训练教程

1. 工程简介

本工程是在 Ultralytics 框架基础上扩展的语义分割训练工程,核心特点不是只支持单一模型,而是可以通过切换 yaml 配置文件,快速训练多种不同结构的语义分割模型。

当前工程已经支持的主要语义分割模型家族包括:

  • UNet
  • UNet++
  • DeepLabV3+
  • DPT
  • FPN
  • PSPNet
  • MAnet
  • PAN
  • Linknet
  • UPerNet
  • Segformer

2. 本工程的优势

2.1 用 YAML 切换模型,训练入口统一

本工程最大的优势是:

  • 不需要为每个模型单独写一套训练脚本
  • 不需要手动改模型源码才能切换结构
  • 只需要替换 ultralytics/cfg/models/... 下的模型 yaml 文件,就可以训练不同模型

也就是说,你可以保持:

  • 相同的数据集
  • 相同的训练入口
  • 相同的指标输出
  • 相同的验证流程

只通过修改模型 YAML 来完成不同网络结构之间的对比实验。

这对论文实验、模型横向对比、参数调优非常方便。

2.2 模型隔离良好,改一个家族不容易影响另一个家族

本工程对新增模型采用了隔离式集成思路:

  • 模型配置在 ultralytics/cfg/models/模型家族/
  • 模型实现放在 ultralytics/nn/模型家族/
  • 不直接污染原始 UNet 训练逻辑

这种结构的好处是:

  • 容易维护
  • 容易继续扩展新模型
  • 出问题时更容易定位
  • 不容易影响已经稳定的模型家族

2.3 保持 Ultralytics 风格

本工程保留了 Ultralytics 的很多优点:

  • 统一训练接口
  • 统一验证接口
  • 统一日志输出
  • 统一结果保存目录
  • 可以直接使用 YOLO(...) 的方式加载模型配置或权重

完整代码下载地址,包含代码+数据集+视频教程:地址
大模型注册地址:https://aigocode.com/invite/8P8F4A8U

3. 环境配置

3.1 推荐环境

建议使用与你当前工程验证一致的环境:

  • Python 3.9
  • PyTorch 2.5.1
  • CUDA 12.4

你当前实际验证环境示例:

Python 3.9.21
torch 2.5.1+cu124

3.2 创建 Conda 环境

conda create -n py39 python=3.9 -y
conda activate py39

3.3 安装 PyTorch

请按你的 CUDA 版本安装对应的 PyTorch。以 CUDA 12.4 为例:

pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124

3.4 安装工程依赖

在工程根目录执行:

pip install -e .

如果你的环境里还缺少常见依赖,也可以补装:

pip install timm segmentation-models-pytorch prettytable

3.5 验证环境

python -c "import torch; print(torch.__version__); print(torch.cuda.is_available())"

如果返回 CUDA 可用,说明基础环境正常。

4. 工程目录说明

与训练最相关的目录如下:

ultralytics-main/
├─ train.py
├─ val.py
├─ ultralytics/
│  ├─ cfg/
│  │  ├─ datasets/
│  │  │  └─ my-semantic-seg.yaml
│  │  └─ models/
│  │     ├─ unet/
│  │     ├─ unetplusplus/
│  │     ├─ deeplabv3plus/
│  │     ├─ dpt/
│  │     ├─ fpn/
│  │     ├─ pspnet/
│  │     ├─ manet/
│  │     ├─ pan/
│  │     ├─ linknet/
│  │     ├─ upernet/
│  │     └─ segformer/
│  └─ nn/
│     ├─ unet/
│     ├─ deeplabv3plus/
│     ├─ dpt/
│     ├─ fpn/
│     ├─ pspnet/
│     ├─ manet/
│     ├─ pan/
│     ├─ linknet/
│     ├─ upernet/
│     └─ segformer/
└─ runs/
   └─ semantic/

5. 数据集准备

5.1 推荐数据集目录结构

数据集制作可以看视频讲解
你的语义分割数据集建议组织为:

dataset/
└─ split_dataset/
   ├─ images/
   │  ├─ train/
   │  ├─ val/
   │  └─ test/
   └─ masks/
      ├─ train/
      ├─ val/
      └─ test/

要求:

  • 图片和 mask 必须一一对应
  • 图片文件名和 mask 文件名的 stem 必须一致
  • mask 推荐使用单通道灰度图
  • mask 中每个像素值表示类别 ID

例如:

images/train/0001.jpg
masks/train/0001.png

5.2 mask 标注规则

mask 像素值约定:

  • 0 ~ (nc - 1):有效类别 ID
  • 255:忽略区域

例如二分类:

  • 0:background
  • 1:spot

5.3 数据集 YAML 配置

训练使用的数据集配置文件为:

  • ultralytics/cfg/datasets/my-semantic-seg.yaml

一个典型配置如下:

path: dataset/split_dataset

train: images/train
val: images/val
test: images/test

masks_train: masks/train
masks_val: masks/val
masks_test: masks/test

nc: 2

names:
  0: background
  1: spot

含义说明:

  • path:数据集根目录
  • train / val / test:图像目录
  • masks_train / masks_val / masks_test:mask 目录
  • nc:类别数
  • names:类别名称

6. 如何选择模型

6.1 模型 YAML 的位置

所有语义分割模型配置都放在:

ultralytics/cfg/models/

例如:

ultralytics/cfg/models/unet/unet-resnet34-expanded.yaml
ultralytics/cfg/models/unetplusplus/unetplusplus-resnet34-expanded.yaml
ultralytics/cfg/models/deeplabv3plus/deeplabv3plus-resnet34.yaml
ultralytics/cfg/models/fpn/fpn-resnet34-expanded.yaml
ultralytics/cfg/models/pspnet/pspnet-resnet34-expanded.yaml
ultralytics/cfg/models/linknet/linknet-resnet34-expanded.yaml
ultralytics/cfg/models/upernet/upernet-resnet34-expanded.yaml
ultralytics/cfg/models/segformer/segformer-mit_b0-expanded.yaml
ultralytics/cfg/models/dpt/dpt-vitb16-384-expanded.yaml

6.2 推荐理解方式

你可以把每个 YAML 理解为:

  • 一个模型家族
  • 一个具体骨干网络
  • 一套完整的模型结构定义

例如:

  • unet-resnet34-expanded.yaml:UNet + ResNet34
  • fpn-efficientnet-b0-expanded.yaml:FPN + EfficientNet-B0
  • segformer-mit_b0-expanded.yaml:Segformer + MiT-B0
  • dpt-vitb16-384-expanded.yaml:DPT + ViT-B/16-384

6.3 常见模型选择建议

  • 数据量不大,先做基线:UNet / UNet++ / DeepLabV3+
  • 希望速度和效果平衡:FPN / PSPNet / Linknet
  • 想尝试 Transformer:DPT / UPerNet / Segformer
  • 想做对比实验:保持数据集和超参数一致,只替换 YAML

7. 训练方法

7.1 使用 train.py 训练

本工程默认训练入口是根目录的:

  • train.py

你只需要修改其中的 yaml_path 即可切换模型。

例如训练 DeepLabV3+ EfficientNet-B0

yaml_path = 'ultralytics/cfg/models/deeplabv3plus/deeplabv3plus-efficientnet-b0.yaml'

例如训练 UNet ResNet34

yaml_path = 'ultralytics/cfg/models/unet/unet-resnet34-expanded.yaml'

例如训练 Segformer MiT-B0

yaml_path = 'ultralytics/cfg/models/segformer/segformer-mit_b0-expanded.yaml'

修改完成后直接运行:

python train.py

7.2 train.py 的核心训练参数

当前脚本中的核心训练配置大致如下:

model.train(
    data='ultralytics/cfg/datasets/my-semantic-seg.yaml',
    task='semantic',
    imgsz=640,
    epochs=100,
    batch=4,
    workers=0,
    optimizer='SGD',
    lr0=0.001,
    lrf=0.01,
    momentum=0.937,
    weight_decay=0.0005,
    warmup_epochs=3.0,
    warmup_momentum=0.8,
    loss=loss_fn,
    augment=True,
    mosaic=1.0,
    close_mosaic=10,
    hsv_h=0.015,
    hsv_s=0.7,
    hsv_v=0.4,
    translate=0.1,
    scale=0.5,
    fliplr=0.5,
    val=True,
    plots=True,
    patience=50,
    seed=0,
    deterministic=True,
    verbose=True,
)

7.3 参数说明

  • data:数据集 YAML
  • task='semantic':语义分割任务
  • imgsz:输入尺寸
  • epochs:训练轮数
  • batch:批大小
  • workers:DataLoader worker 数
  • optimizer:优化器
  • lr0:初始学习率
  • lrf:最终学习率比例
  • momentum:动量
  • weight_decay:权重衰减
  • loss:损失函数类型
  • augment:是否开启数据增强
  • val:训练时是否验证
  • plots:是否保存训练曲线和可视化结果

7.4 损失函数切换

本工程支持多种损失函数。在 train.py 中通过:

loss_fn = 'ce'

可选项包括:

  • ce
  • dice
  • focal
  • tversky
  • jaccard
  • lovasz
  • mcc
  • bce_dice
  • ce_dice
  • focal_dice

例如改为 Dice:

loss_fn = 'dice'

7.5 是否使用预训练

当前训练脚本中还保留了一个可选入口:

yolo_pretrained_weights = None

如果你希望加载某个 YOLO 权重的 backbone,可以设置为:

yolo_pretrained_weights = 'yolo11n.pt'

如果不需要,则保持:

yolo_pretrained_weights = None

注意:

  • 是否适合加载预训练,要看模型结构兼容性
  • 如果只是做稳定对比实验,建议先统一设为 None

8. 多模型训练示例

8.1 训练 UNet

yaml_path = 'ultralytics/cfg/models/unet/unet-resnet34-expanded.yaml'

8.2 训练 UNet++

yaml_path = 'ultralytics/cfg/models/unetplusplus/unetplusplus-resnet34-expanded.yaml'

8.3 训练 DeepLabV3+

yaml_path = 'ultralytics/cfg/models/deeplabv3plus/deeplabv3plus-resnet34.yaml'

8.4 训练 FPN

yaml_path = 'ultralytics/cfg/models/fpn/fpn-resnet34-expanded.yaml'

8.5 训练 PSPNet

yaml_path = 'ultralytics/cfg/models/pspnet/pspnet-resnet34-expanded.yaml'

8.6 训练 Linknet

yaml_path = 'ultralytics/cfg/models/linknet/linknet-resnet34-expanded.yaml'

8.7 训练 UPerNet

yaml_path = 'ultralytics/cfg/models/upernet/upernet-resnet34-expanded.yaml'

8.8 训练 Segformer

yaml_path = 'ultralytics/cfg/models/segformer/segformer-mit_b0-expanded.yaml'

8.9 训练 DPT

yaml_path = 'ultralytics/cfg/models/dpt/dpt-vitb16-384-expanded.yaml'

说明:

  • 你完全不需要换训练框架
  • 只需要替换 yaml_path
  • 这就是本工程最核心的优势

9. 训练结果保存位置

训练结果默认保存在:

runs/semantic/

例如:

runs/semantic/train39/
├─ args.yaml
├─ results.csv
└─ weights/
   ├─ best.pt
   └─ last.pt

含义:

  • args.yaml:本次训练参数
  • results.csv:每轮指标
  • best.pt:最佳权重
  • last.pt:最后一轮权重

10. 模型验证与结果导出

验证脚本为:

  • val.py

直接运行:

python val.py

脚本会默认加载你指定的权重,例如:

model_path = 'runs/semantic/train39/weights/best.pt'

验证结果会输出:

  • 单类指标表
  • 全局指标表
  • 模型参数量和 GFLOPs

并保存到:

paper_data.txt

11. 当前指标解释

本工程当前验证输出中:

  • 单类表显示:IoU / Dice / Precision / Recall
  • 全局表显示:mIoU / fwIoU / mPA / Dice / Precision / PixAcc

说明:

  • Recall = TP / (TP + FN)
  • 在当前实现中,这个值与单类像素准确率数值相同
  • 因此单类表不再额外重复显示 Acc
  • mPA 是各类别像素准确率的平均
  • PixAcc 是全局像素准确率

所以:

  • mPA 和全局 mean Recall 在当前实现中数值一致
  • PixAcc 不是 mPA

12. 常见注意事项

12.1 路径尽量使用正斜杠

建议在 Python 脚本中使用:

model_path = 'runs/semantic/train39/weights/best.pt'

不要写成:

model_path = 'runs\semantic\train39\weights\best.pt'

因为反斜杠会触发转义字符问题,例如:

  • \t
  • \b

12.2 DPT 输入尺寸要注意

某些 DPT 变体和 ViT 输入尺寸强相关,例如:

  • 224
  • 384

训练前要确认模型 YAML 与 imgsz 是否匹配。

12.3 不同模型的显存需求差异很大

建议经验:

  • 显存紧张时先减小 batch
  • 再减小 imgsz
  • Transformer 模型通常更吃显存

12.4 某些模型对 batch 更敏感

例如某些带 BN 的模型:

  • batch=1 可能不稳定
  • 更适合 batch>=2

如果训练报 BN 相关错误,优先尝试增大 batch 或减小输入尺寸。

12.5 验证与训练建议使用同一环境

建议始终在同一个 Conda 环境里完成:

  • 训练
  • 验证
  • 导出结果

避免因为 PyTorch、timm、CUDA 版本差异造成额外问题。

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐