UNET,UNET++,DEEPLABV3+,DPT,PAN,Segformer保姆级训练教程!
训练教程
1. 工程简介
本工程是在 Ultralytics 框架基础上扩展的语义分割训练工程,核心特点不是只支持单一模型,而是可以通过切换 yaml 配置文件,快速训练多种不同结构的语义分割模型。
当前工程已经支持的主要语义分割模型家族包括:
UNetUNet++DeepLabV3+DPTFPNPSPNetMAnetPANLinknetUPerNetSegformer
2. 本工程的优势
2.1 用 YAML 切换模型,训练入口统一
本工程最大的优势是:
- 不需要为每个模型单独写一套训练脚本
- 不需要手动改模型源码才能切换结构
- 只需要替换
ultralytics/cfg/models/...下的模型yaml文件,就可以训练不同模型
也就是说,你可以保持:
- 相同的数据集
- 相同的训练入口
- 相同的指标输出
- 相同的验证流程
只通过修改模型 YAML 来完成不同网络结构之间的对比实验。
这对论文实验、模型横向对比、参数调优非常方便。
2.2 模型隔离良好,改一个家族不容易影响另一个家族
本工程对新增模型采用了隔离式集成思路:
- 模型配置在
ultralytics/cfg/models/模型家族/ - 模型实现放在
ultralytics/nn/模型家族/ - 不直接污染原始
UNet训练逻辑
这种结构的好处是:
- 容易维护
- 容易继续扩展新模型
- 出问题时更容易定位
- 不容易影响已经稳定的模型家族
2.3 保持 Ultralytics 风格
本工程保留了 Ultralytics 的很多优点:
- 统一训练接口
- 统一验证接口
- 统一日志输出
- 统一结果保存目录
- 可以直接使用
YOLO(...)的方式加载模型配置或权重
完整代码下载地址,包含代码+数据集+视频教程:地址
大模型注册地址:https://aigocode.com/invite/8P8F4A8U
3. 环境配置
3.1 推荐环境
建议使用与你当前工程验证一致的环境:
- Python
3.9 - PyTorch
2.5.1 - CUDA
12.4
你当前实际验证环境示例:
Python 3.9.21
torch 2.5.1+cu124
3.2 创建 Conda 环境
conda create -n py39 python=3.9 -y
conda activate py39
3.3 安装 PyTorch
请按你的 CUDA 版本安装对应的 PyTorch。以 CUDA 12.4 为例:
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124
3.4 安装工程依赖
在工程根目录执行:
pip install -e .
如果你的环境里还缺少常见依赖,也可以补装:
pip install timm segmentation-models-pytorch prettytable
3.5 验证环境
python -c "import torch; print(torch.__version__); print(torch.cuda.is_available())"
如果返回 CUDA 可用,说明基础环境正常。
4. 工程目录说明
与训练最相关的目录如下:
ultralytics-main/
├─ train.py
├─ val.py
├─ ultralytics/
│ ├─ cfg/
│ │ ├─ datasets/
│ │ │ └─ my-semantic-seg.yaml
│ │ └─ models/
│ │ ├─ unet/
│ │ ├─ unetplusplus/
│ │ ├─ deeplabv3plus/
│ │ ├─ dpt/
│ │ ├─ fpn/
│ │ ├─ pspnet/
│ │ ├─ manet/
│ │ ├─ pan/
│ │ ├─ linknet/
│ │ ├─ upernet/
│ │ └─ segformer/
│ └─ nn/
│ ├─ unet/
│ ├─ deeplabv3plus/
│ ├─ dpt/
│ ├─ fpn/
│ ├─ pspnet/
│ ├─ manet/
│ ├─ pan/
│ ├─ linknet/
│ ├─ upernet/
│ └─ segformer/
└─ runs/
└─ semantic/
5. 数据集准备
5.1 推荐数据集目录结构
数据集制作可以看视频讲解
你的语义分割数据集建议组织为:
dataset/
└─ split_dataset/
├─ images/
│ ├─ train/
│ ├─ val/
│ └─ test/
└─ masks/
├─ train/
├─ val/
└─ test/
要求:
- 图片和 mask 必须一一对应
- 图片文件名和 mask 文件名的
stem必须一致 - mask 推荐使用单通道灰度图
- mask 中每个像素值表示类别 ID
例如:
images/train/0001.jpg
masks/train/0001.png
5.2 mask 标注规则
mask 像素值约定:
0 ~ (nc - 1):有效类别 ID255:忽略区域
例如二分类:
0:background1:spot
5.3 数据集 YAML 配置
训练使用的数据集配置文件为:
ultralytics/cfg/datasets/my-semantic-seg.yaml
一个典型配置如下:
path: dataset/split_dataset
train: images/train
val: images/val
test: images/test
masks_train: masks/train
masks_val: masks/val
masks_test: masks/test
nc: 2
names:
0: background
1: spot
含义说明:
path:数据集根目录train/val/test:图像目录masks_train/masks_val/masks_test:mask 目录nc:类别数names:类别名称
6. 如何选择模型
6.1 模型 YAML 的位置
所有语义分割模型配置都放在:
ultralytics/cfg/models/
例如:
ultralytics/cfg/models/unet/unet-resnet34-expanded.yaml
ultralytics/cfg/models/unetplusplus/unetplusplus-resnet34-expanded.yaml
ultralytics/cfg/models/deeplabv3plus/deeplabv3plus-resnet34.yaml
ultralytics/cfg/models/fpn/fpn-resnet34-expanded.yaml
ultralytics/cfg/models/pspnet/pspnet-resnet34-expanded.yaml
ultralytics/cfg/models/linknet/linknet-resnet34-expanded.yaml
ultralytics/cfg/models/upernet/upernet-resnet34-expanded.yaml
ultralytics/cfg/models/segformer/segformer-mit_b0-expanded.yaml
ultralytics/cfg/models/dpt/dpt-vitb16-384-expanded.yaml
6.2 推荐理解方式
你可以把每个 YAML 理解为:
- 一个模型家族
- 一个具体骨干网络
- 一套完整的模型结构定义
例如:
unet-resnet34-expanded.yaml:UNet + ResNet34fpn-efficientnet-b0-expanded.yaml:FPN + EfficientNet-B0segformer-mit_b0-expanded.yaml:Segformer + MiT-B0dpt-vitb16-384-expanded.yaml:DPT + ViT-B/16-384
6.3 常见模型选择建议
- 数据量不大,先做基线:
UNet / UNet++ / DeepLabV3+ - 希望速度和效果平衡:
FPN / PSPNet / Linknet - 想尝试 Transformer:
DPT / UPerNet / Segformer - 想做对比实验:保持数据集和超参数一致,只替换 YAML
7. 训练方法
7.1 使用 train.py 训练
本工程默认训练入口是根目录的:
train.py
你只需要修改其中的 yaml_path 即可切换模型。
例如训练 DeepLabV3+ EfficientNet-B0:
yaml_path = 'ultralytics/cfg/models/deeplabv3plus/deeplabv3plus-efficientnet-b0.yaml'
例如训练 UNet ResNet34:
yaml_path = 'ultralytics/cfg/models/unet/unet-resnet34-expanded.yaml'
例如训练 Segformer MiT-B0:
yaml_path = 'ultralytics/cfg/models/segformer/segformer-mit_b0-expanded.yaml'
修改完成后直接运行:
python train.py
7.2 train.py 的核心训练参数
当前脚本中的核心训练配置大致如下:
model.train(
data='ultralytics/cfg/datasets/my-semantic-seg.yaml',
task='semantic',
imgsz=640,
epochs=100,
batch=4,
workers=0,
optimizer='SGD',
lr0=0.001,
lrf=0.01,
momentum=0.937,
weight_decay=0.0005,
warmup_epochs=3.0,
warmup_momentum=0.8,
loss=loss_fn,
augment=True,
mosaic=1.0,
close_mosaic=10,
hsv_h=0.015,
hsv_s=0.7,
hsv_v=0.4,
translate=0.1,
scale=0.5,
fliplr=0.5,
val=True,
plots=True,
patience=50,
seed=0,
deterministic=True,
verbose=True,
)
7.3 参数说明
data:数据集 YAMLtask='semantic':语义分割任务imgsz:输入尺寸epochs:训练轮数batch:批大小workers:DataLoader worker 数optimizer:优化器lr0:初始学习率lrf:最终学习率比例momentum:动量weight_decay:权重衰减loss:损失函数类型augment:是否开启数据增强val:训练时是否验证plots:是否保存训练曲线和可视化结果
7.4 损失函数切换
本工程支持多种损失函数。在 train.py 中通过:
loss_fn = 'ce'
可选项包括:
cedicefocaltverskyjaccardlovaszmccbce_dicece_dicefocal_dice
例如改为 Dice:
loss_fn = 'dice'
7.5 是否使用预训练
当前训练脚本中还保留了一个可选入口:
yolo_pretrained_weights = None
如果你希望加载某个 YOLO 权重的 backbone,可以设置为:
yolo_pretrained_weights = 'yolo11n.pt'
如果不需要,则保持:
yolo_pretrained_weights = None
注意:
- 是否适合加载预训练,要看模型结构兼容性
- 如果只是做稳定对比实验,建议先统一设为
None
8. 多模型训练示例
8.1 训练 UNet
yaml_path = 'ultralytics/cfg/models/unet/unet-resnet34-expanded.yaml'
8.2 训练 UNet++
yaml_path = 'ultralytics/cfg/models/unetplusplus/unetplusplus-resnet34-expanded.yaml'
8.3 训练 DeepLabV3+
yaml_path = 'ultralytics/cfg/models/deeplabv3plus/deeplabv3plus-resnet34.yaml'
8.4 训练 FPN
yaml_path = 'ultralytics/cfg/models/fpn/fpn-resnet34-expanded.yaml'
8.5 训练 PSPNet
yaml_path = 'ultralytics/cfg/models/pspnet/pspnet-resnet34-expanded.yaml'
8.6 训练 Linknet
yaml_path = 'ultralytics/cfg/models/linknet/linknet-resnet34-expanded.yaml'
8.7 训练 UPerNet
yaml_path = 'ultralytics/cfg/models/upernet/upernet-resnet34-expanded.yaml'
8.8 训练 Segformer
yaml_path = 'ultralytics/cfg/models/segformer/segformer-mit_b0-expanded.yaml'
8.9 训练 DPT
yaml_path = 'ultralytics/cfg/models/dpt/dpt-vitb16-384-expanded.yaml'
说明:
- 你完全不需要换训练框架
- 只需要替换
yaml_path - 这就是本工程最核心的优势
9. 训练结果保存位置
训练结果默认保存在:
runs/semantic/
例如:
runs/semantic/train39/
├─ args.yaml
├─ results.csv
└─ weights/
├─ best.pt
└─ last.pt
含义:
args.yaml:本次训练参数results.csv:每轮指标best.pt:最佳权重last.pt:最后一轮权重
10. 模型验证与结果导出
验证脚本为:
val.py
直接运行:
python val.py
脚本会默认加载你指定的权重,例如:
model_path = 'runs/semantic/train39/weights/best.pt'
验证结果会输出:
- 单类指标表
- 全局指标表
- 模型参数量和 GFLOPs
并保存到:
paper_data.txt
11. 当前指标解释
本工程当前验证输出中:
- 单类表显示:
IoU / Dice / Precision / Recall - 全局表显示:
mIoU / fwIoU / mPA / Dice / Precision / PixAcc
说明:
Recall = TP / (TP + FN)- 在当前实现中,这个值与单类像素准确率数值相同
- 因此单类表不再额外重复显示
Acc mPA是各类别像素准确率的平均PixAcc是全局像素准确率
所以:
mPA和全局 mean Recall 在当前实现中数值一致PixAcc不是mPA
12. 常见注意事项
12.1 路径尽量使用正斜杠
建议在 Python 脚本中使用:
model_path = 'runs/semantic/train39/weights/best.pt'
不要写成:
model_path = 'runs\semantic\train39\weights\best.pt'
因为反斜杠会触发转义字符问题,例如:
\t\b
12.2 DPT 输入尺寸要注意
某些 DPT 变体和 ViT 输入尺寸强相关,例如:
224384
训练前要确认模型 YAML 与 imgsz 是否匹配。
12.3 不同模型的显存需求差异很大
建议经验:
- 显存紧张时先减小
batch - 再减小
imgsz - Transformer 模型通常更吃显存
12.4 某些模型对 batch 更敏感
例如某些带 BN 的模型:
batch=1可能不稳定- 更适合
batch>=2
如果训练报 BN 相关错误,优先尝试增大 batch 或减小输入尺寸。
12.5 验证与训练建议使用同一环境
建议始终在同一个 Conda 环境里完成:
- 训练
- 验证
- 导出结果
避免因为 PyTorch、timm、CUDA 版本差异造成额外问题。
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐



所有评论(0)