PyTorch 入门第01个项目:MNIST 手写数字识别完整流程

一、项目简介

前面已经完成了 Python 开发环境配置和 PyTorch GPU 版安装。本文开始运行第一个 PyTorch 入门项目:MNIST 手写数字识别

MNIST 是深度学习入门中非常经典的数据集,包含 0-9 共 10 类手写数字图片。每张图片大小为 28×28,是灰度图像。

本文使用一个简单的全连接神经网络完成手写数字分类任务。重点不是追求最高准确率,而是理解 PyTorch 项目的基本流程:

数据加载 → 模型定义 → 前向传播 → 损失计算 → 反向传播 → 参数更新 → 准确率测试 → 预测可视化

二、运行环境

项目 配置
操作系统 Windows 11
Python 版本 Python 3.10
深度学习框架 PyTorch
开发工具 PyCharm
环境管理 Anaconda / conda
数据集 MNIST

如果使用之前创建的 PyTorch 环境,可以先激活环境:

三、MNIST 数据集说明

MNIST 图片大小为:

28 × 28

由于本文使用的是全连接神经网络,所以需要把每张图片从二维图像展平成一维向量:

28 × 28 = 784

代码中对应操作是:

x.view(-1, 28 * 28)

其中,-1 表示自动推断 batch size。

四、完整代码

import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
import matplotlib.pyplot as plt


# 定义神经网络模型
class Net(torch.nn.Module):

    def __init__(self):
        super().__init__()

        # MNIST 图片为 28×28,展平后是 784 维
        self.fc1 = torch.nn.Linear(28 * 28, 64)
        self.fc2 = torch.nn.Linear(64, 64)
        self.fc3 = torch.nn.Linear(64, 64)

        # 输出 10 个类别,对应数字 0-9
        self.fc4 = torch.nn.Linear(64, 10)

    def forward(self, x):
        # 全连接层 + ReLU 激活函数
        x = torch.nn.functional.relu(self.fc1(x))
        x = torch.nn.functional.relu(self.fc2(x))
        x = torch.nn.functional.relu(self.fc3(x))

        # 输出每个类别的对数概率
        x = torch.nn.functional.log_softmax(self.fc4(x), dim=1)
        return x


# 加载 MNIST 数据集
def get_data_loader(is_train):
    to_tensor = transforms.Compose([
        transforms.ToTensor()
    ])

    data_set = MNIST(
        root="",
        train=is_train,
        transform=to_tensor,
        download=True
    )

    return DataLoader(
        data_set,
        batch_size=15,
        shuffle=True
    )


# 评估模型准确率
def evaluate(test_data, net):
    n_correct = 0
    n_total = 0

    # 测试阶段不需要计算梯度
    with torch.no_grad():
        for (x, y) in test_data:
            # 将图片展平成 784 维
            outputs = net.forward(x.view(-1, 28 * 28))

            for i, output in enumerate(outputs):
                # 取预测概率最大的类别
                if torch.argmax(output) == y[i]:
                    n_correct += 1
                n_total += 1

    return n_correct / n_total


def main():
    # 加载训练集和测试集
    train_data = get_data_loader(is_train=True)
    test_data = get_data_loader(is_train=False)

    # 创建模型
    net = Net()

    # 训练前测试初始准确率
    print("initial accuracy:", evaluate(test_data, net))

    # Adam 优化器
    optimizer = torch.optim.Adam(net.parameters(), lr=0.001)

    # 训练 2 个 epoch
    for epoch in range(2):
        for (x, y) in train_data:
            # 清空梯度
            net.zero_grad()

            # 前向传播
            output = net.forward(x.view(-1, 28 * 28))

            # 计算损失
            loss = torch.nn.functional.nll_loss(output, y)

            # 反向传播
            loss.backward()

            # 更新参数
            optimizer.step()

        print("epoch", epoch, "accuracy:", evaluate(test_data, net))

    # 显示部分预测结果
    for (n, (x, _)) in enumerate(test_data):
        if n > 3:
            break

        predict = torch.argmax(net.forward(x[0].view(-1, 28 * 28)))

        plt.figure(n)
        plt.imshow(x[0].view(28, 28), cmap="gray")
        plt.title("prediction: " + str(int(predict)))

    plt.show()


if __name__ == "__main__":
    main()

五、关键代码解释

1. 为什么输入是 28×28?

MNIST 每张图片大小是 28×28。由于本文使用全连接网络,所以需要将图片展平成:

784 维向量

对应代码:

x.view(-1, 28 * 28)

2. 为什么最后输出是 10?

因为 MNIST 是 0-9 的数字分类任务,一共有 10 个类别:

self.fc4 = torch.nn.Linear(64, 10)

3. log_softmaxnll_loss 的关系

模型最后使用:

torch.nn.functional.log_softmax()

损失函数使用:

torch.nn.functional.nll_loss()

这两个通常搭配使用,适合分类任务。

4. 训练过程包括哪些步骤?

核心训练流程如下:

net.zero_grad()
output = net.forward(x.view(-1, 28 * 28))
loss = torch.nn.functional.nll_loss(output, y)
loss.backward()
optimizer.step()

对应含义:

代码 含义
net.zero_grad() 清空上一轮梯度
net.forward() 前向传播
nll_loss() 计算损失
loss.backward() 反向传播
optimizer.step() 更新模型参数

六、运行结果

运行后,控制台会输出训练前和训练后的准确率:

初始准确率较低,是因为模型参数是随机初始化的。经过训练后,准确率明显提高,说明模型已经从训练集中学习到了手写数字的基本特征。

训练完成后,程序会显示几张测试集图片,并在标题中显示模型预测结果:

七、常见问题与避坑

1. PyCharm 提示找不到 torch 怎么办?

大概率是解释器选错了。需要选择安装 PyTorch 的 conda 环境,例如:

pytorch-gpu 环境中的 python.exe

不要选择 base 环境或系统自带 Python。

2. MNIST 下载失败怎么办?

可能是网络问题,可以多运行几次,或者更换网络环境。

3. 这段代码有没有使用 GPU?

这段代码默认运行在 CPU 上。对于 MNIST 这种小项目,CPU 也可以正常运行。

如果想使用 GPU,需要将模型和数据都移动到 CUDA 设备上,例如:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net = net.to(device)
x = x.to(device)
y = y.to(device)

不过作为第一个入门项目,建议先把完整训练流程跑通。

4. 为什么不用 CNN?

本文使用全连接网络,是为了让入门者更容易理解完整训练流程。对于图像任务,CNN 通常效果更好,后续可以继续使用卷积神经网络改进 MNIST 分类效果。

八、总结

本文完成了第一个 PyTorch 入门项目:MNIST 手写数字识别。

通过这个项目,主要理解了:

  1. 如何加载 MNIST 数据集;
  2. 如何定义简单神经网络;
  3. 如何进行前向传播;
  4. 如何计算损失;
  5. 如何反向传播并更新参数;
  6. 如何评估准确率;
  7. 如何显示预测结果。

这个项目虽然简单,但已经包含了深度学习训练的完整流程。后续可以继续学习 CNN、CIFAR-10 图像分类、迁移学习和 YOLO 目标检测项目。

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐