from loadData import DataLayer
from model import VGG16, Network
import torch
import tensorboardX as tb
import os
weight_output = "./weight_output"
if not os.path.exists(weight_output): # 如果路径不存在
os.makedirs(weight_output)
# import pydevd_pycharm
# pydevd_pycharm.settrace('10.214.160.245', port=8011, stdoutToServer=True, stderrToServer=True)
def train():
data_layer = DataLayer("./dataset/train")
# image_len = len(data_layer.images)
data_layer_val = DataLayer("./dataset/val")
net = VGG16()
# Construct the computation graph
net.create_architecture(2, tag='default', anchor_scales=[8,16,32], anchor_ratios=[0.5,1,2])
# Define the loss
# Set learning rate and momentum
lr = 0.0001
params = []
for key, value in dict(net.named_parameters()).items():
if value.requires_grad:
if 'bias' in key:
params += [{
'params': [value],
'lr': lr * 2,
'weight_decay': 0.0001
}]
else:
params += [{
'params': [value],
'lr': lr,
'weight_decay': getattr(value, 'weight_decay', 0.0001)
}]
optimizer = torch.optim.SGD(params, momentum=0.9)
# Write the train and validation information to tensorboard
writer = tb.writer.FileWriter("./tensorboard/train")
valwriter = tb.writer.FileWriter("./tensorboard/val")
print("model create successfully!")
# 加载模型
list_dir = os.listdir(weight_output)
if len(list_dir) == 0:
net.load_pretrained_cnn(torch.load("./imagenet_weights/vgg16.pth"))
else:
# net.load_state_dict(torch.load(os.path.join(weight_output, list_dir[-1])))
net.load_state_dict(torch.load(os.path.join(weight_output, list_dir[-1]), map_location='cpu'))
print("model params load successfully!")
net.train()
net.to(net._device)
max_iter = 5001
for iter in range(max_iter):
# Compute the graph with summary
blobs = data_layer.forward()
if iter % 100 == 0:
rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss, summary = net.train_step_with_summary(blobs, optimizer)
for _sum in summary:
writer.add_summary(_sum, float(iter))
# Also check the summary on the validation set
blobs_val = data_layer_val.forward()
summary_val = net.get_summary(blobs_val)
for _sum in summary_val:
valwriter.add_summary(_sum, float(iter))
else:
blobs = data_layer.forward()
# Compute the graph without summary
rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, total_loss = net.train_step(blobs, optimizer)
if iter % 10 == 0:
# Display the last image training information
print('iter: %d / %d, total loss: %.6f\n >>> rpn_loss_cls: %.6f\n '
'>>> rpn_loss_box: %.6f\n >>> loss_cls: %.6f\n >>> loss_box: %.6f\n >>> lr: %f' % \
(iter, max_iter, total_loss, rpn_loss_cls, rpn_loss_box, loss_cls, loss_box, lr))
if iter % 1000 == 0:
torch.save(net.state_dict(), os.path.join(weight_output, "params"+'{0:0>9}'.format(iter)+".pkl"))
writer.close()
valwriter.close()
if __name__ == '__main__':
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
train()