为什么图像不能直接用 MLP

一张 224×224 RGB 图片 = 150,528 个像素。如果第一层 MLP 是 1024 维:

  • 参数 = 150,528 × 1024 ≈ 1.5 亿

爆炸了。CNN 用卷积把这件事做成 ~万级参数。

卷积:可学习的"小窗口"

import torch.nn as nn

conv = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=1)

参数 = 3 × 32 × 3 × 3 = 864——远少于 MLP

原图 (3, 224, 224)
       ↓ conv 3x3
特征 (32, 224, 224)

每个输出通道都是用一个 3×3 卷积核扫遍整张图——不论图多大,参数只看 kernel 大小

池化:缩小尺寸

nn.MaxPool2d(kernel_size=2)

每 2×2 取最大值——尺寸减半(224 → 112),参数为 0。

经典架构:从 LeNet 到 ResNet

LeNet(1998):CNN 的开山祖

nn.Sequential(
    nn.Conv2d(1, 6, 5), nn.ReLU(), nn.MaxPool2d(2),
    nn.Conv2d(6, 16, 5), nn.ReLU(), nn.MaxPool2d(2),
    nn.Flatten(),
    nn.Linear(16*5*5, 120), nn.ReLU(),
    nn.Linear(120, 84), nn.ReLU(),
    nn.Linear(84, 10),
)

AlexNet(2012):让深度学习"出圈"

  • 加深到 8 层
  • 用 ReLU 替代 sigmoid
  • 用 Dropout
  • 用 GPU 训练

VGG(2014):纯 3×3 卷积堆叠到 16/19 层

ResNet(2015):残差连接,解锁百层网络

class ResBlock(nn.Module):
    def __init__(self, ch):
        super().__init__()
        self.conv1 = nn.Conv2d(ch, ch, 3, padding=1)
        self.conv2 = nn.Conv2d(ch, ch, 3, padding=1)

    def forward(self, x):
        return x + self.conv2(torch.relu(self.conv1(x)))    # 关键:x +

x + ... 让梯度能直接传过来——百层都能训。

现成的预训练模型(绝大多数任务的正确起点)

from torchvision.models import resnet50, ResNet50_Weights

model = resnet50(weights=ResNet50_Weights.DEFAULT)
model.eval()

不要从零训练——拿预训练模型 fine-tune。

迁移学习:拿别人训好的当起点

import torch.nn as nn
from torchvision.models import resnet50, ResNet50_Weights

model = resnet50(weights=ResNet50_Weights.DEFAULT)

# 冻结所有参数
for p in model.parameters():
    p.requires_grad = False

# 换最后一层(自己的分类数)
model.fc = nn.Linear(model.fc.in_features, 10)

# 只训新层
optimizer = torch.optim.Adam(model.fc.parameters(), lr=1e-3)

用 1000 张图也能达到不错的效果——预训练模型让你省下 99% 的算力。

数据增强

from torchvision import transforms

train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(0.2, 0.2, 0.2),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

随机裁剪、翻转、颜色扰动——人为造数据让模型更鲁棒。

2026 的现实:Vision Transformer 抢了 CNN 的活

from torchvision.models import vit_b_16, ViT_B_16_Weights
model = vit_b_16(weights=ViT_B_16_Weights.DEFAULT)

ViT 在大数据上比 CNN 强——但小数据 + 边缘设备 CNN 仍然主流

下一篇讲 RNN/LSTM。