前言

本文记录了我从零开始使用PyTorch搭建LeNet-5模型,对8类蔬菜水果图像分类的完整过程。重点分享训练中遇到的类别混淆问题(苹果 vs 草莓,洋葱 vs 橙子)以及如何通过数据增强辅助分类头将准确率从83.75%提升到96.25%。代码已开源,适合新手参考。

一、项目背景

最近接到一个考核任务:基于PyTorch实现LeNet-5模型,适配自定义数据集(类别数>5),完成训练、测试及可视化。我选择了8类常见的蔬菜水果:

  • 类别:apple, banana, cucumber, nut, onion, orange, strawberry, tomato

  • 数据量:每类训练50张,验证10张,总计480张(所有图片自己收集并清洗)

  • 要求:不能使用MNIST、CIFAR等标准数据集

我的起点:有YOLO使用经验(目标检测),但PyTorch分类任务是首次接触。

二、环境配置

# 创建虚拟环境
conda create -n lenet5_env python=3.10 -y
conda activate lenet5_env

# 安装PyTorch(GPU版本)

pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu128

# 安装其他依赖
pip install matplotlib numpy scikit-learn seaborn pillow

项目结构:

Vegetable_Classification_LeNet5/
├── dataset/
│   ├── train/        # 8个类别文件夹,每类50张
│   └── val/          # 8个类别文件夹,每类10张
├── train.py          # 训练脚本
├── evaluate.py       # 评估脚本
└── run_xxx/          # 训练结果保存

三、LeNet-5模型搭建

LeNet-5原始结构是针对MNIST灰度图(1×32×32,10类)。我的适配修改:

class LeNet5(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.features = nn.Sequential(
            # 修改1:输入通道从1改为3(彩色图)
            nn.Conv2d(3, 6, kernel_size=5, stride=1, padding=0),
            nn.Tanh(),
            nn.AvgPool2d(2, 2),
            nn.Conv2d(6, 16, kernel_size=5),
            nn.Tanh(),
            nn.AvgPool2d(2, 2),
        )
        self.classifier = nn.Sequential(
            nn.Linear(16 * 5 * 5, 120),
            nn.Tanh(),
            nn.Linear(120, 84),
            nn.Tanh(),
            # 修改2:输出从10类改为num_classes(8)
            nn.Linear(84, num_classes)
        )
    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

四、第一次训练结果:83.75%但存在严重混淆

训练50轮,验证集准确率83.75%。看着还行?但一画混淆矩阵,问题暴露了:

混淆矩阵(初始)

真实\预测 apple banana ... strawberry orange tomato
apple 6 0 ... 4 0 0
onion 0 0 ... 1 8 0
其他6类 10 10 ... 10 10 10

关键问题

  • apple:10张只对了6张,4张被误认为strawberry(都是红色圆形)

  • onion:10张只对了1张,8张被误认为orange,1张被误认为strawberry(洋葱外皮橙黄,与橙子、草莓颜色相近)

教训:不能只看总体准确率,混淆矩阵才是诊断问题的金标准。

五、优化策略一:增强数据增强

原始transform只有Resize+ToTensor+Normalize。我增加了:

train_transform = transforms.Compose([
    transforms.Resize(32),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=25),
    transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.2, hue=0.1),
    transforms.RandomAffine(degrees=15, translate=(0.1, 0.1), scale=(0.9, 1.1)),
    transforms.RandomResizedCrop(32, scale=(0.8, 1.0)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

效果

  • apple准确率:60% → 70%

  • onion准确率:10% → 80%

  • 总体准确率:83.75% → 88%

但apple仍有2张被误认为strawberry,说明仅靠数据增强不够。

六、优化策略二:辅助分类头(针对易混淆类别)

核心思想:在模型中间层添加辅助分类器,专门让模型学习区分混淆对(apple, strawberry)和(onion, orange)。

实现代码

class LeNet5WithAuxiliary(nn.Module):
    def __init__(self, num_classes, confusing_pairs):
        super().__init__()
        # 主分支(特征提取+分类)
        self.features = ...   # 同原始LeNet-5
        self.classifier = nn.Sequential(
            nn.Linear(16*5*5, 120), nn.Tanh(),
            nn.Linear(120, 84), nn.Tanh(),
            nn.Linear(84, num_classes)
        )
        # 辅助分支:输入flatten后的特征,输出每个混淆对的2分类
        self.auxiliary = nn.Sequential(
            nn.Linear(16*5*5, 64),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(64, len(confusing_pairs)*2)
        )
        self.confusing_pairs = confusing_pairs  # [(0,6), (4,5)] 对应(apple,strawberry)和(onion,orange)

    def forward(self, x, return_aux=False):
        x = self.features(x)
        x = torch.flatten(x, 1)
        main_out = self.classifier(x)
        if return_aux:
            aux_out = self.auxiliary(x)
            return main_out, aux_out
        return main_out

组合损失:主损失(CrossEntropy)+ 辅助损失(仅对属于混淆对的样本计算二分类交叉熵),aux_weight=0.3

训练过程:每步同时计算主损失和辅助损失,反向传播更新全部参数。

效果

  • apple准确率:70% → 90%(9/10)

  • onion准确率:80% → 100%(10/10)

  • 总体准确率:96.25%

  • 唯一错误:1个apple仍被误认为strawberry(那个苹果特别红且形状圆润)

七、最终结果对比

指标 优化前 优化后
总体准确率 83.75% 96.25%
apple准确率 60% (6/10) 90% (9/10)
onion准确率 10% (1/10) 100% (10/10)
apple→strawberry混淆 4张 1张
onion→orange混淆 8张 0张

混淆矩阵(最终)

注:数据集来源于kaggle,训练集每类50张,验证集每类10张。

八、踩坑与解决汇总

问题 解决方案
PyTorch安装失败 使用官方CPU安装命令,加--index-url
卷积输出尺寸不匹配 打印每层shape,对照LeNet-5论文调整
总体准确率高但个别类很差 画出混淆矩阵,定位混淆对
apple与strawberry混淆 数据增强+辅助分类头专门优化
onion与orange混淆 同上
辅助损失权重怎么选 实验0.1/0.3/0.5,0.3效果最好
过拟合倾向 增加权重衰减(weight_decay=1e-4)+数据增强

九、完整代码

代码已上传至GitHub:点击查看

核心文件:

  • train.py:含数据增强+辅助分类头的完整训练脚本

  • evaluate.py:加载模型生成混淆矩阵和分类报告

十、总结与展望

收获

  1. 从零跑通了PyTorch分类任务完整流程

  2. 学会了用混淆矩阵诊断模型问题,而不是盲目调参

  3. 理解了数据增强和辅助损失的实际应用场景

后续计划

  • 收集更多带叶子的苹果图片,进一步减少与草莓的混淆

  • 尝试ResNet-18等现代网络对比效果

  • 增加验证集数量(每类10张偏少,评估波动大)

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐