任务
MNIST:识别 28×28 灰度手写数字(0–9)。深度学习的"Hello World"。
完整代码
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ---------- 1. 数据 ----------
transform = transforms.Compose([
transforms.ToTensor(), # PIL → Tensor [0, 1]
transforms.Normalize((0.1307,), (0.3081,)), # MNIST 全局均值/方差
])
train_set = datasets.MNIST("./mnist", train=True, download=True, transform=transform)
test_set = datasets.MNIST("./mnist", train=False, download=True, transform=transform)
train_loader = DataLoader(train_set, batch_size=64, shuffle=True)
test_loader = DataLoader(test_set, batch_size=256)
# ---------- 2. 模型(简单 MLP)----------
class MLP(nn.Module):
def __init__(self):
super().__init__()
self.net = nn.Sequential(
nn.Flatten(), # 28x28 → 784
nn.Linear(784, 128),
nn.ReLU(),
nn.Linear(128, 64),
nn.ReLU(),
nn.Linear(64, 10),
)
def forward(self, x):
return self.net(x)
model = MLP().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()
# ---------- 3. 训练 ----------
for epoch in range(5):
model.train()
for xb, yb in train_loader:
xb, yb = xb.to(device), yb.to(device)
optimizer.zero_grad()
loss = loss_fn(model(xb), yb)
loss.backward()
optimizer.step()
# 测试
model.eval()
correct = 0
with torch.no_grad():
for xb, yb in test_loader:
xb, yb = xb.to(device), yb.to(device)
correct += (model(xb).argmax(1) == yb).sum().item()
acc = correct / len(test_set)
print(f"Epoch {epoch+1}: 测试集准确率 {acc:.4f}")
# ---------- 4. 保存 ----------
torch.save(model.state_dict(), "mnist_mlp.pth")
跑下来 5 个 epoch 大约能到 97% 准确率。
推理(用训练好的模型)
model = MLP().to(device)
model.load_state_dict(torch.load("mnist_mlp.pth"))
model.eval()
# 拿一张图片预测
img, label = test_set[0]
with torch.no_grad():
logits = model(img.unsqueeze(0).to(device)) # 加 batch 维
probs = torch.softmax(logits, dim=-1)
pred = probs.argmax().item()
print(f"真实: {label}, 预测: {pred}")
用 CNN 替代 MLP(轻松破 99%)
class CNN(nn.Module):
def __init__(self):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(1, 32, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2), # 28 → 14
nn.Conv2d(32, 64, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2), # 14 → 7
nn.Flatten(),
nn.Linear(64 * 7 * 7, 128),
nn.ReLU(),
nn.Linear(128, 10),
)
def forward(self, x):
return self.net(x)
把 MLP 换成 CNN 重新训练——5 个 epoch 能到 99.2%。
你刚刚学到了什么
- 用
torchvision.datasets下载经典数据集 - 用
transforms.Compose做数据预处理 - 用
DataLoader批量加载 - 写
nn.Module+ Sequential 定义模型 - 标准训练循环(zero_grad / backward / step)
- 模型保存与加载
- MLP vs CNN 性能对比
整个流程在 PyTorch 项目里千篇一律——换数据、换模型、换损失函数,模板不变。
下一篇深入讲 CNN 原理。