{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "cf465d7b",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.optim as optim\n",
"import os\n",
"\n",
"# 导入项目中的模块\n",
"from models import SimpleMLP, DeepMLP, ResidualMLP, SimpleCNN, MediumCNN, VGGStyleNet, SimpleResNet\n",
"from utils import (\n",
" load_cifar10, \n",
" set_seed, \n",
" train_model, \n",
" evaluate_model, \n",
" plot_training_history,\n",
" visualize_model_predictions,\n",
" visualize_conv_filters,\n",
" model_complexity\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "068c088d",
"metadata": {
"inputHidden": false
},
"outputs": [],
"source": [
"# 设置参数\n",
"model_type = 'simple_mlp' # 可选: 'simple_mlp', 'deep_mlp', 'residual_mlp', 'simple_cnn', 'medium_cnn', 'vgg_style', 'resnet'\n",
"epochs = 20\n",
"learning_rate = 0.001\n",
"batch_size = 128\n",
"use_data_augmentation = True # CNN通常受益于数据增强\n",
"save_directory = './ck'\n",
"visualize_filters = True # 是否可视化卷积核(仅对CNN有效)\n",
"visualize_predictions = True # 是否可视化预测结果\n",
"\n",
"# 设置随机种子\n",
"set_seed()\n",
"\n",
"#因为mo平台的提交任务机制,需要手动切换到该文件夹下。\n",
"os.chdir(os.path.expanduser(\"~/work/Jianhai/lab5\"))\n",
"\n",
"# 检查是否有可用的GPU\n",
"device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n",
"print(f\"使用设备: {device}\")\n",
"\n",
"# 加载数据\n",
"train_loader, valid_loader, test_loader, classes = load_cifar10(\n",
" use_augmentation=use_data_augmentation, \n",
" batch_size=batch_size\n",
")\n",
"\n",
"# 初始化选择的模型\n",
"if model_type == 'simple_mlp':\n",
" model = SimpleMLP()\n",
" model_name = \"SimpleMLP\"\n",
"elif model_type == 'deep_mlp':\n",
" model = DeepMLP(dropout_rate=0.5, use_bn=True, use_dropout=True)\n",
" model_name = \"DeepMLP\"\n",
"elif model_type == 'residual_mlp':\n",
" model = ResidualMLP(activation='relu')\n",
" model_name = \"ResidualMLP\"\n",
"elif model_type == 'simple_cnn':\n",
" model = SimpleCNN()\n",
" model_name = \"SimpleCNN\"\n",
"elif model_type == 'medium_cnn':\n",
" model = MediumCNN(use_bn=True)\n",
" model_name = \"MediumCNN\"\n",
"elif model_type == 'vgg_style':\n",
" model = VGGStyleNet()\n",
" model_name = \"VGGStyleNet\"\n",
"else: # resnet\n",
" model = SimpleResNet(num_blocks=[2, 2, 2])\n",
" model_name = \"SimpleResNet\"\n",
"\n",
"print(f\"使用模型: {model_name}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "63cb4425",
"metadata": {},
"outputs": [],
"source": [
"# 计算模型复杂度\n",
"print(\"\\n分析模型复杂度:\")\n",
"model_complexity(model, device=device)\n",
"\n",
"# 定义损失函数和优化器\n",
"criterion = nn.CrossEntropyLoss()\n",
"optimizer = optim.Adam(model.parameters(), lr=learning_rate)\n",
"\n",
"# 可以添加学习率调度器\n",
"scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)\n",
"\n",
"# 确保checkpoints目录存在\n",
"os.makedirs(save_directory, exist_ok=True)\n",
"\n",
"# 训练模型\n",
"trained_model, history = train_model(\n",
" model, train_loader, valid_loader, criterion, optimizer, scheduler,\n",
" num_epochs=epochs, device=device, save_dir=save_directory\n",
")\n",
"\n",
"# 绘制训练历史\n",
"plot_training_history(history, title=f\"{model_name} Training History\")\n",
"\n",
"# 在测试集上评估模型\n",
"print(\"\\n在测试集上评估模型:\")\n",
"test_loss, test_acc = evaluate_model(trained_model, test_loader, criterion, device, classes)\n",
"\n",
"print(f\"{model_name} 最终测试准确率: {test_acc:.4f}\")\n",
"\n",
"# 如果是CNN模型并且需要可视化卷积核\n",
"if visualize_filters and model_type in ['simple_cnn', 'medium_cnn', 'vgg_style', 'resnet']:\n",
" print(\"\\n可视化卷积核:\")\n",
" if model_type == 'simple_cnn':\n",
" visualize_conv_filters(trained_model, 'conv1')\n",
" elif model_type == 'medium_cnn':\n",
" visualize_conv_filters(trained_model, 'conv1')\n",
" elif model_type == 'vgg_style':\n",
" visualize_conv_filters(trained_model, 'features.0')\n",
" else: # resnet\n",
" visualize_conv_filters(trained_model, 'conv1')\n",
"\n",
"# 如果需要可视化模型预测\n",
"if visualize_predictions:\n",
" print(\"\\n可视化模型预测:\")\n",
" visualize_model_predictions(trained_model, test_loader, classes, device)\n",
"\n",
"print(f\"\\n{model_name}的训练和评估已完成!\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "92cbeb64",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}