master
/ Jianhai / lab5 / models / mlp.py

mlp.py @754b50d raw · history · blame

import torch
import torch.nn as nn
import torch.nn.functional as F

class SimpleMLP(nn.Module):
    """单隐层MLP模型"""
    def __init__(self, input_dim=3*32*32, hidden_dim=512, output_dim=10):
        super(SimpleMLP, self).__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_dim, output_dim)
        
    def forward(self, x):
        x = self.flatten(x)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x


class DeepMLP(nn.Module):
    """深层MLP模型,具有多个隐藏层、批标准化和dropout"""
    def __init__(self, input_dim=3*32*32, dropout_rate=0.5, use_bn=True, use_dropout=True):
        super(DeepMLP, self).__init__()
        self.flatten = nn.Flatten()
        self.use_bn = use_bn
        self.use_dropout = use_dropout
        
        # 第一层
        self.fc1 = nn.Linear(input_dim, 1024)
        self.bn1 = nn.BatchNorm1d(1024) if use_bn else nn.Identity()
        
        # 第二层
        self.fc2 = nn.Linear(1024, 512)
        self.bn2 = nn.BatchNorm1d(512) if use_bn else nn.Identity()
        
        # 第三层
        self.fc3 = nn.Linear(512, 256)
        self.bn3 = nn.BatchNorm1d(256) if use_bn else nn.Identity()
        
        # 输出层
        self.fc4 = nn.Linear(256, 10)
        
        # 激活和Dropout
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout_rate) if use_dropout else nn.Identity()
        
    def forward(self, x):
        x = self.flatten(x)
        
        # 第一层
        x = self.fc1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.dropout(x)
        
        # 第二层
        x = self.fc2(x)
        x = self.bn2(x)
        x = self.relu(x)
        x = self.dropout(x)
        
        # 第三层
        x = self.fc3(x)
        x = self.bn3(x)
        x = self.relu(x)
        x = self.dropout(x)
        
        # 输出层
        x = self.fc4(x)
        
        return x


class ResidualBlock(nn.Module):
    """MLP的残差块"""
    def __init__(self, input_dim, output_dim, activation, dropout_rate=0.5):
        super(ResidualBlock, self).__init__()
        
        self.linear1 = nn.Linear(input_dim, output_dim)
        self.bn1 = nn.BatchNorm1d(output_dim)
        self.linear2 = nn.Linear(output_dim, output_dim)
        self.bn2 = nn.BatchNorm1d(output_dim)
        
        self.activation = activation
        self.dropout = nn.Dropout(dropout_rate)
        
        # 如果输入维度不等于输出维度,添加一个线性变换
        self.shortcut = nn.Identity()
        if input_dim != output_dim:
            self.shortcut = nn.Sequential(
                nn.Linear(input_dim, output_dim),
                nn.BatchNorm1d(output_dim)
            )
    
    def forward(self, x):
        residual = x
        
        out = self.linear1(x)
        out = self.bn1(out)
        out = self.activation(out)
        out = self.dropout(out)
        
        out = self.linear2(out)
        out = self.bn2(out)
        
        out += self.shortcut(residual)
        out = self.activation(out)
        
        return out


class ResidualMLP(nn.Module):
    """带有残差连接的MLP模型"""
    def __init__(self, input_dim=3*32*32, hidden_dims=[1024, 1024, 1024, 512, 512, 512], output_dim=10, 
                 dropout_rate=0.5, activation='relu'):
        super(ResidualMLP, self).__init__()
        self.flatten = nn.Flatten()
        
        # 选择激活函数
        if activation == 'relu':
            self.activation = nn.ReLU()
        elif activation == 'leaky_relu':
            self.activation = nn.LeakyReLU(0.1)
        elif activation == 'gelu':
            self.activation = nn.GELU()
        elif activation == 'swish':
            self.activation = lambda x: x * torch.sigmoid(x)
        else:
            raise ValueError(f"不支持的激活函数: {activation}")
        
        # 输入层
        layers = []
        layers.append(nn.Linear(input_dim, hidden_dims[0]))
        layers.append(nn.BatchNorm1d(hidden_dims[0]))
        layers.append(self.activation)
        layers.append(nn.Dropout(dropout_rate))
        
        # 隐藏层,带残差连接
        for i in range(1, len(hidden_dims)):
            layers.append(ResidualBlock(hidden_dims[i-1], hidden_dims[i], self.activation, dropout_rate))
        
        # 输出层
        layers.append(nn.Linear(hidden_dims[-1], output_dim))
        
        self.layers = nn.Sequential(*layers)
        
    def forward(self, x):
        x = self.flatten(x)
        x = self.layers(x)
        return x