master
/ Jianhai / lab5 / train.ipynb

train.ipynb @754b50d

754b50d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "4f3d7435",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import os\n",
    "\n",
    "# 导入项目中的模块\n",
    "from models import SimpleMLP, DeepMLP, ResidualMLP, SimpleCNN, MediumCNN, VGGStyleNet, SimpleResNet\n",
    "from utils import (\n",
    "    load_cifar10, \n",
    "    set_seed, \n",
    "    train_model, \n",
    "    evaluate_model, \n",
    "    plot_training_history,\n",
    "    visualize_model_predictions,\n",
    "    visualize_conv_filters,\n",
    "    model_complexity\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "9c8a2cb3",
   "metadata": {
    "inputHidden": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "使用设备: cpu\n",
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n",
      "训练集大小: 45000\n",
      "验证集大小: 5000\n",
      "测试集大小: 10000\n",
      "使用模型: SimpleMLP\n"
     ]
    }
   ],
   "source": [
    "# 设置参数\n",
    "model_type = 'simple_mlp'  # 可选: 'simple_mlp', 'deep_mlp', 'residual_mlp', 'simple_cnn', 'medium_cnn', 'vgg_style', 'resnet'\n",
    "epochs = 20\n",
    "learning_rate = 0.001\n",
    "batch_size = 128\n",
    "use_data_augmentation = True  # CNN通常受益于数据增强\n",
    "save_directory = './ck'\n",
    "visualize_filters = True  # 是否可视化卷积核(仅对CNN有效)\n",
    "visualize_predictions = True  # 是否可视化预测结果\n",
    "\n",
    "# 设置随机种子\n",
    "set_seed()\n",
    "\n",
    "#因为mo平台的提交任务机制,需要手动切换到该文件夹下。\n",
    "os.chdir(os.path.expanduser(\"~/work/Jianhai/lab5\"))\n",
    "\n",
    "# 检查是否有可用的GPU\n",
    "device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n",
    "print(f\"使用设备: {device}\")\n",
    "\n",
    "# 加载数据\n",
    "train_loader, valid_loader, test_loader, classes = load_cifar10(\n",
    "    use_augmentation=use_data_augmentation, \n",
    "    batch_size=batch_size\n",
    ")\n",
    "\n",
    "# 初始化选择的模型\n",
    "if model_type == 'simple_mlp':\n",
    "    model = SimpleMLP()\n",
    "    model_name = \"SimpleMLP\"\n",
    "elif model_type == 'deep_mlp':\n",
    "    model = DeepMLP(dropout_rate=0.5, use_bn=True, use_dropout=True)\n",
    "    model_name = \"DeepMLP\"\n",
    "elif model_type == 'residual_mlp':\n",
    "    model = ResidualMLP(activation='relu')\n",
    "    model_name = \"ResidualMLP\"\n",
    "elif model_type == 'simple_cnn':\n",
    "    model = SimpleCNN()\n",
    "    model_name = \"SimpleCNN\"\n",
    "elif model_type == 'medium_cnn':\n",
    "    model = MediumCNN(use_bn=True)\n",
    "    model_name = \"MediumCNN\"\n",
    "elif model_type == 'vgg_style':\n",
    "    model = VGGStyleNet()\n",
    "    model_name = \"VGGStyleNet\"\n",
    "else:  # resnet\n",
    "    model = SimpleResNet(num_blocks=[2, 2, 2])\n",
    "    model_name = \"SimpleResNet\"\n",
    "\n",
    "print(f\"使用模型: {model_name}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "51f4362c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "分析模型复杂度:\n",
      "参数量: 1,578,506\n",
      "每批次(128个样本)推理时间: 8.96ms\n",
      "Epoch 1/20\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m/tmp/ipykernel_246/3850660409.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     16\u001b[0m trained_model, history = train_model(\n\u001b[1;32m     17\u001b[0m     \u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrain_loader\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalid_loader\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcriterion\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mscheduler\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 18\u001b[0;31m     \u001b[0mnum_epochs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mepochs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msave_dir\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0msave_directory\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     19\u001b[0m )\n\u001b[1;32m     20\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/work/Jianhai/lab5/utils/train_utils.py\u001b[0m in \u001b[0;36mtrain_model\u001b[0;34m(model, train_loader, valid_loader, criterion, optimizer, scheduler, num_epochs, device, save_dir)\u001b[0m\n\u001b[1;32m     63\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     64\u001b[0m             \u001b[0;31m# 反向传播和优化\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 65\u001b[0;31m             \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     66\u001b[0m             \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     67\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/.virtualenvs/basenv/lib/python3.7/site-packages/torch/tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m    243\u001b[0m                 \u001b[0mcreate_graph\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    244\u001b[0m                 inputs=inputs)\n\u001b[0;32m--> 245\u001b[0;31m         \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mautograd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    246\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    247\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mregister_hook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/.virtualenvs/basenv/lib/python3.7/site-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m    145\u001b[0m     Variable._execution_engine.run_backward(\n\u001b[1;32m    146\u001b[0m         \u001b[0mtensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad_tensors_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 147\u001b[0;31m         allow_unreachable=True, accumulate_grad=True)  # allow_unreachable flag\n\u001b[0m\u001b[1;32m    148\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    149\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "# 计算模型复杂度\n",
    "print(\"\\n分析模型复杂度:\")\n",
    "model_complexity(model, device=device)\n",
    "\n",
    "# 定义损失函数和优化器\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = optim.Adam(model.parameters(), lr=learning_rate)\n",
    "\n",
    "# 可以添加学习率调度器\n",
    "scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)\n",
    "\n",
    "# 确保checkpoints目录存在\n",
    "os.makedirs(save_directory, exist_ok=True)\n",
    "\n",
    "# 训练模型\n",
    "trained_model, history = train_model(\n",
    "    model, train_loader, valid_loader, criterion, optimizer, scheduler,\n",
    "    num_epochs=epochs, device=device, save_dir=save_directory\n",
    ")\n",
    "\n",
    "# 绘制训练历史\n",
    "plot_training_history(history, title=f\"{model_name} Training History\")\n",
    "\n",
    "# 在测试集上评估模型\n",
    "print(\"\\n在测试集上评估模型:\")\n",
    "test_loss, test_acc = evaluate_model(trained_model, test_loader, criterion, device, classes)\n",
    "\n",
    "print(f\"{model_name} 最终测试准确率: {test_acc:.4f}\")\n",
    "\n",
    "# 如果是CNN模型并且需要可视化卷积核\n",
    "if visualize_filters and model_type in ['simple_cnn', 'medium_cnn', 'vgg_style', 'resnet']:\n",
    "    print(\"\\n可视化卷积核:\")\n",
    "    if model_type == 'simple_cnn':\n",
    "        visualize_conv_filters(trained_model, 'conv1')\n",
    "    elif model_type == 'medium_cnn':\n",
    "        visualize_conv_filters(trained_model, 'conv1')\n",
    "    elif model_type == 'vgg_style':\n",
    "        visualize_conv_filters(trained_model, 'features.0')\n",
    "    else:  # resnet\n",
    "        visualize_conv_filters(trained_model, 'conv1')\n",
    "\n",
    "# 如果需要可视化模型预测\n",
    "if visualize_predictions:\n",
    "    print(\"\\n可视化模型预测:\")\n",
    "    visualize_model_predictions(trained_model, test_loader, classes, device)\n",
    "\n",
    "print(f\"\\n{model_name}的训练和评估已完成!\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d9379a62",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "554f08d9",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}