任务

MNIST:识别 28×28 灰度手写数字(0–9)。深度学习的"Hello World"

完整代码

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ---------- 1. 数据 ----------
transform = transforms.Compose([
    transforms.ToTensor(),                  # PIL → Tensor [0, 1]
    transforms.Normalize((0.1307,), (0.3081,)),  # MNIST 全局均值/方差
])

train_set = datasets.MNIST("./mnist", train=True,  download=True, transform=transform)
test_set  = datasets.MNIST("./mnist", train=False, download=True, transform=transform)

train_loader = DataLoader(train_set, batch_size=64, shuffle=True)
test_loader  = DataLoader(test_set,  batch_size=256)

# ---------- 2. 模型(简单 MLP)----------
class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Flatten(),                  # 28x28 → 784
            nn.Linear(784, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 10),
        )

    def forward(self, x):
        return self.net(x)


model = MLP().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()

# ---------- 3. 训练 ----------
for epoch in range(5):
    model.train()
    for xb, yb in train_loader:
        xb, yb = xb.to(device), yb.to(device)
        optimizer.zero_grad()
        loss = loss_fn(model(xb), yb)
        loss.backward()
        optimizer.step()

    # 测试
    model.eval()
    correct = 0
    with torch.no_grad():
        for xb, yb in test_loader:
            xb, yb = xb.to(device), yb.to(device)
            correct += (model(xb).argmax(1) == yb).sum().item()
    acc = correct / len(test_set)
    print(f"Epoch {epoch+1}: 测试集准确率 {acc:.4f}")

# ---------- 4. 保存 ----------
torch.save(model.state_dict(), "mnist_mlp.pth")

跑下来 5 个 epoch 大约能到 97% 准确率

推理(用训练好的模型)

model = MLP().to(device)
model.load_state_dict(torch.load("mnist_mlp.pth"))
model.eval()

# 拿一张图片预测
img, label = test_set[0]
with torch.no_grad():
    logits = model(img.unsqueeze(0).to(device))    # 加 batch 维
    probs = torch.softmax(logits, dim=-1)
    pred = probs.argmax().item()
print(f"真实: {label}, 预测: {pred}")

用 CNN 替代 MLP(轻松破 99%)

class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),                         # 28 → 14
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),                          # 14 → 7
            nn.Flatten(),
            nn.Linear(64 * 7 * 7, 128),
            nn.ReLU(),
            nn.Linear(128, 10),
        )

    def forward(self, x):
        return self.net(x)

把 MLP 换成 CNN 重新训练——5 个 epoch 能到 99.2%

你刚刚学到了什么

  • torchvision.datasets 下载经典数据集
  • transforms.Compose 做数据预处理
  • DataLoader 批量加载
  • nn.Module + Sequential 定义模型
  • 标准训练循环(zero_grad / backward / step)
  • 模型保存与加载
  • MLP vs CNN 性能对比

整个流程在 PyTorch 项目里千篇一律——换数据、换模型、换损失函数,模板不变。

下一篇深入讲 CNN 原理。