master
/ train.py

train.py @master raw · history · blame

from loadData import DataLayer
from model import VGG16, Network
import torch
import tensorboardX as tb
import os

weight_output = "./weight_output"
if not os.path.exists(weight_output):  # 如果路径不存在
    os.makedirs(weight_output)

# import pydevd_pycharm
# pydevd_pycharm.settrace('10.214.160.245', port=8011, stdoutToServer=True, stderrToServer=True)

def train():
    data_layer = DataLayer("./dataset/train")
    # image_len = len(data_layer.images)
    data_layer_val = DataLayer("./dataset/val")

    net = VGG16()
    # Construct the computation graph
    net.create_architecture(2, tag='default', anchor_scales=[8,16,32], anchor_ratios=[0.5,1,2])
    # Define the loss
    # Set learning rate and momentum
    lr = 0.0001
    params = []
    for key, value in dict(net.named_parameters()).items():
        if value.requires_grad:
            if 'bias' in key:
                params += [{
                    'params': [value],
                    'lr': lr * 2,
                    'weight_decay': 0.0001
                }]
            else:
                params += [{
                    'params': [value],
                    'lr': lr,
                    'weight_decay': getattr(value, 'weight_decay', 0.0001)
                }]
    optimizer = torch.optim.SGD(params, momentum=0.9)
    # Write the train and validation information to tensorboard
    writer = tb.writer.FileWriter("./tensorboard/train")
    valwriter = tb.writer.FileWriter("./tensorboard/val")


    print("model create successfully!")
    # 加载模型
    list_dir = os.listdir(weight_output)
    if len(list_dir) == 0:
        net.load_pretrained_cnn(torch.load("./imagenet_weights/vgg16.pth"))
    else:
        # net.load_state_dict(torch.load(os.path.join(weight_output, list_dir[-1])))
        net.load_state_dict(torch.load(os.path.join(weight_output, list_dir[-1]), map_location='cpu'))

    print("model params load successfully!")
    net.train()
    net.to(net._device)

    max_iter = 5001
    for iter in range(max_iter):
        # Compute the graph with summary
        blobs = data_layer.forward()
        if iter % 100 == 0:
            rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss, summary = net.train_step_with_summary(blobs, optimizer)
            for _sum in summary:
                writer.add_summary(_sum, float(iter))
            # Also check the summary on the validation set
            blobs_val = data_layer_val.forward()
            summary_val = net.get_summary(blobs_val)
            for _sum in summary_val:
                valwriter.add_summary(_sum, float(iter))
        else:
            blobs = data_layer.forward()
            # Compute the graph without summary
            rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss = net.train_step(blobs, optimizer)
        if iter % 10 == 0:
            # Display the last image training information
            print('iter: %d / %d, total loss: %.6f\n >>> rpn_loss_cls: %.6f\n '
                      '>>> rpn_loss_box: %.6f\n >>> loss_cls: %.6f\n >>> loss_box: %.6f\n >>> lr: %f' % \
                      (iter, max_iter, total_loss, rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, lr))

        if iter % 1000 == 0:
            torch.save(net.state_dict(), os.path.join(weight_output, "params"+'{0:0>9}'.format(iter)+".pkl"))

    writer.close()
    valwriter.close()

if __name__ == '__main__':
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    train()