| 1 | 1 |
"cells": [
|
| 2 | 2 |
{
|
| 3 | 3 |
"cell_type": "code",
|
| 4 | |
"execution_count": 1,
|
| 5 | |
"id": "4f3d7435",
|
|
4 |
"execution_count": null,
|
|
5 |
"id": "cf465d7b",
|
| 6 | 6 |
"metadata": {},
|
| 7 | 7 |
"outputs": [],
|
| 8 | 8 |
"source": [
|
|
| 27 | 27 |
},
|
| 28 | 28 |
{
|
| 29 | 29 |
"cell_type": "code",
|
| 30 | |
"execution_count": 4,
|
| 31 | |
"id": "9c8a2cb3",
|
|
30 |
"execution_count": null,
|
|
31 |
"id": "068c088d",
|
| 32 | 32 |
"metadata": {
|
| 33 | 33 |
"inputHidden": false
|
| 34 | 34 |
},
|
| 35 | |
"outputs": [
|
| 36 | |
{
|
| 37 | |
"name": "stdout",
|
| 38 | |
"output_type": "stream",
|
| 39 | |
"text": [
|
| 40 | |
"使用设备: cpu\n",
|
| 41 | |
"Files already downloaded and verified\n",
|
| 42 | |
"Files already downloaded and verified\n",
|
| 43 | |
"训练集大小: 45000\n",
|
| 44 | |
"验证集大小: 5000\n",
|
| 45 | |
"测试集大小: 10000\n",
|
| 46 | |
"使用模型: SimpleMLP\n"
|
| 47 | |
]
|
| 48 | |
}
|
| 49 | |
],
|
|
35 |
"outputs": [],
|
| 50 | 36 |
"source": [
|
| 51 | 37 |
"# 设置参数\n",
|
| 52 | 38 |
"model_type = 'simple_mlp' # 可选: 'simple_mlp', 'deep_mlp', 'residual_mlp', 'simple_cnn', 'medium_cnn', 'vgg_style', 'resnet'\n",
|
|
| 102 | 88 |
},
|
| 103 | 89 |
{
|
| 104 | 90 |
"cell_type": "code",
|
| 105 | |
"execution_count": 5,
|
| 106 | |
"id": "51f4362c",
|
|
91 |
"execution_count": null,
|
|
92 |
"id": "63cb4425",
|
| 107 | 93 |
"metadata": {},
|
| 108 | |
"outputs": [
|
| 109 | |
{
|
| 110 | |
"name": "stdout",
|
| 111 | |
"output_type": "stream",
|
| 112 | |
"text": [
|
| 113 | |
"\n",
|
| 114 | |
"分析模型复杂度:\n",
|
| 115 | |
"参数量: 1,578,506\n",
|
| 116 | |
"每批次(128个样本)推理时间: 8.96ms\n",
|
| 117 | |
"Epoch 1/20\n"
|
| 118 | |
]
|
| 119 | |
},
|
| 120 | |
{
|
| 121 | |
"ename": "KeyboardInterrupt",
|
| 122 | |
"evalue": "",
|
| 123 | |
"output_type": "error",
|
| 124 | |
"traceback": [
|
| 125 | |
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
| 126 | |
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
|
| 127 | |
"\u001b[0;32m/tmp/ipykernel_246/3850660409.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 16\u001b[0m trained_model, history = train_model(\n\u001b[1;32m 17\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrain_loader\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalid_loader\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcriterion\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mscheduler\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 18\u001b[0;31m \u001b[0mnum_epochs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mepochs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msave_dir\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0msave_directory\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 19\u001b[0m )\n\u001b[1;32m 20\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
|
| 128 | |
"\u001b[0;32m~/work/Jianhai/lab5/utils/train_utils.py\u001b[0m in \u001b[0;36mtrain_model\u001b[0;34m(model, train_loader, valid_loader, criterion, optimizer, scheduler, num_epochs, device, save_dir)\u001b[0m\n\u001b[1;32m 63\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 64\u001b[0m \u001b[0;31m# 反向传播和优化\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 65\u001b[0;31m \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 66\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 67\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
|
| 129 | |
"\u001b[0;32m~/.virtualenvs/basenv/lib/python3.7/site-packages/torch/tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m 243\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 244\u001b[0m inputs=inputs)\n\u001b[0;32m--> 245\u001b[0;31m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mautograd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 246\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 247\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mregister_hook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
| 130 | |
"\u001b[0;32m~/.virtualenvs/basenv/lib/python3.7/site-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m 145\u001b[0m Variable._execution_engine.run_backward(\n\u001b[1;32m 146\u001b[0m \u001b[0mtensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad_tensors_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 147\u001b[0;31m allow_unreachable=True, accumulate_grad=True) # allow_unreachable flag\n\u001b[0m\u001b[1;32m 148\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 149\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
|
| 131 | |
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
|
| 132 | |
]
|
| 133 | |
}
|
| 134 | |
],
|
|
94 |
"outputs": [],
|
| 135 | 95 |
"source": [
|
| 136 | 96 |
"# 计算模型复杂度\n",
|
| 137 | 97 |
"print(\"\\n分析模型复杂度:\")\n",
|
|
| 185 | 145 |
{
|
| 186 | 146 |
"cell_type": "code",
|
| 187 | 147 |
"execution_count": null,
|
| 188 | |
"id": "d9379a62",
|
| 189 | |
"metadata": {},
|
| 190 | |
"outputs": [],
|
| 191 | |
"source": []
|
| 192 | |
},
|
| 193 | |
{
|
| 194 | |
"cell_type": "code",
|
| 195 | |
"execution_count": null,
|
| 196 | |
"id": "554f08d9",
|
|
148 |
"id": "92cbeb64",
|
| 197 | 149 |
"metadata": {},
|
| 198 | 150 |
"outputs": [],
|
| 199 | 151 |
"source": []
|