【UNet3+】遥感影像分割
文章目录
1. 项目准备
1.1. 问题导入
-
图像分割
在计算机视觉领域,图像分割指的是将数字图像细分为多个图像子区域的过程,其目的是简化或改变图像的表示形式,使得图像更容易理解和分析。图像分割通常用于定位图像中的物体和边界,更精确的说,它是对图像中的每个像素加标签的一个过程,这一过程使得具有相同标签的像素具有某种共同视觉特性。 -
实验任务
本例简要介绍如何使用UNet3+
模型实现遥感影像分割,我们需要将遥感影像中存在的建筑物分割、标注出来。
1.2. 数据集简介
武汉大学2019年发布了Aerial Imagery Dataset,该数据集原始航拍数据来自新西兰土地信息服务网站,数据集共有8,189张具有0.3m分辨率、大小为512×512像素的遥感图像,数据集共包含18,7000座建筑物。数据集包含存放遥感图像的image文件夹和存放分割图像的label文件夹,例图如下图所示:
这是数据集的下载链接:Aerial Imagery Dataset - AI Studio
2. UNet3+模型
2.1. 背景介绍
Hinton等人(2006)提出了一种Encoder-Decoder
结构,当时这个Encoder-Decoder
结构提出的主要作用并不是分割,而是压缩图像和去噪声。输入是一幅图,经过下采样的编码,得到一串比原先图像更小的特征,相当于压缩,然后再经过一个解码,理想状况就是能还原到原来的图像。
后来,Jonathan等人(2015)在论文中基于该拓扑结构提出了FCN
(Fully Convolutional Networks)。自提出以后,FCN
就成为了语义分割的基本框架,后续算法(如UNet
)其实都是在这个框架中改进而来。其中的UNet
由于其对称结构简单易懂,且模型效果优秀,于是就成为了许多网络改进的范本之一。
UNet
(2015)是医学影像分割领域应用最广泛的的网络,它使用跳跃连接(skip connection)来结合来自解码器的高级语义特征图和来自编码器的相应尺度的低级语义特征图,其性能和网络中多尺度特征的融合密切相关。为了避免纯跳跃连接在语义上融合不相似的特征,此后的UNet++
(2018)引入嵌套结构和密集的跳跃连接对网络进行了改进。而最新的UNet3+
(2020)通过全尺度的跳跃连接和深度监督(deep supervisions)来融合深层和浅层特征的同时对各个尺度的特征进行监督,它还可以在减少网络参数的同时提高计算效率。
2.2. 模型介绍
Huang等人(2020)在论文中提出了UNet3+
模型,Huang等人使用该模型在肝脏和脾脏数据集上进行广泛的实验,发现它的表现得到了提高并且超过了很多baselines。下面介绍一下UNet3+
模型的三个创新点:
(1) 全尺度跳跃连接
UNet3+
充分利用多尺度特征,引入全尺度跳跃连接(Full-scale Skip Connections),该连接结合了来自全尺度特征图的低级语义和高级语义,并且参数更少。
在许多分割实验的研究中,不同尺度的特征图展示着不同的信息:低级语义特征图捕捉丰富的空间信息,能够突出物体的边界;而高级语义特征图则体现了物体所在的位置信息。为此,UNet3+
的每个解码器层都融合了来自编码器中的小尺度和同尺度的低级语义特征图,以及来自解码器的大尺度的高级语义特征图,这些特征图捕获了全尺度下的细粒度语义和粗粒度语义。
如上图所示,为了构造特征图 X D e 3 X_{De}^3 XDe3,第3层解码器不仅需要接收同尺度编码器层的特征图 X E n 3 X_{En}^3 XEn3,还需要接收小尺度编码器层的特征图 X E n 1 X_{En}^1 XEn1和 X E n 2 X_{En}^2 XEn2(为了统一特征图的分辨率,在接收前需进行下采样操作),同时也需要接收大尺度解码器层的特征图 X D e 5 X_{De}^5 XDe5和 X D e 4 X_{De}^4 XDe4(为了统一特征图的分辨率,在接收前需进行上采样操作)。在统一特征图的分辨率之后,我们还需用64个3×3的卷积核统一特征图的数量,以减少多余信息。在完成上述操作之后,我们就能用“通道维度拼接”的方法融合特征了,融合上述5个特征后便得到了320个特征图。接着,我们用320个3×3的卷积核对其进行卷积操作,最后通过批正则化(Batch Normalize)和ReLU(Rectified Linear Unit)便得到 X D e 3 X_{De}^3 XDe3。
于是,特征图
X
D
e
i
X_{De}^i
XDei的计算公式可总结为:
其中,变量
i
i
i表示沿着编码方向的编/解码层的编号,变量
N
N
N表示编码器的总数,函数
C
C
C代表卷积操作,函数
U
U
U和
D
D
D分别代表上采样和下采样操作,函数
H
H
H代表“特征融合”机制(即1个卷积层+1个批正则化层+1个ReLU函数层),
[
]
[ ]
[ ]代表“通道维度拼接”。
(2) 全尺度深度监督
UNet3+
采用全尺度深度监督(Full-scale Deep Supervision),从全面的聚合特征图中学习层次表示,优化了混合损失函数以增强器官边界。
不同于UNet++
对全分辨率特征图进行深度监督,UNet3+
中每个解码器都有一个侧输出,它是由真实标准(ground truth)来进行监督的。为实现深度监督,每个解码器的侧输出都会被送入1个3×3卷积层、1个双线性上采样层以及1个sigmoid函数层中。
为了进一步增强器官边界,UNet3+
提出了一种多尺度结构相似指数(Multi-Scale Structural Similarity index,MS-SSIM)损失函数来赋予模糊边界更大的权重。由于区域分布差异越大,MS-SSIM值越高,故UNet3+
将更加关注模糊边界。假设我们从分割结果P
和真实标准G
中分别裁剪了两个N×N的块
p
p
p和
g
g
g,并且有
p
=
{
p
j
:
j
=
1
,
.
.
.
,
N
2
}
p =\{p_j : j = 1,...,N^2\}
p={pj:j=1,...,N2}和
g
=
{
g
j
:
j
=
1
,
.
.
.
,
N
2
}
g =\{g_j : j = 1,...,N^2\}
g={gj:j=1,...,N2},那么我们可定义
p
p
p和
g
g
g的MS-SSIM损失函数为:
其中,
M
M
M表示尺度的总数(原作者将尺度总数设为5),
μ
p
,
μ
g
μ_p, μ_g
μp,μg和
σ
p
,
σ
g
σ_p, σ_g
σp,σg分别表示
p
p
p和
g
g
g的均值和方差,
σ
p
g
σ_{pg}
σpg则表示
p
p
p和
g
g
g的协方差。
β
m
,
γ
m
β_m, γ_m
βm,γm分别表示这两部分在每个尺度中的相对重要性程度,而设置小常量
C
1
=
0.01
2
,
C
2
=
0.03
2
C_1 = {0.01}^2, C_2 = {0.03}^2
C1=0.012,C2=0.032的目的是避免出现除以0的异常情况。
UNet3+
融合了focal损失函数、MS-SSIM损失函数和IoU损失函数,提出了一种用于三个不同层次级别(像素级、块级、图像级)分割的混合损失函数,它能捕获边界清晰的大尺度结构和精细结构。该混合损失函数的定义为:
(3) 分类指导模块
UNet3+
提出分类指导模块(Classification-guided Module,CGM),通过图像级分类联合训练,减少非器官图像的过度分割。
在大多数医学图像分割实验中,由于来自背景的噪声信息停留在较浅层次中,这导致非器官图像出现过度分割的现象。为解决这一问题,UNet3+
增加了一个预测输入图像是否有器官的额外分类任务。
如上图所示,最深层的特征图 X D e 5 X_{De}^5 XDe5依次通过Dropout层、1×1卷积层、最大池化层和Sigmoid函数层,以得到代表 X D e 5 X_{De}^5 XDe5中有/无器官概率的二维张量。然后,我们可以用argmax函数处理二维张量,以得到仅包含0和1的二分类结果。接着,我们用这些分类结果与每个侧边分割输出相乘,以得到修正后的侧边分割输出。我们可以通过优化二分类的交叉损失函数,来获得更准确的分类结果,以此指导模型避免对非器官图像过度分割。
3. 代码实现
3.0. 前期准备
- 导入模块
注意:本案例仅适用于
Paddle 2.0+
版本,建议根据显存大小合理调整超参数batch_size
和img_size
的大小!
import cv2
import os
import random
import zipfile
import numpy as np
from copy import deepcopy
from PIL import Image, ImageEnhance
from matplotlib import pyplot as plt
from matplotlib.colors import LinearSegmentedColormap as LSC
import paddle
from paddle import nn
from paddle.framework import ParamAttr
from paddle.io import DataLoader, Dataset
from paddle.nn import initializer as I, functional as F
from paddle.optimizer import Adam
from paddle.optimizer.lr import CosineAnnealingDecay
- 设置超参数
BATCH_SIZE = 4 # 每批次的样本数
EPOCHS = 16 # 模型训练的总轮数
LOG_GAP = 500 # 输出训练信息的间隔
N_CLASSES = 2 # 图像分类种类数量
IMG_SIZE = (256, 256) # 图像缩放尺寸
INIT_LR = 3e-4 # 初始学习率
SRC_PATH = "./data/data69911/BuildData.zip" # 压缩包路径
DST_PATH = "./data" # 解压路径
DATA_PATH = { # 实验数据集路径
"img": DST_PATH + "/image", # 正常图像
"lab": DST_PATH + "/label", # 分割图像
}
INFER_PATH = { # 预测数据集路径
"img": ["./work/1.jpg", "./work/2.jpg"], # 正常图像
"lab": ["./work/1.png", "./work/2.png"], # 分割图像
}
MODEL_PATH = "UNet3+.pdparams" # 模型参数保存路径
3.1. 数据准备
- 解压数据集
由于数据集中的数据是以压缩包的形式存放的,因此我们需要先解压数据压缩包。
if not os.path.isdir(DATA_PATH["img"]) or not os.path.isdir(DATA_PATH["lab"]):
z = zipfile.ZipFile(SRC_PATH, "r") # 以只读模式打开zip文件
z.extractall(path=DST_PATH) # 解压zip文件至目标路径
z.close()
print("The dataset has been unpacked successfully!")
- 划分数据集
我们需要按9:1比例划分训练集和测试集,分别生成两个包含数据路径和标签路径映射关系的列表。
train_list, test_list = [], [] # 存放图像路径与标签路径的映射
images = os.listdir(DATA_PATH["img"]) # 统计数据集下的图像文件
for idx, img in enumerate(images):
lab = os.path.join(DATA_PATH["lab"], img.replace(".jpg", ".png"))
img = os.path.join(DATA_PATH["img"], img)
if idx % 10 != 0: # 按照1:9的比例划分数据集
train_list.append((img, lab))
else:
test_list.append((img, lab))
- 数据增强
数据増广(Data Augmentation),即数据增强,数据增强的目的主要是减少网络的过拟合现象,通过对训练图片进行变换可以得到泛化能力更强的网络,更好地适应应用场景。
由于实验模型较为复杂,直接训练容易发生过拟合,故在处理实验数据集时采用数据增强的方法扩充数据集的多样性。本实验中用到的数据增强方法有:随机改变亮度,随机改变对比度,随机改变饱和度,随机改变清晰度,随机旋转图像,随机翻转图像,随机加高斯噪声等。
def random_brightness(img, lab, low=0.5, high=1.5):
''' 随机改变亮度(0.5~1.5) '''
x = random.uniform(low, high)
img = ImageEnhance.Brightness(img).enhance(x)
return img, lab
def random_contrast(img, lab, low=0.5, high=1.5):
''' 随机改变对比度(0.5~1.5) '''
x = random.uniform(low, high)
img = ImageEnhance.Contrast(img).enhance(x)
return img, lab
def random_color(img, lab, low=0.5, high=1.5):
''' 随机改变饱和度(0.5~1.5) '''
x = random.uniform(low, high)
img = ImageEnhance.Color(img).enhance(x)
return img, lab
def random_sharpness(img, lab, low=0.5, high=1.5):
''' 随机改变清晰度(0.5~1.5) '''
x = random.uniform(low, high)
img = ImageEnhance.Sharpness(img).enhance(x)
return img, lab
def random_rotate(img, lab, low=0, high=360):
''' 随机旋转图像(0~360度) '''
angle = random.choice(range(low, high))
img, lab = img.rotate(angle), lab.rotate(angle)
return img, lab
def random_flip(img, lab, prob=0.5):
''' 随机翻转图像(p=0.5) '''
if random.random() < prob: # 上下翻转
img = img.transpose(Image.FLIP_TOP_BOTTOM)
lab = lab.transpose(Image.FLIP_TOP_BOTTOM)
if random.random() < prob: # 左右翻转
img = img.transpose(Image.FLIP_LEFT_RIGHT)
lab = lab.transpose(Image.FLIP_LEFT_RIGHT)
return img, lab
def random_noise(img, lab, low=0, high=10):
''' 随机加高斯噪声(0~10) '''
img = np.asarray(img)
sigma = np.random.uniform(low, high)
noise = np.random.randn(img.shape[0], img.shape[1], 3) * sigma
img = img + np.round(noise).astype('uint8')
# 将矩阵中的所有元素值限制在0~255之间:
img[img > 255], img[img < 0] = 255, 0
img = Image.fromarray(img)
return img, lab
def image_augment(img, lab, prob=0.5):
''' 叠加多种数据增强方法 '''
opts = [random_brightness, random_contrast, random_color, random_flip,
random_noise, random_rotate, random_sharpness,] # 数据增强方法
for func in opts:
if random.random() < prob:
img, lab = func(img, lab) # 处理图像和标签
return img, lab
- 数据预处理
我们需要对数据集图像进行缩放和归一化处理。
class MyDataset(Dataset):
''' 自定义的数据集类
* `label_list`: 图像路径和标签路径的映射列表
* `transform`: 图像处理函数
* `augment`: 数据增强函数
'''
def __init__(self, label_list, transform, augment=None):
super(MyDataset, self).__init__()
random.shuffle(label_list) # 打乱映射列表
self.label_list = label_list
self.transform = transform
self.augment = augment
def __getitem__(self, index):
''' 根据位序获取对应数据 '''
img_path, lab_path = self.label_list[index]
img, lab = self.transform(img_path, lab_path, self.augment)
return img, lab
def __len__(self):
''' 获取数据集的样本总数 '''
return len(self.label_list)
def data_mapper(img_path, lab_path, augment=None):
''' 图像处理函数 '''
img = Image.open(img_path).convert("RGB")
lab = cv2.cvtColor(cv2.imread(lab_path), cv2.COLOR_RGB2GRAY)
# 将标签文件进行灰度二值化:
_, lab = cv2.threshold(src=lab, # 待处理图片
thresh=170, # 起始阈值
maxval=255, # 最大阈值
type=cv2.THRESH_BINARY_INV) # 算法类型
lab = Image.fromarray(lab).convert("L") # 转换为PIL.Image
# 将图像缩放为IMG_SIZE大小的高质量图像:
img = img.resize(IMG_SIZE, Image.ANTIALIAS)
lab = lab.resize(IMG_SIZE, Image.ANTIALIAS)
if augment is not None: # 数据增强
img, lab = augment(img, lab)
# 将图像转为numpy数组,并转换图像的格式:
img = np.array(img).astype("float32").transpose((2, 0, 1))
lab = np.array(lab).astype("int32")[np.newaxis, ...]
# 将图像数据归一化,并转换成Tensor格式:
img = paddle.to_tensor(img / 255.0)
lab = paddle.to_tensor(lab // 255)
return img, lab
train_dataset = MyDataset(train_list, data_mapper, image_augment) # 训练集
test_dataset = MyDataset(test_list, data_mapper, augment=None) # 测试集
- 定义数据提供器
我们需要分别构建用于训练和测试的数据提供器,其中训练数据提供器是乱序、按批次提供数据的。
train_loader = DataLoader(train_dataset, # 训练数据集
batch_size=BATCH_SIZE, # 每批次的样本数
num_workers=4, # 加载数据的子进程数
shuffle=True, # 打乱数据集
drop_last=False) # 不丢弃不完整的样本批次
test_loader = DataLoader(test_dataset, # 测试数据集
batch_size=1, # 每批次的样本数
num_workers=4, # 加载数据的子进程数
shuffle=False, # 不打乱数据集
drop_last=False) # 不丢弃不完整的样本批次
3.2. 网络配置
本次实验使用的是UNet3+
模型,UNet
系列模型包含下采样(编码器,特征提取)和上采样(解码器,分辨率还原)两个阶段,因模型结构比较像U型而得名。
- 定义网络初始化函数
def init_weights(net, init_type="normal"):
''' 初始化网络的权重与偏置
* `net`: 需要初始化的神经网络层
* `init_type`: 初始化机制(normal/xavier/kaiming/truncated)
'''
if init_type == "normal":
attr = ParamAttr(initializer=I.Normal())
elif init_type == "xavier":
attr = ParamAttr(initializer=I.XavierNormal())
elif init_type == "kaiming":
attr = ParamAttr(initializer=I.KaimingNormal())
elif init_type == "truncated":
attr = ParamAttr(initializer=I.TruncatedNormal())
else:
error = "Initialization method [%s] is not implemented!"
raise NotImplementedError(error % init_type)
# 初始化网络层net的权重系数和偏置系数:
net.param_attr, net.bias_attr = attr, deepcopy(attr)
- 构建编码器
class Encoder(nn.Layer):
''' 用于构建编码器模块
* `in_size`: 输入通道数
* `out_size`: 输出通道数
* `is_batchnorm`: 是否批正则化
* `n`: 卷积层数量(默认为2)
* `ks`: 卷积核大小(默认为3)
* `s`: 卷积运算步长(默认为1)
* `p`: 卷积填充大小(默认为1)
'''
def __init__(self, in_size, out_size, is_batchnorm,
n=2, ks=3, s=1, p=1):
super(Encoder, self).__init__()
self.n = n
for i in range(1, self.n+1): # 定义多层卷积神经网络
if is_batchnorm:
block = nn.Sequential(nn.Conv2D(in_size, out_size, ks, s, p),
nn.BatchNorm2D(out_size),
nn.ReLU())
else:
block = nn.Sequential(nn.Conv2D(in_size, out_size, ks, s, p),
nn.ReLU())
setattr(self, "block%d" % i, block)
in_size = out_size
for m in self.children(): # 初始化各层网络的系数
init_weights(m, init_type="kaiming")
def forward(self, x):
for i in range(1, self.n+1):
block = getattr(self, "block%d" % i)
x = block(x) # 进行前向传播运算
return x
- 构建解码器
class Decoder(nn.Layer):
''' 用于构建解码器模块
* `cur_stage`(int): 当前解码器所在层数
* `cat_size`(int): 统一后的特征图通道数
* `up_size`(int): 特征融合后的通道总数
* `filters`(list): 各卷积网络的卷积核数
* `ks`: 卷积核大小(默认为3)
* `s`: 卷积运算步长(默认为1)
* `p`: 卷积填充大小(默认为1)
'''
def __init__(self, cur_stage, cat_size, up_size,
filters, ks=3, s=1, p=1):
super(Decoder, self).__init__()
self.n = len(filters) # 卷积网络模块的个数
for idx, num in enumerate(filters):
idx += 1 # 待处理输出所在层数
if idx < cur_stage:
# he[idx]_PT_hd[cur_stage], Pool [ps] times
ps = 2 ** (cur_stage - idx)
block = nn.Sequential(nn.MaxPool2D(ps, ps, ceil_mode=True),
nn.Conv2D(num, cat_size, ks, s, p),
nn.BatchNorm2D(cat_size),
nn.ReLU())
elif idx == cur_stage:
# he[idx]_Cat_hd[cur_stage], Concatenate
block = nn.Sequential(nn.Conv2D(num, cat_size, ks, s, p),
nn.BatchNorm2D(cat_size),
nn.ReLU())
else:
# hd[idx]_UT_hd[cur_stage], Upsample [us] times
us = 2 ** (idx - cur_stage)
num = num if idx == 5 else up_size
block = nn.Sequential(nn.Upsample(scale_factor=us, mode="bilinear"),
nn.Conv2D(num, cat_size, ks, s, p),
nn.BatchNorm2D(cat_size),
nn.ReLU())
setattr(self, "block%d" % idx, block)
# fusion(he[]_PT_hd[], ..., he[]_Cat_hd[], ..., hd[]_UT_hd[])
self.fusion = nn.Sequential(nn.Conv2D(up_size, up_size, ks, s, p),
nn.BatchNorm2D(up_size),
nn.ReLU())
for m in self.children(): # 初始化各层网络的系数
init_weights(m, init_type="kaiming")
def forward(self, inputs):
outputs = [] # 记录各层的输出,以便于拼接起来
for i in range(self.n):
block = getattr(self, "block%d" % (i+1))
outputs.append( block(inputs[i]) )
hd = self.fusion(paddle.concat(outputs, 1))
return hd
- 定义网络结构
class UNet3Plus(nn.Layer):
''' UNet3+ with Deep Supervision and Class-guided Module
* `in_channels`: 输入通道数(默认为3)
* `n_classes`: 物体的分类种数(默认为2)
* `is_batchnorm`: 是否批正则化(默认为True)
* `deep_sup`: 是否开启深度监督机制(Deep Supervision)
* `set_cgm`: 是否设置分类引导模块(Class-guided Module)
'''
def __init__(self, in_channels=3, n_classes=2,
is_batchnorm=True, deep_sup=True, set_cgm=True):
super(UNet3Plus, self).__init__()
self.deep_sup = deep_sup
self.set_cgm = set_cgm
filters = [64, 128, 256, 512, 1024] # 各模块的卷积核大小
cat_channels = filters[0] # 统一后的特征图通道数
cat_blocks = 5 # 编(解)码器的层数
up_channels = cat_channels * cat_blocks # 特征融合后的通道数
# ====================== Encoders ======================
self.conv_e1 = Encoder(in_channels, filters[0], is_batchnorm)
self.pool_e1 = nn.MaxPool2D(kernel_size=2)
self.conv_e2 = Encoder(filters[0], filters[1], is_batchnorm)
self.pool_e2 = nn.MaxPool2D(kernel_size=2)
self.conv_e3 = Encoder(filters[1], filters[2], is_batchnorm)
self.pool_e3 = nn.MaxPool2D(kernel_size=2)
self.conv_e4 = Encoder(filters[2], filters[3], is_batchnorm)
self.pool_e4 = nn.MaxPool2D(kernel_size=2)
self.conv_e5 = Encoder(filters[3], filters[4], is_batchnorm)
# ====================== Decoders ======================
self.conv_d4 = Decoder(4, cat_channels, up_channels, filters)
self.conv_d3 = Decoder(3, cat_channels, up_channels, filters)
self.conv_d2 = Decoder(2, cat_channels, up_channels, filters)
self.conv_d1 = Decoder(1, cat_channels, up_channels, filters)
# ======================= Output =======================
if self.set_cgm:
# -------------- Class-guided Module ---------------
self.cls = nn.Sequential(nn.Dropout(p=0.5),
nn.Conv2D(filters[4], 2, 1),
nn.AdaptiveMaxPool2D(1),
nn.Sigmoid())
if self.deep_sup:
# -------------- Bilinear Upsampling ---------------
self.upscore5 = nn.Upsample(scale_factor=16, mode="bilinear")
self.upscore4 = nn.Upsample(scale_factor=8, mode="bilinear")
self.upscore3 = nn.Upsample(scale_factor=4, mode="bilinear")
self.upscore2 = nn.Upsample(scale_factor=2, mode="bilinear")
# ---------------- Deep Supervision ----------------
self.outconv5 = nn.Conv2D(filters[4], n_classes, 3, 1, 1)
self.outconv4 = nn.Conv2D(up_channels, n_classes, 3, 1, 1)
self.outconv3 = nn.Conv2D(up_channels, n_classes, 3, 1, 1)
self.outconv2 = nn.Conv2D(up_channels, n_classes, 3, 1, 1)
self.outconv1 = nn.Conv2D(up_channels, n_classes, 3, 1, 1)
# ================= Initialize Weights =================
for m in self.sublayers():
if isinstance(m, nn.Conv2D) or isinstance(m, nn.BatchNorm):
init_weights(m, init_type='kaiming')
def dot_product(self, seg, cls):
B, N, H, W = seg.shape
seg = seg.reshape((B, N, H * W))
clssp = paddle.ones((1, N))
ecls = (cls * clssp).reshape((B, N, 1))
final = (seg * ecls).reshape((B, N, H, W))
return final
def forward(self, x):
# ====================== Encoders ======================
e1 = self.conv_e1(x) # e1: 320*320*64
e2 = self.pool_e1(self.conv_e2(e1)) # e2: 160*160*128
e3 = self.pool_e2(self.conv_e3(e2)) # e3: 80*80*256
e4 = self.pool_e3(self.conv_e4(e3)) # e4: 40*40*512
e5 = self.pool_e4(self.conv_e5(e4)) # e5: 20*20*1024
# ================ Class-guided Module =================
if self.set_cgm:
cls_branch = self.cls(e5).squeeze(3).squeeze(2)
cls_branch_max = cls_branch.argmax(axis=1)
cls_branch_max = cls_branch_max[:, np.newaxis].astype("float32")
# ====================== Decoders ======================
d5 = e5
d4 = self.conv_d4((e1, e2, e3, e4, d5))
d3 = self.conv_d3((e1, e2, e3, d4, d5))
d2 = self.conv_d2((e1, e2, d3, d4, d5))
d1 = self.conv_d1((e1, d2, d3, d4, d5))
# ======================= Output =======================
if self.deep_sup:
y5 = self.upscore5( self.outconv5(d5) ) # 16 => 256
y4 = self.upscore4( self.outconv4(d4) ) # 32 => 256
y3 = self.upscore3( self.outconv3(d3) ) # 64 => 256
y2 = self.upscore2( self.outconv2(d2) ) # 128 => 256
y1 = self.outconv1(d1) # 256
if self.set_cgm:
y5 = self.dot_product(y5, cls_branch_max)
y4 = self.dot_product(y4, cls_branch_max)
y3 = self.dot_product(y3, cls_branch_max)
y2 = self.dot_product(y2, cls_branch_max)
y1 = self.dot_product(y1, cls_branch_max)
return F.sigmoid(y1), F.sigmoid(y2), F.sigmoid(y3),\
F.sigmoid(y4), F.sigmoid(y5)
else:
y1 = self.outconv1(d1) # 320*320*n_classes
if self.set_cgm:
y1 = self.dot_product(y1, cls_branch_max)
return F.sigmoid(y1)
- 实例化模型
model = UNet3Plus(n_classes=N_CLASSES, deep_sup=False, set_cgm=False)
# paddle.Model(model).summary((1, 3) + IMG_SIZE) # 可视化模型结构
- 定义损失函数
class DiceLoss(nn.Layer):
''' Dice Loss for Segmentation Tasks'''
def __init__(self,
n_classes: int = 2,
smooth: Union[float, Tuple[float, float]] = (0, 1e-6),
sigmoid_x: bool = False,
softmax_x: bool = True,
onehot_y: bool = True,
square_xy: bool = True,
include_bg: bool = True,
reduction: str = "mean"):
''' Args:
* `n_classes`: number of classes.
* `smooth`: smoothing parameters of the dice coefficient.
* `sigmoid_x`: whether using `sigmoid` to process the result.
* `softmax_x`: whether using `softmax` to process the result.
* `onehot_y`: whether using `one-hot` to encode the label.
* `square_xy`: whether using squared result and label.
* `include_bg`: whether taking account of bg-class when computering dice.
* `reduction`: reduction function of dice loss.
'''
super(DiceLoss, self).__init__()
if reduction not in ["mean", "sum"]:
raise NotImplementedError(
"`reduction` of dice loss should be 'mean' or 'sum'!"
)
if isinstance(smooth, float):
self.smooth = (smooth, smooth)
else:
self.smooth = smooth
self.n_classes = n_classes
self.sigmoid_x = sigmoid_x
self.softmax_x = softmax_x
self.onehot_y = onehot_y
self.square_xy = square_xy
self.include_bg = include_bg
self.reduction = reduction
def forward(self, pred, mask):
(sm_nr, sm_dr) = self.smooth
if self.sigmoid_x:
pred = F.sigmoid(pred)
if self.n_classes > 1:
if self.softmax_x and self.n_classes == pred.shape[1]:
pred = F.softmax(pred, axis=1)
if self.onehot_y:
mask = mask if mask.ndim < 4 else mask.squeeze(axis=1)
mask = F.one_hot(mask.astype("int64"), self.n_classes)
mask = mask.transpose((0, 3, 1, 2))
if not self.include_bg:
pred = pred[:, 1:] if pred.shape[1] > 1 else pred
mask = mask[:, 1:] if mask.shape[1] > 1 else mask
if pred.ndim != mask.ndim or pred.shape[1] != mask.shape[1]:
raise ValueError(
f"The shape of `pred`({pred.shape}) and " +
f"`mask`({mask.shape}) should be the same."
)
# only reducing spatial dimensions:
reduce_dims = paddle.arange(2, pred.ndim).tolist()
insersect = paddle.sum(pred * mask, axis=reduce_dims)
if self.square_xy:
pred, mask = paddle.pow(pred, 2), paddle.pow(mask, 2)
pred_sum = paddle.sum(pred, axis=reduce_dims)
mask_sum = paddle.sum(mask, axis=reduce_dims)
loss = 1. - (2 * insersect + sm_nr) / (pred_sum + mask_sum + sm_dr)
if self.reduction == "sum":
loss = paddle.sum(loss)
else:
loss = paddle.mean(loss)
return loss
- 定义评估方法
def dice_func(pred: np.ndarray, mask: np.ndarray,
n_classes: int, ignore_bg: bool = False):
''' compute dice (for NumpyArray) '''
def sub_dice(x: paddle.Tensor, y: paddle.Tensor, sm: float = 1e-6):
intersect = np.sum(np.sum(np.sum(x * y)))
y_sum = np.sum(np.sum(np.sum(y)))
x_sum = np.sum(np.sum(np.sum(x)))
return (2 * intersect + sm) / (x_sum + y_sum + sm)
assert pred.shape == mask.shape
assert isinstance(ignore_bg, bool)
return [
sub_dice(pred == i, mask == i)
for i in range(int(ignore_bg), n_classes)
]
3.3. 模型训练
model.train() # 开启训练模式
scheduler = CosineAnnealingDecay(
learning_rate=INIT_LR,
T_max=EPOCHS,
) # 定义学习率衰减器
optimizer = AdamW(
learning_rate=scheduler,
parameters=model.parameters(),
weight_decay=1e-5
) # 定义Adam优化器
dice_loss = DiceLoss(n_classes=N_CLASSES)
loss_list = [] # 用于可视化
for ep in range(EPOCHS):
ep_loss_list = []
for batch_id, data in enumerate(train_loader()):
image, label = data
pred = model(image) # 预测结果
loss = dice_loss(pred, label) # 计算损失函数值
if batch_id % LOG_GAP == 0: # 定期输出训练结果
print("Epoch:%2d,Batch:%3d,Loss:%.5f" % (ep, batch_id, loss))
ep_loss_list.append(loss.item())
optimizer.clear_grad()
loss.backward()
optimizer.step()
scheduler.step() # 衰减一次学习率
loss_list.append(np.mean(ep_loss_list))
print("【Train】Epoch:%2d,Loss:%.5f" % (ep, loss_list[-1]))
paddle.save(model.state_dict(), MODEL_PATH) # 保存训练好的模型
模型训练的结果如下:
Epoch: 0,Batch: 0,Loss:0.41813
Epoch: 0,Batch:500,Loss:0.51309
Epoch: 0,Batch:1000,Loss:0.32444
Epoch: 0,Batch:1500,Loss:0.22436
Epoch: 0,Batch:2000,Loss:0.07805
【Train】Epoch: 0,Loss:0.19494
Epoch: 1,Batch: 0,Loss:0.50250
Epoch: 1,Batch:500,Loss:0.50011
Epoch: 1,Batch:1000,Loss:0.05158
Epoch: 1,Batch:1500,Loss:0.06440
Epoch: 1,Batch:2000,Loss:0.09458
【Train】Epoch: 1,Loss:0.16005
Epoch: 2,Batch: 0,Loss:0.25998
Epoch: 2,Batch:500,Loss:0.03422
Epoch: 2,Batch:1000,Loss:0.50014
Epoch: 2,Batch:1500,Loss:0.35213
Epoch: 2,Batch:2000,Loss:0.04104
【Train】Epoch: 2,Loss:0.14942
Epoch: 3,Batch: 0,Loss:0.06672
Epoch: 3,Batch:500,Loss:0.05075
Epoch: 3,Batch:1000,Loss:0.03801
Epoch: 3,Batch:1500,Loss:0.05001
Epoch: 3,Batch:2000,Loss:0.03976
【Train】Epoch: 3,Loss:0.14288
Epoch: 4,Batch: 0,Loss:0.06034
Epoch: 4,Batch:500,Loss:0.08312
Epoch: 4,Batch:1000,Loss:0.50062
Epoch: 4,Batch:1500,Loss:0.03367
Epoch: 4,Batch:2000,Loss:0.03980
【Train】Epoch: 4,Loss:0.13926
Epoch: 5,Batch: 0,Loss:0.05745
Epoch: 5,Batch:500,Loss:0.04486
Epoch: 5,Batch:1000,Loss:0.06463
Epoch: 5,Batch:1500,Loss:0.08085
Epoch: 5,Batch:2000,Loss:0.03778
【Train】Epoch: 5,Loss:0.13551
Epoch: 6,Batch: 0,Loss:0.02407
Epoch: 6,Batch:500,Loss:0.50000
Epoch: 6,Batch:1000,Loss:0.50007
Epoch: 6,Batch:1500,Loss:0.05890
Epoch: 6,Batch:2000,Loss:0.03876
【Train】Epoch: 6,Loss:0.13283
Epoch: 7,Batch: 0,Loss:0.05039
Epoch: 7,Batch:500,Loss:0.02733
Epoch: 7,Batch:1000,Loss:0.02768
Epoch: 7,Batch:1500,Loss:0.03542
Epoch: 7,Batch:2000,Loss:0.14349
【Train】Epoch: 7,Loss:0.13040
Epoch: 8,Batch: 0,Loss:0.02584
Epoch: 8,Batch:500,Loss:0.11713
Epoch: 8,Batch:1000,Loss:0.04467
Epoch: 8,Batch:1500,Loss:0.04462
Epoch: 8,Batch:2000,Loss:0.02022
【Train】Epoch: 8,Loss:0.12809
Epoch: 9,Batch: 0,Loss:0.04599
Epoch: 9,Batch:500,Loss:0.01690
Epoch: 9,Batch:1000,Loss:0.02768
Epoch: 9,Batch:1500,Loss:0.50053
Epoch: 9,Batch:2000,Loss:0.04013
【Train】Epoch: 9,Loss:0.12536
Epoch:10,Batch: 0,Loss:0.03324
Epoch:10,Batch:500,Loss:0.36780
Epoch:10,Batch:1000,Loss:0.03769
Epoch:10,Batch:1500,Loss:0.50011
Epoch:10,Batch:2000,Loss:0.50002
【Train】Epoch:10,Loss:0.12356
Epoch:11,Batch: 0,Loss:0.03896
Epoch:11,Batch:500,Loss:0.11487
Epoch:11,Batch:1000,Loss:0.03414
Epoch:11,Batch:1500,Loss:0.06988
Epoch:11,Batch:2000,Loss:0.05266
【Train】Epoch:11,Loss:0.12255
Epoch:12,Batch: 0,Loss:0.02918
Epoch:12,Batch:500,Loss:0.50000
Epoch:12,Batch:1000,Loss:0.50000
Epoch:12,Batch:1500,Loss:0.05509
Epoch:12,Batch:2000,Loss:0.06147
【Train】Epoch:12,Loss:0.12097
Epoch:13,Batch: 0,Loss:0.03541
Epoch:13,Batch:500,Loss:0.03809
Epoch:13,Batch:1000,Loss:0.04672
Epoch:13,Batch:1500,Loss:0.02856
Epoch:13,Batch:2000,Loss:0.02951
【Train】Epoch:13,Loss:0.11975
Epoch:14,Batch: 0,Loss:0.06455
Epoch:14,Batch:500,Loss:0.03240
Epoch:14,Batch:1000,Loss:0.05857
Epoch:14,Batch:1500,Loss:0.02092
Epoch:14,Batch:2000,Loss:0.02371
【Train】Epoch:14,Loss:0.11936
Epoch:15,Batch: 0,Loss:0.50000
Epoch:15,Batch:500,Loss:0.03537
Epoch:15,Batch:1000,Loss:0.50006
Epoch:15,Batch:1500,Loss:0.05185
Epoch:15,Batch:2000,Loss:0.50004
【Train】Epoch:15,Loss:0.11859
- 可视化训练过程
fig = plt.figure(figsize=[10, 5])
# 训练误差图像:
ax = fig.add_subplot(111, facecolor="#E8E8F8")
ax.set_xlabel("Steps", fontsize=18)
ax.set_ylabel("Loss", fontsize=18)
plt.tick_params(labelsize=14)
ax.plot(range(len(loss_list)), loss_list, color="orangered")
ax.grid(linewidth=1.5, color="white") # 显示网格
fig.tight_layout()
plt.show()
plt.close()
3.4. 模型评估
model.eval() # 开启评估模式
model.set_state_dict(
paddle.load(MODEL_PATH)
) # 载入预训练模型参数
dice_accs = []
for batch_id, data in enumerate(test_loader()):
image, label = data
pred = model(image) # 预测结果
pred = pred.argmax(axis=1).squeeze(axis=0).cpu().numpy()
label = label.squeeze(0).squeeze(0).cpu().numpy()
dice = dice_func(pred, label, N_CLASSES) # 计算损失函数值
dice_accs.append(dice)
print("Eval \t Dice: %.5f" % (np.mean(dice_accs)))
模型评估的结果如下:
Eval Dice: 0.94400
3.5. 模型预测
def show_result(img_path, lab_path, pred):
''' 展示原图、标签以及预测结果 '''
def add_subimg(img, loc, title, cmap=None):
''' 添加子图以展示图像 '''
plt.subplot(loc)
plt.title(title)
plt.imshow(img, cmap)
plt.xticks([]) # 去除X刻度
plt.yticks([]) # 去除Y刻度
def colormap(colors=['#A0C185', '#A6A6A6']):
''' 自定义ColorMap '''
return LSC.from_list('cmap', colors, 256)
img = Image.open(img_path).resize(IMG_SIZE)
lab = Image.open(lab_path).resize(IMG_SIZE)
pred = pred.argmax(axis=1).numpy().reshape(IMG_SIZE)
plt.figure(figsize=(12, 4))
add_subimg(img, 131, "Image")
add_subimg(lab, 132, "Label")
add_subimg(pred, 133, "Predict", colormap())
plt.tight_layout()
plt.show()
plt.close()
model.eval() # 开启评估模式
model.set_state_dict(
paddle.load(MODEL_PATH)
) # 载入预训练模型参数
for i in range(len(INFER_PATH["img"])):
img_path, lab_path = INFER_PATH["img"][i], INFER_PATH["lab"][i]
img, lab = data_mapper(img_path, lab_path) # 处理预测图像
pred = model(img[np.newaxis, ...]) # 开始模型预测
show_result(img_path, lab_path, pred)
第1组图像分割结果如下:
第2组图像分割结果如下:
写在最后
- 如果您发现项目存在问题,或者如果您有更好的建议,欢迎在下方评论区中留言讨论~
- 这是本项目的链接:实验项目 - AI Studio,点击
fork
可直接在AI Studio运行~- 这是我的个人主页:个人主页 - AI Studio,来AI Studio互粉吧,等你哦~
- 【友链滴滴】欢迎大家随时访问我的个人博客~
更多推荐
所有评论(0)