master
rm data chenjh919 1 year, 3 days ago
10 changed file(s) with 9 addition(s) and 58 deletion(s). Raw diff Collapse all Expand all
11 "cells": [
22 {
33 "cell_type": "code",
4 "execution_count": 1,
5 "id": "4f3d7435",
4 "execution_count": null,
5 "id": "cf465d7b",
66 "metadata": {},
77 "outputs": [],
88 "source": [
2727 },
2828 {
2929 "cell_type": "code",
30 "execution_count": 4,
31 "id": "9c8a2cb3",
30 "execution_count": null,
31 "id": "068c088d",
3232 "metadata": {
3333 "inputHidden": false
3434 },
35 "outputs": [
36 {
37 "name": "stdout",
38 "output_type": "stream",
39 "text": [
40 "使用设备: cpu\n",
41 "Files already downloaded and verified\n",
42 "Files already downloaded and verified\n",
43 "训练集大小: 45000\n",
44 "验证集大小: 5000\n",
45 "测试集大小: 10000\n",
46 "使用模型: SimpleMLP\n"
47 ]
48 }
49 ],
35 "outputs": [],
5036 "source": [
5137 "# 设置参数\n",
5238 "model_type = 'simple_mlp' # 可选: 'simple_mlp', 'deep_mlp', 'residual_mlp', 'simple_cnn', 'medium_cnn', 'vgg_style', 'resnet'\n",
10288 },
10389 {
10490 "cell_type": "code",
105 "execution_count": 5,
106 "id": "51f4362c",
91 "execution_count": null,
92 "id": "63cb4425",
10793 "metadata": {},
108 "outputs": [
109 {
110 "name": "stdout",
111 "output_type": "stream",
112 "text": [
113 "\n",
114 "分析模型复杂度:\n",
115 "参数量: 1,578,506\n",
116 "每批次(128个样本)推理时间: 8.96ms\n",
117 "Epoch 1/20\n"
118 ]
119 },
120 {
121 "ename": "KeyboardInterrupt",
122 "evalue": "",
123 "output_type": "error",
124 "traceback": [
125 "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
126 "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
127 "\u001b[0;32m/tmp/ipykernel_246/3850660409.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 16\u001b[0m trained_model, history = train_model(\n\u001b[1;32m 17\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrain_loader\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalid_loader\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcriterion\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mscheduler\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 18\u001b[0;31m \u001b[0mnum_epochs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mepochs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msave_dir\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0msave_directory\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 19\u001b[0m )\n\u001b[1;32m 20\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
128 "\u001b[0;32m~/work/Jianhai/lab5/utils/train_utils.py\u001b[0m in \u001b[0;36mtrain_model\u001b[0;34m(model, train_loader, valid_loader, criterion, optimizer, scheduler, num_epochs, device, save_dir)\u001b[0m\n\u001b[1;32m 63\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 64\u001b[0m \u001b[0;31m# 反向传播和优化\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 65\u001b[0;31m \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 66\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 67\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
129 "\u001b[0;32m~/.virtualenvs/basenv/lib/python3.7/site-packages/torch/tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m 243\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 244\u001b[0m inputs=inputs)\n\u001b[0;32m--> 245\u001b[0;31m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mautograd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 246\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 247\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mregister_hook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
130 "\u001b[0;32m~/.virtualenvs/basenv/lib/python3.7/site-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m 145\u001b[0m Variable._execution_engine.run_backward(\n\u001b[1;32m 146\u001b[0m \u001b[0mtensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad_tensors_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 147\u001b[0;31m allow_unreachable=True, accumulate_grad=True) # allow_unreachable flag\n\u001b[0m\u001b[1;32m 148\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 149\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
131 "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
132 ]
133 }
134 ],
94 "outputs": [],
13595 "source": [
13696 "# 计算模型复杂度\n",
13797 "print(\"\\n分析模型复杂度:\")\n",
185145 {
186146 "cell_type": "code",
187147 "execution_count": null,
188 "id": "d9379a62",
189 "metadata": {},
190 "outputs": [],
191 "source": []
192 },
193 {
194 "cell_type": "code",
195 "execution_count": null,
196 "id": "554f08d9",
148 "id": "92cbeb64",
197149 "metadata": {},
198150 "outputs": [],
199151 "source": []
data/cifar-10-batches-py/batches.meta less more
Binary diff not shown
data/cifar-10-batches-py/data_batch_1 less more
Binary diff not shown
data/cifar-10-batches-py/data_batch_2 less more
Binary diff not shown
data/cifar-10-batches-py/data_batch_3 less more
Binary diff not shown
data/cifar-10-batches-py/data_batch_4 less more
Binary diff not shown
data/cifar-10-batches-py/data_batch_5 less more
Binary diff not shown
+0
-1
data/cifar-10-batches-py/readme.html less more
0 <meta HTTP-EQUIV="REFRESH" content="0; url=http://www.cs.toronto.edu/~kriz/cifar.html">
data/cifar-10-batches-py/test_batch less more
Binary diff not shown
data/cifar-10-python.tar.gz less more
Binary diff not shown