unet模型学习笔记
一、核心文件夹 (The Folders)
1.nets/ (神经网络定义库)
-
职责:存放 U-Net 的核心代码。
-
unet.py:定义了 U-Net 的整体架构(编码器、解码器、跳跃连接)。 -
vgg.py/resnet.py:U-Net 通常需要一个“主干网络”来提取特征,这里存放的就是常用的主干网络实现。

nets/unet.py 是整个仓库的核心,它像一个“拼装车间”,把特征提取网络(VGG或ResNet)和上采样网络组合在一起。
1. 初始化部分 (__init__):定义网络组件
这部分代码的作用是准备好所有需要的“积木”。
import torch.nn as nn
from nets.vgg import VGG16
from nets.resnet import resnet50
class Unet(nn.Module):
def __init__(self, num_classes = 21, pretrained = False, backbone = 'vgg'):
super(Unet, self).__init__()
# num_classes: 分类数量(如:背景+猫+狗=3类)。
# pretrained: 是否加载预训练权重(即别人练好的“经验”)。
# backbone: 选择主干网络,默认用 vgg。
if backbone == 'vgg':
self.vgg = VGG16(pretrained = pretrained)
# 加载 VGG16 模型作为编码器(左侧下采样部分)。
in_filters = [192, 384, 768, 1024]
# 定义解码器(右侧)每一层接收到的通道数总和。
elif backbone == "resnet50":
self.resnet = resnet50(pretrained = pretrained)
# 或者选择 ResNet50 作为编码器。
in_filters = [192, 512, 1024, 3072]
# ResNet 通道多,所以这里的数字比 VGG 大。
out_filters = [64, 128, 256, 512]
# 定义解码器每一层输出的通道数。
# 下面定义 4 个上采样模块(U-Net右侧的四个上升台阶)
# unetUp 是作者在下面自定义的一个类,负责“上采样 + 特征融合 + 两次卷积”
self.up_concat4 = unetUp(in_filters[3], out_filters[3]) # 处理最小的特征图
self.up_concat3 = unetUp(in_filters[2], out_filters[2]) # 处理倒数第二层
self.up_concat2 = unetUp(in_filters[1], out_filters[1]) # 处理中间层
self.up_concat1 = unetUp(in_filters[0], out_filters[0]) # 恢复到较大的特征图
# 最后的一层卷积,将通道数调整为最终的类别数
self.final = nn.Conv2d(out_filters[0], num_classes, 1)
# kernel_size=1 代表 1x1 卷积,作用是把特征图映射到类别概率上。
2. 前向传播部分 (forward):定义数据流动逻辑
这是代码运行时的逻辑路线图,决定了图片进来到结果出去的过程。
def forward(self, inputs):
# inputs: 输入的原始图片,比如 [1, 3, 512, 512] (1张图, 3通道RGB, 512x512分辨率)
# 第一步:利用主干网络提取 5 组不同尺度的特征图
if self.backbone == "vgg":
feat1, feat2, feat3, feat4, feat5 = self.vgg.forward(inputs)
# feat1 是最大的特征图(细节多),feat5 是最小的(语义强)。
elif self.backbone == "resnet50":
feat1, feat2, feat3, feat4, feat5 = self.resnet.forward(inputs)
# 第二步:开始由小变大,进行上采样和特征融合(跳跃连接)
# 这一步是 U-Net 的灵魂:把深层信息(feat5)和浅层信息(feat4)拼在一起
up4 = self.up_concat4(feat4, feat5)
# 把 feat5 放大,然后和 feat4 左右拼接。
up3 = self.up_concat3(feat3, up4)
# 把刚才的结果 up4 再次放大,并和 feat3 拼接。
up2 = self.up_concat2(feat2, up3)
# 继续放大,和 feat2 拼接。
up1 = self.up_concat1(feat1, up2)
# 最后一次放大,和最初始的浅层特征 feat1 拼接,找回丢失的边缘细节。
# 第三步:输出最终结果
final = self.final(up1)
# 把拼接完的所有信息变成类别预测图。
return final
3. 解码器核心类 (unetUp) 的逻辑
在 unet.py 文件的上方,你会看到一个 unetUp 类,它是实现“跳跃连接(Skip Connection)”的具体工具:
class unetUp(nn.Module):
def __init__(self, in_size, out_size):
super(unetUp, self).__init__()
# nn.Upsample: 上采样。将图片尺寸放大 2 倍(比如从 16x16 变 32x32)。
self.upsample = nn.Upsample(scale_factor = 2, mode = 'bilinear', align_corners = True)
# 拼接后通道数会增加,这里用两次卷积来融合这些信息
self.conv1 = nn.Conv2d(in_size, out_size, kernel_size = 3, padding = 1)
self.conv2 = nn.Conv2d(out_size, out_size, kernel_size = 3, padding = 1)
def forward(self, inputs1, inputs2):
# inputs1: 来自左边的特征(浅层),inputs2: 来自下方的特征(深层)
outputs2 = self.upsample(inputs2) # 先把深层特征图放大
# torch.cat: 【最关键】在通道维度(dim=1)把两张特征图“粘”在一起
outputs = torch.cat([inputs1, outputs2], dim = 1)
# 经过两次卷积,让网络自己学习如何融合这两种不同的特征
outputs = self.conv1(outputs)
outputs = self.conv2(outputs)
return outputs
直观总结:
-
backbone:是挖掘机,负责把图片里的特征(轮廓、颜色、物体含义)挖出来。 -
torch.cat:是缝合线,负责把刚才挖出来的深层信息(知道这是猫)和浅层信息(知道猫的胡须在哪)缝在一起。 -
nn.Upsample:是放大镜,把被挖出来的、已经变小的特征图变回原来图片的大小。 -
final层:是调色板,把模型对每个像素的理解涂上颜色(比如:红色代表细胞,背景是黑色)。
nets/resnet.py 通常被用作 U-Net 的主干特征提取网络(Backbone)。相比于原始的 U-Net(使用简单的卷积堆叠),使用 ResNet50 作为主干可以提取更深层的特征,且更容易训练。
由于代码较长,我为你拆解最核心的 Bottleneck(瓶颈结构) 和 ResNet 主类,这是理解 ResNet 的关键。
1. 核心组件:Bottleneck 类
这是 ResNet50/101/152 的基本单元。它之所以叫“瓶颈”,是因为中间用了一个 3x3 的小卷积,两头用 1x1 卷积加宽。
class Bottleneck(nn.Module):
# expansion = 4 意味着输出通道数是输入通道数的 4 倍
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(Bottleneck, self).__init__()
# 1x1 卷积:用来压缩通道数,减少计算量
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
# BatchNorm:对卷积后的结果做标准化,让数据分布更稳定
self.bn1 = nn.BatchNorm2d(planes)
# 3x3 卷积:核心特征提取,stride=2 时会缩小图片尺寸(下采样)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
# 1x1 卷积:用来放大通道数(planes * 4)
self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
# ReLU:激活函数,给网络增加非线性能力
self.relu = nn.ReLU(inplace=True)
# downsample:捷径分支(Shortcut)。如果输入输出尺寸对不上,用它来调整
self.downsample = downsample
self.stride = stride
def forward(self, x):
# identity:保存原始输入,为了后面的“残差连接”
identity = x
# 第一层卷积 + 标准化 + 激活
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
# 第二层卷积 + 标准化 + 激活
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
# 第三层卷积 + 标准化(注意这里最后不激活,要加完残差再激活)
out = self.conv3(out)
out = self.bn3(out)
# 如果输入和输出的尺寸/通道数不同,就对输入进行调整
if self.downsample is not None:
identity = self.downsample(x)
# 【核心:残差连接】将处理后的特征和原始输入相加
out += identity
# 最后再统一做一次激活
out = self.relu(out)
return out
2. 主类:ResNet 结构
这个类负责把上面的 Bottleneck 积木搭成一个高楼。
class ResNet(nn.Module):
def __init__(self, block, layers, num_classes=1000):
# block: 使用哪种积木(比如 Bottleneck)
# layers: 每一层有多少个积木(ResNet50 是 [3, 4, 6, 3])
self.inplanes = 64
super(ResNet, self).__init__()
# 第一层:大的 7x7 卷积,直接把图片缩小一半
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
# 最大池化:再次缩小图片尺寸
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
# 下面是 ResNet 的四个大阶段(Stage)
# 每个 stage 都会提取不同层次的特征,U-Net 会在这些地方“取走”特征做跳跃连接
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
3. 为什么 U-Net 要看这个文件?
在 nets/unet.py 中,你会看到类似这样的逻辑:
# 在 unet 的 forward 函数里
feat1 = self.resnet.layer1(x) # 拿到浅层特征(边缘、颜色)
feat2 = self.resnet.layer2(feat1) # 拿到中层特征(形状)
feat3 = self.resnet.layer3(feat2) # 拿到深层特征(语义)
代码中关键参数的意思:
-
inplanes:输入到当前层的通道数(水的流量)。 -
planes:这一层基础的输出通道数。 -
stride(步长):如果等于 2,图片的长宽就会缩小一半。 -
padding(填充):在图片外围补一圈 0,保证卷积后图片尺寸不会因为边缘损耗而莫名缩小。 -
bias=False:因为后面接了BatchNorm,偏置项(bias)会被抵消掉,所以为了省计算量直接关掉。 -
直观总结:
-
卷积层 (Conv):是过滤器,负责找特征。
-
标准化层 (BN):是调节器,防止模型练着练着“走火入魔”(梯度爆炸)。
-
残差连接 (
out += identity):是 ResNet 的神来之笔,它允许信息直接跨层传递,解决了深层网络难以训练的问题。 -
_make_layer:是一个自动化工厂,你告诉它你需要多少个Bottleneck,它就自动帮你串联起来。
nets/vgg.py 在这个项目中被用作 U-Net 的特征提取器(Backbone)。VGG16 的结构非常规律,由一系列“卷积+激活+池化”的积木块组成。
为了让 U-Net 实现“跳跃连接”,这个 VGG 的实现与标准的分类 VGG 不同,它会在中间阶段把特征图“截断”并输出。
1. 构造卷积层序列的工具函数 (make_layers)
这个函数是按照预定义的配置表,自动像搭积木一样叠放卷积层。
import torch.nn as nn
def make_layers(cfg, batch_norm=False):
# cfg: 一个列表,数字代表卷积核数量,'M' 代表最大池化层。
# batch_norm: 是否使用标准归一化(通常选 False 以节省内存)。
layers = [] # 创建一个空列表,用来存放层级。
in_channels = 3 # 输入通道数,初始是 RGB 3 通道。
for v in cfg:
if v == 'M':
# 如果是 'M',就加一个 2x2 的最大池化层,图片长宽缩小一半。
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
else:
# 否则就是一个卷积层:3x3 卷积核,保持图片尺寸(padding=1)。
conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
if batch_norm:
# 如果开启 BN,则按顺序加入:卷积 -> 标准化 -> 激活。
layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
else:
# 否则只加入:卷积 -> 激活。
layers += [conv2d, nn.ReLU(inplace=True)]
in_channels = v # 更新输入通道数,下一层的输入等于这一层的输出。
return nn.Sequential(*layers) # 将列表里的层组合成一个连续的神经网络。
2. VGG16 主类结构 (VGG 类)
这里定义了 VGG16 的具体层数配置。
# 这是 VGG16 的经典配置:数字代表通道数,'M' 代表下采样
cfgs = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M']
class VGG(nn.Module):
def __init__(self, features, num_classes=1000, image_set='voc'):
super(VGG, self).__init__()
# features: 就是上面 make_layers 生成的那一长串卷积层。
self.features = features
# 对于 U-Net 来说,我们通常不需要下面这些全连接层(分类头),
# 但为了代码完整性,作者保留了这部分结构。
self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) # 自适应平均池化。
self.classifier = nn.Sequential(
nn.Linear(512 * 7 * 7, 4096), # 全连接层。
nn.ReLU(True),
nn.Dropout(), # 随机失活,防止过拟合。
nn.Linear(4096, 4096),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(4096, num_classes), # 映射到最终的分类数量。
)
3. 前向传播:为 U-Net 量身定制 (forward)
这是初学者最需要关注的地方。标准的 VGG 只返回最后的结果,但 U-Net 需要五个不同阶段的特征图。
def forward(self, x):
# x: 输入的原始图片 [batch, 3, 512, 512]
# 阶段 1:经过前两个卷积层,图片还是 512x512
feat1 = self.features[:4](x)
# 阶段 2:经过第一次下采样,变为 256x256
feat2 = self.features[4:9](feat1)
# 阶段 3:经过第二次下采样,变为 128x128
feat3 = self.features[9:16](feat2)
# 阶段 4:经过第三次下采样,变为 64x64
feat4 = self.features[16:23](feat3)
# 阶段 5:经过第四次下采样,变为 32x32
feat5 = self.features[23:30](feat4)
# 最后,把这 5 个不同大小的特征图全部返回。
# U-Net 的右侧(上采样部分)会像接力赛一样,挨个把它们接过去。
return [feat1, feat2, feat3, feat4, feat5]
核心函数意思总结:
-
nn.Conv2d(..., padding=1): 卷积。由于设置了padding=1,它在提取特征的同时,不会改变图片的长宽尺寸。 -
nn.MaxPool2d: 最大池化。它的作用是“降维”,就像把 512 像素的图拍扁成 256 像素,保留最明显的特征(比如最亮的点)。 -
nn.ReLU(inplace=True): 激活函数。inplace=True表示直接在原始内存上修改,这样可以节省显存,防止在训练大模型时内存溢出。 -
self.features[:4]: 这是 Python 的切片语法。作者根据 VGG 的固定结构,算好了前 4 层是第一个特征阶段,16 到 23 层是第四个阶段。
为什么 VGG 适合做 Backbone?
VGG 的结构非常直观:它通过不断的池化('M')让特征图越来越小,同时通过卷积让通道数(64 -> 128 -> 512)越来越多。这符合人类视觉系统的逻辑:看的范围越来越广(感受野变大),但看到的内容越来越抽象(从线条变成物体概念)。
2.utils/ (工具箱)
-
职责:存放所有辅助训练和测试的脚本。
-
dataloader.py:负责把硬盘里的图片加载到内存,并进行格式转换(变成张量)。 -
loss.py:定义“损失函数”,即如何判断模型画得准不准。放在了nets/unet_training.py这个文件里。 -
utils_metrics.py:计算评价指标(如 mIoU、准确率等)。 -
callbacks.py:在训练过程中自动保存模型、记录日志。
以下是 nets/unet_training.py 中关键代码的逐行解释:
1. 损失函数类:CE_Loss (交叉熵损失)
这是语义分割最基础的损失函数,用来判断每个像素点分类是否正确。
def CE_Loss(inputs, target, cls_weights, num_classes=21):
# inputs: 模型预测的结果
# target: 真实的标签图(Ground Truth)
# cls_weights: 类别权重,如果某些物体很小,可以给它更高的权重
n, c, h, w = inputs.size() # 获取预测图的:数量、类别数、高度、宽度
nt, ht, wt = target.size() # 获取标签图的尺寸
# 如果尺寸对不上,进行插值缩放,保证能对比
if h != ht or w != wt:
inputs = F.interpolate(inputs, size=(ht, wt), mode="bilinear", align_corners=True)
# 重点:转置并改变形状,为了符合 PyTorch 交叉熵函数的输入要求
inputs = inputs.transpose(1, 2).transpose(2, 3).contiguous().view(-1, num_classes)
target = target.view(-1)
# 调用 PyTorch 内置的交叉熵计算误差
loss = nn.CrossEntropyLoss(weight=cls_weights, ignore_index=num_classes)(inputs, target)
return loss
2. 损失函数类:Dice_Loss (Dice 损失)
这是 U-Net 的灵魂损失函数。它专门用于处理“目标很小、背景很大”的情况(比如医学图像里的小肿瘤)。
def Dice_Loss(inputs, target, beta=1, smooth=1e-5):
# inputs: 预测结果,target: 真实标签
# smooth: 一个很小的数,防止分母为 0 导致报错
n, c, h, w = inputs.size()
nt, ht, wt = target.size()
# 同样进行尺寸对齐
if h != ht or w != wt:
inputs = F.interpolate(inputs, size=(ht, wt), mode="bilinear", align_corners=True)
# 将预测值转化为 0-1 之间的概率值
inputs = torch.softmax(inputs, dim=1)
# 【核心逻辑】计算预测结果和真实结果的交集(Intersection)
# 就像两个圆的重合面积越大,Dice 损失就越小
tp = torch.sum(target * inputs, axis=[0, 2, 3])
fp = torch.sum(inputs, axis=[0, 2, 3]) - tp
fn = torch.sum(target, axis=[0, 2, 3]) - tp
# 根据公式计算 Dice 系数
score = ((1 + beta ** 2) * tp + smooth) / ((1 + beta ** 2) * tp + beta ** 2 * fn + fp + smooth)
dice_loss = 1 - torch.mean(score) # 1 减去相似度,就是误差值
return dice_loss
3. 其他重要函数意思
除了损失函数,这个文件里还有几个初学者必须知道的函数:
-
weights_init(权重初始化):-
意思:神经网络刚出生时是一张白纸,这个函数负责给模型里的每一个“神经元”分配一个初始的随机数字。
-
作用:合理的初始值(如 Xavier 或 Kaiming 初始化)能让模型练得更快,不至于在一开始就跑偏。
-
-
get_lr_scheduler(学习率调度器):-
意思:学习率就像模型走路的“步子”。
-
作用:刚开始练的时候步子大一点(快速接近目标),快练好时步子小一点(精细打磨),这个函数就是控制步子大小的。
-
-
set_optimizer_lr(设置优化器学习率):-
意思:手动把计算好的步长(学习率)告诉给“修理工”(优化器)。
-
总结:为什么放在 nets 而不是 utils?
在 bubbliiiing 的架构逻辑里:
-
utils放置的是通用的、不依赖于具体模型的工具(比如图片缩放、计算 mIoU 指标)。 -
nets/unet_training.py放置的是专门针对 U-Net 训练定制的代码。因为不同的网络(如 YOLO, DeepLab)往往需要不同的 Loss 组合方式。
utils/dataloader.py 文件在深度学习项目中相当于“物流部门”。它的任务是从硬盘读取原始图片和标签(Mask),把它们处理成统一的大小,进行各种“变花样”的数据增强,最后打包成机器能听懂的张量(Tensor)送入 GPU。
1. 导入库与预处理函数
import cv2
import numpy as np
import torch
from PIL import Image
from torch.utils.data.dataset import Dataset
def preprocess_input(image):
image /= 255.0 # 将像素值从 0-255 压缩到 0-1 之间,这叫“归一化”,有助于模型收敛。
return image
2. UnetDataset 类:数据的“加工厂”
这个类继承了 PyTorch 的 Dataset。它是告诉程序:给出一个索引 i,你应该去哪里找第 i 张图,并把它加工成什么样。
(1) 初始化部分 __init__
class UnetDataset(Dataset):
def __init__(self, annotation_lines, input_shape, num_classes, train, dataset_path):
super(UnetDataset, self).__init__()
self.annotation_lines = annotation_lines # 存储所有图片路径的列表(通常来自 train.txt)。
self.length = len(self.annotation_lines) # 数据集的总长度(图片总数)。
self.input_shape = input_shape # 模型要求的图片尺寸,如 [512, 512]。
self.num_classes = num_classes # 分类的数量。
self.train = train # 是否为训练模式(训练模式下会开启随机翻转等增强)。
self.dataset_path = dataset_path # 数据集所在的根目录。
(2) 获取长度 __len__
def __len__(self):
return self.length # 告诉 PyTorch 这个数据集里一共有多少张照片。
(3) 核心获取函数 __getitem__
这是整个文件最重要的部分。每当训练需要一张新图时,就会调用这个函数。
def __getitem__(self, index):
# 1. 根据索引找到对应的图片名称
line = self.annotation_lines[index].split()
# 2. 从硬盘打开原始图片,并转换成 RGB 模式
image = Image.open(os.path.join(os.path.join(self.dataset_path, "VOC2007/JPEGImages"), line[0] + ".jpg"))
# 3. 从硬盘打开标签图(即 Mask),转换成灰度模式(L)
# 标签图中,0 通常代表背景,1 代表第一个目标,以此类推。
png = Image.open(os.path.join(os.path.join(self.dataset_path, "VOC2007/SegmentationClass"), line[0] + ".png"))
# 4. 数据增强(Data Augmentation)
# 如果是训练模式,会随机对图片和标签同时进行缩放、翻转、颜色抖动等。
# 目的是让模型见多识广,不至于只认识这一张图。
if self.train:
image, png = self.get_random_data(image, png, self.input_shape, random = self.train)
else:
# 如果是验证模式,只进行不失真的大小调整。
image, png = self.get_random_data(image, png, self.input_shape, random = False)
# 5. 格式转换:从图片格式转为数组
image = np.transpose(preprocess_input(np.array(image, np.float32)), (2, 0, 1))
# 为什么要 transpose?因为图片是 [H, W, C],而 PyTorch 要求 [C, H, W](通道数在前)。
png = np.array(png)
# 将标签图里大于类别数的像素点强行设为“背景”,防止数据标注错误。
png[png >= self.num_classes] = self.num_classes
# 6. 返回处理好的图片和标签
return image, png
3. unet_dataset_collate:打包函数
如果说 __getitem__ 是加工一件商品,那么 collate 就是把加工好的商品装进集装箱(Batch)。
def unet_dataset_collate(batch):
images = []
pngs = []
for img, png in batch:
images.append(img)
pngs.append(png)
# 将一组图片和标签堆叠在一起,形成一个批次(Batch)。
# 比如 batch_size 是 4,那这里就会返回 [4, 3, 512, 512] 的张量。
images = torch.from_numpy(np.array(images)).type(torch.FloatTensor)
pngs = torch.from_numpy(np.array(pngs)).long()
return images, pngs
直观总结:
-
Dataset就像是一个“菜单”:它知道厨房里有什么菜(图片),并且知道怎么把每一道菜洗干净、切好(预处理)。 -
input_shape:这是厨房的“统一盘子尺寸”。不管原始图片多大,最后都要缩放到这个尺寸,否则模型吃不下。 -
get_random_data:这是“加佐料”。通过旋转、裁剪,把一张图变成“多张图”,让模型学习得更扎实。 -
np.transpose:这是“换个摆盘姿势”。模型这个“客人”很挑剔,它习惯先看通道(RGB),再看像素坐标。
3.model_data/ (模型资产)
-
职责:存放静态资源。
-
.pth文件:这是权重文件,保存了模型学到的“经验”。 -
cls_classes.txt:定义了你的分类目标(比如:第一类是背景,第二类是细胞,第三类是血管)。
4.VOCdevkit/ (数据集仓库)
-
-
职责:存放你的训练素材。
-
VOC2007/JPEGImages:存放原始彩色图片。 -
VOC2007/SegmentationClass:存放对应的标签图(Mask),即告诉机器哪里是什么。
-
-
img/-
职责:仅用于展示项目 README 文档里的示意图,对代码运行没有实际功能影响。
-
5.img/
二、根目录下的关键脚本 (The Scripts)
1.train.py (训练总控制台)
-
这是最核心的文件。它把
nets的模型、utils的数据加载器、model_data的权重全部组合起来,开始训练。 -
初学者注意:你大部分的学习率、训练次数(Epoch)、批次大小(Batch Size)都在这里设置。
1. 参数配置区 (Configuration)
这部分代码主要是设置“游戏规则”,比如用什么显卡、分多少类。
# 是否使用 GPU 进行训练。如果显卡支持,设为 True 能极大地提高速度。
Cuda = True
# 分类个数 + 1(1是背景)。比如你要识别猫和狗,这里就是 2+1=3。
num_classes = 21
# 主干网络选择。可选 'vgg' 或 'resnet50'。
backbone = "vgg"
# 是否使用预训练权重。建议初学者开启,这样模型就像是“读过高中”后再来学专业课,比从零开始快。
pretrained = False
# 预训练权重文件的路径。
model_path = "model_data/unet_vgg_voc.pth"
# 输入图片的大小。必须是 32 的倍数(因为 U-Net 有多次下采样,每次缩小一半)。
input_shape = [512, 512]
2. 训练阶段设置 (Training Phases)
这个仓库采用了**“冻结训练”**的策略。这是一种非常聪明的训练方法:先冻结主干特征提取网络(Backbone),只练后面的解码器;等稳定了,再全线开启训练。
# 冻结阶段:只训练 U-Net 的加强特征提取部分。
Init_Epoch = 0 # 起始世代
Freeze_Epoch = 50 # 冻结训练到第几个世代结束
Freeze_batch_size = 2 # 冻结阶段每批次喂给模型多少张图(显存小就调小点)
# 解冻阶段:全网络一起训练,精细打磨参数。
UnFreeze_Epoch = 100 # 总训练世代
Unfreeze_batch_size = 2 # 解冻后每批次的数量
3. 模型与初始化 (Model & Weights)
这部分代码负责把模型“造”出来,并给它装载记忆(权重)。
# 实例化 U-Net 模型
model = Unet(num_classes=num_classes, pretrained=pretrained, backbone=backbone).train()
# 如果没有预训练权重,就进行随机初始化(给神经元分配随机的初始数字)
if not pretrained:
weights_init(model)
# 如果指定了 model_path,就加载已经练过的权重
if model_path != '':
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_dict = model.state_dict()
pretrained_dict = torch.load(model_path, map_location=device)
# 只加载匹配的部分,防止报错
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
4. 数据加载 (Data Loading)
这里负责把 voc_annotation.py 生成的名单读进来,并交给“搬运工”。
# 读取训练集和验证集的图片路径和标签信息
with open(train_annotation_path, "r") as f:
train_lines = f.readlines()
with open(val_annotation_path, "r") as f:
val_lines = f.readlines()
# 实例化 Dataset 类(还记得 dataloader.py 吗?这里就是调用它)
train_dataset = UnetDataset(train_lines, input_shape, num_classes, True, VOCdevkit_path)
val_dataset = UnetDataset(val_lines, input_shape, num_classes, False, VOCdevkit_path)
# 实例化 DataLoader,负责开启多线程,把数据打包成 Batch 喂给 GPU
gen = DataLoader(train_dataset, shuffle=True, batch_size=batch_size, num_iter=num_workers, pin_memory=True, drop_last=True, collate_fn=unet_dataset_collate)
gen_val = DataLoader(val_dataset, shuffle=True, batch_size=batch_size, num_iter=num_workers, pin_memory=True, drop_last=True, collate_fn=unet_dataset_collate)
5. 核心训练循环 (The Training Loop)
这是最热闹的地方。程序会在这里跑几十到上百个 Epoch(世代)。
for epoch in range(start_epoch, end_epoch):
# 如果到了解冻的时机,就把主干网络的锁定解开
if epoch >= Freeze_Epoch and not UnFreeze_Flag:
batch_size = Unfreeze_batch_size
# ... 重新设置优化器和学习率 ...
UnFreeze_Flag = True
# 调用 fit_one_epoch 函数。这是实际干活的地方。
# 每一代训练都会经历:读图 -> 预测 -> 算误差 -> 找差距 -> 改参数。
fit_one_epoch(model_train, model, loss_history, optimizer, epoch,
gen, gen_val, end_epoch, Cuda, dice_loss, focal_loss, cls_weights, num_classes)
6. 关键函数 fit_one_epoch 里的逻辑意思
在 train.py 内部或调用的工具函数中,这几行是灵魂:
-
optimizer.zero_grad():擦除黑板。把上一轮计算的梯度(差距)清空。 -
outputs = model_train(images):模型考试。模型根据输入的图给出它认为的分割结果。 -
loss = loss_calc(outputs, pngs):对答案。计算预测图和真实标签图之间的差距。 -
loss.backward():反向传播。把差距一层层传回去,告诉前面的神经元:“你刚才猜错了”。 -
optimizer.step():调整。根据刚才传回来的信息,微调模型内部的参数螺丝。
2.predict.py (预测/推理入口)
当你训练好模型后,运行这个文件。它会启动一个窗口或循环,让你输入图片并查看分割结果。
3.unet.py (预测逻辑包装)
注意:根目录下的这个 unet.py 不同于 nets/ 里的。它是一个“工具类”,专门用于加载模型权重并对单张图片进行预处理,方便 predict.py 调用。
4.voc_annotation.py (数据索引生成器)
在训练前必须运行。它会扫描 VOCdevkit,并生成 train.txt 等文件,告诉程序:“这 800 张图用来学习,剩下的 200 张用来考试。
5.get_miou.py (性能体检脚本)
用于量化评估你的模型到底有多优秀,它会计算出平均交并比(mIoU),这是语义分割领域最权威的指标
三、函数级别的“黑话”解释
-
__init__:积木准备。比如定义“我要一个卷积层”、“我要一个池化层”。 -
forward:数据流动。定义图片进入网络后,先过哪个门,再进哪个房。 -
Dataset/DataLoader:搬运工。负责从硬盘一张张地把图片塞给 GPU。 -
Optimizer(优化器):修理工。根据误差,不断旋转模型内部的“螺丝钉”(权重),让误差越来越小。
四、基于模型修改自己的图片
1. 准备数据集(最关键的一步)
你需要按照 VOC 格式 整理你的图片。将你的数据放入 VOCdevkit/VOC2007 文件夹中:
-
JPEGImages:存放所有的原始彩色图片(建议统一为
.jpg格式)。 -
SegmentationClass:存放所有的标签图片(Mask)(必须是
.png格式)。-
注意:标签图里的像素值必须是数字。比如背景像素是 0,目标 A 是 1,目标 B 是 2。在视觉上,这些图片看起来可能是全黑的,这是正常的。
-
2. 修改类别文件 (model_data/cls_classes.txt)
打开这个文件,把里面的内容换成你自己的分类目标。
-
注意:第一行通常保留为
background(背景),后面紧跟着你的目标。 -
示例:如果你要识别肺部结节,文件内容应为:
background
nodule
3. 生成索引文件 (voc_annotation.py)
模型训练前需要知道哪些图是用来“学习”的,哪些是用来“考试”的。
-
打开根目录下的
voc_annotation.py。 -
修改其中的
classes_path指向你刚刚改好的cls_classes.txt。 -
运行该脚本。它会在根目录生成
2007_train.txt和2007_val.txt,里面记录了图片的路径。
4. 配置训练参数 (train.py)
这是你的控制中心。在运行之前,检查以下几个核心参数:
-
num_classes:改为你的总类别数(背景 + 目标数量)。 -
backbone:根据你的显卡配置选择vgg或resnet50。 -
model_path:如果你是第一次训练,建议填""(从零开始)或者指向作者提供的预训练权重(收敛更快)。 -
input_shape:根据你的图片大小修改,通常设为[512, 512]。
5. 开始训练
在终端或 PyCharm 中直接运行 train.py。
-
观察日志:你会看到
Loss(误差)在不断下降。 -
产出物:训练完成后,在
logs/文件夹下会生成一系列.pth文件。这些就是你辛苦训练出来的“大脑”。
6. 使用你的模型进行预测
当你训练好模型后,想要测试一张新图片:
-
修改 根目录下的
unet.py(注意不是nets/里的):-
将
model_path指向你刚才生成的最好的.pth文件。 -
将
classes_path指向你的cls_classes.txt。
-
-
运行
predict.py。 -
根据提示输入图片的路径,你就能看到模型对你新图片的分割效果了。
给初学者的 3 个避坑小贴士:
-
尺寸一致性:确保你的标签图 (Mask) 和原始图片的长宽比例是一致的,否则模型会学得很痛苦。
-
显存溢出 (OOM):如果运行
train.py报错提示内存不足,请尝试调小batch_size(比如从 4 改为 2)。 -
标签格式:很多初学者直接用彩色分割图当标签,这是不对的。标签图必须是单通道的索引图。你可以用
PIL库通过代码检查一下标签图的像素值是否在[0, num_classes-1]范围内。
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐

所有评论(0)