本文将介绍深度学习语义分割任务中几个经典模块,主要包括:ASPP、PP、Encoding、JPU、DCM,Criss-Cross Attention几个模块,同时给出了各个模块的实现代码。

目录

一、ASPP(Atrous spatial pyramid pooling)

二、PP(Pyramid Pooling Module)

三、DCM(Dynamic Convolutional Module)

四、JPU(Joint Pyramid Upsampling)

五、Encoding(Context Encoding Module)

六、Criss-Cross Attention Module

声明


一、ASPP(Atrous spatial pyramid pooling)

ASPP模块最初是在DeepLabV2中提出的,该模块由4个并行的空洞卷积模块组成(卷积率不同),以获得较大的感受野,提取更多的上下文信息。初始结构如下图:

DeepLabV3中,对该模块进行了改进,修改了空洞比率,同时增加了全局池化层来提取全局信息,并且在ASPP模块中引入了BatchNormalization。结构如下:

DeepLabV3+中,提出将ASPP的卷积替换成Depth-wise卷积来减少参数数量,加快计算速度,结构如下图(和上图一样的):

 ASPP代码(PaddleSeg ASPP):

class ASPPModule(nn.Layer):

    def __init__(self,
                 aspp_ratios,
                 in_channels,
                 out_channels,
                 align_corners,
                 use_sep_conv=False,
                 image_pooling=False,
                 data_format='NCHW'):
        super().__init__()

        self.align_corners = align_corners
        self.data_format = data_format
        self.aspp_blocks = nn.LayerList()

        for ratio in aspp_ratios:
            if use_sep_conv and ratio > 1:
                conv_func = layers.SeparableConvBNReLU
            else:
                conv_func = layers.ConvBNReLU

            block = conv_func(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=1 if ratio == 1 else 3,
                dilation=ratio,
                padding=0 if ratio == 1 else ratio,
                data_format=data_format)
            self.aspp_blocks.append(block)

        out_size = len(self.aspp_blocks)

        if image_pooling:
            self.global_avg_pool = nn.Sequential(
                nn.AdaptiveAvgPool2D(
                    output_size=(1, 1), data_format=data_format),
                layers.ConvBNReLU(
                    in_channels,
                    out_channels,
                    kernel_size=1,
                    bias_attr=False,
                    data_format=data_format))
            out_size += 1
        self.image_pooling = image_pooling

        self.conv_bn_relu = layers.ConvBNReLU(
            in_channels=out_channels * out_size,
            out_channels=out_channels,
            kernel_size=1,
            data_format=data_format)

        self.dropout = nn.Dropout(p=0.1)  # drop rate

    def forward(self, x):
        outputs = []
        if self.data_format == 'NCHW':
            interpolate_shape = paddle.shape(x)[2:]
            axis = 1
        else:
            interpolate_shape = paddle.shape(x)[1:3]
            axis = -1
        for block in self.aspp_blocks:
            y = block(x)
            outputs.append(y)

        if self.image_pooling:
            img_avg = self.global_avg_pool(x)
            img_avg = F.interpolate(
                img_avg,
                interpolate_shape,
                mode='bilinear',
                align_corners=self.align_corners,
                data_format=self.data_format)
            outputs.append(img_avg)

        x = paddle.concat(outputs, axis=axis)
        x = self.conv_bn_relu(x)
        x = self.dropout(x)

        return x

二、PP(Pyramid Pooling Module)

PP模块是在论文PSPNet中首次提出的,该模块由4个并行的自适应池化通道组成,不同大小的自适应池化层可以获得不同程度的上下文信息,从而提升网络效果,模型结构如下:

首先特征图经过不同的自适应池化层得到不同分辨率的特征图,然后使用卷积层压缩通道数,接着将特征图上采样至输入特征图同样大小后,将输入特征图和4个特征图concat送入后续模块。

PP Module代码(PaddleSeg PP):

class PPModule(nn.Layer):
    def __init__(self, in_channels, out_channels, bin_sizes, dim_reduction,
                 align_corners):
        super().__init__()

        self.bin_sizes = bin_sizes

        inter_channels = in_channels
        if dim_reduction:
            inter_channels = in_channels // len(bin_sizes)

        # we use dimension reduction after pooling mentioned in original implementation.
        self.stages = nn.LayerList([
            self._make_stage(in_channels, inter_channels, size)
            for size in bin_sizes
        ])

        self.conv_bn_relu2 = layers.ConvBNReLU(
            in_channels=in_channels + inter_channels * len(bin_sizes),
            out_channels=out_channels,
            kernel_size=3,
            padding=1)

        self.align_corners = align_corners

    def _make_stage(self, in_channels, out_channels, size):
        prior = nn.AdaptiveAvgPool2D(output_size=(size, size))
        conv = layers.ConvBNReLU(
            in_channels=in_channels, out_channels=out_channels, kernel_size=1)

        return nn.Sequential(prior, conv)

    def forward(self, input):
        cat_layers = []
        for stage in self.stages:
            x = stage(input)
            x = F.interpolate(
                x,
                paddle.shape(input)[2:],
                mode='bilinear',
                align_corners=self.align_corners)
            cat_layers.append(x)
        cat_layers = [input] + cat_layers[::-1]
        cat = paddle.concat(cat_layers, axis=1)
        out = self.conv_bn_relu2(cat)

        return out

三、DCM(Dynamic Convolutional Module)

DCM模块是在论文DMNet(Dynamic Multi-scale Filters for Semantic Segmentation)中提出的,ASPP和PP模块获取多尺度信息是通过空洞卷积和自适应池化得到的,他们的参数在推理的时候是固定的,没法根据不同的图像来自适应调整参数,DCM能够根据不同的图像自适应调整参数。

下图为DMNet的网络结构,多个DCM模块并行提取不同尺度的语义信息。

如下图,DCM模块有2个分支,上面的分支负责压缩特征通道(减少计算量),下面的分支通过自适应池化得到Depth-wise卷积的权重,然后2个分支的输出做Depthwise卷积。由于权重是动态生成的,对不同输入的图片适应性更强,能自适应改变自身的参数。

网络结构

 DCM代码(自己实现的DMNetDCM):

class DCM(nn.Layer):
    def __init__(self, filter_size, fusion, in_channels, channels):
        super().__init__()
        self.filter_size = filter_size
        self.fusion = fusion
        self.channels = channels

        self.filter_gen_conv = nn.Conv2D(in_channels, channels, 1)
        self.input_redu_conv = layers.ConvBNReLU(in_channels, channels, 1)

        self.norm = layers.SyncBatchNorm(channels)
        self.act = nn.ReLU()

        if self.fusion:
            self.fusion_conv = layers.ConvBNReLU(channels, channels, 1)

    def forward(self, x):
        generated_filter = self.filter_gen_conv(F.adaptive_avg_pool2d(x, self.filter_size))
        x = self.input_redu_conv(x)
        b, c, h, w = x.shape
        x = x.reshape([1, b * c, h, w])
        generated_filter = generated_filter.reshape([b * c, 1, self.filter_size, self.filter_size])
        pad = (self.filter_size - 1) // 2
        if (self.filter_size - 1) % 2 == 0:
            pad = (pad, pad, pad, pad)
        else:
            pad = (pad + 1, pad, pad + 1, pad)
        x = F.pad(x, pad, mode='constant', value=0) # [1, b * c, h, w]
        output = F.conv2d(x, weight=generated_filter, groups=b * c)
        output = output.reshape([b, self.channels, h, w])
        output = self.norm(output)
        output = self.act(output)
        if self.fusion:
            output = self.fusion_conv(output)
        return output

四、JPU(Joint Pyramid Upsampling)

JPU是在论文FastFCN中首次提出。DeepLab中移除backbone的下采样层(原来的output_stride=32,deeplab的output_stride=8/16),保留高分辨率的输出特征图,同时加入了空洞卷积模块来增加模型的感受野,但是特征图的变大会导致计算量的增加。有了JPU模块,backbone正常输出下采样的特征图,由JPU输出高分辨率特征图,并且JPU也利用空洞卷积来增加感受野。

网络结构

FastFCN模型结构如下图所示,在backbone后接JPU模块,JPU输出高分辨率特征图送入后续网络:

JPU(Joint Pyramid Upsampling)

不同分辨率的特征图经过卷积层压缩通道,上采样至相同分辨率后concat,接着将特征图送入4通道并行空洞率不同的的Separable Conv(Depth-wise卷积 + Point-wise卷积)层,不同通道的特征图concat后经过卷积层送入后续网络。

 JPU代码(来自:https://github.com/wuhuikai/FastFCN):

class SeparableConv2d(nn.Module):
    def __init__(self, inplanes, planes, kernel_size=3, stride=1, padding=1, dilation=1, bias=False, norm_layer=nn.BatchNorm2d):
        super(SeparableConv2d, self).__init__()

        self.conv1 = nn.Conv2d(inplanes, inplanes, kernel_size, stride, padding, dilation, groups=inplanes, bias=bias)
        self.bn = norm_layer(inplanes)
        self.pointwise = nn.Conv2d(inplanes, planes, 1, 1, 0, 1, 1, bias=bias)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn(x)
        x = self.pointwise(x)
        return x


class JPU(nn.Module):
    def __init__(self, in_channels, width=512, norm_layer=None, up_kwargs=None):
        super(JPU, self).__init__()
        self.up_kwargs = up_kwargs

        self.conv5 = nn.Sequential(
            nn.Conv2d(in_channels[-1], width, 3, padding=1, bias=False),
            norm_layer(width),
            nn.ReLU(inplace=True))
        self.conv4 = nn.Sequential(
            nn.Conv2d(in_channels[-2], width, 3, padding=1, bias=False),
            norm_layer(width),
            nn.ReLU(inplace=True))
        self.conv3 = nn.Sequential(
            nn.Conv2d(in_channels[-3], width, 3, padding=1, bias=False),
            norm_layer(width),
            nn.ReLU(inplace=True))

        self.dilation1 = nn.Sequential(SeparableConv2d(3*width, width, kernel_size=3, padding=1, dilation=1, bias=False),
                                       norm_layer(width),
                                       nn.ReLU(inplace=True))
        self.dilation2 = nn.Sequential(SeparableConv2d(3*width, width, kernel_size=3, padding=2, dilation=2, bias=False),
                                       norm_layer(width),
                                       nn.ReLU(inplace=True))
        self.dilation3 = nn.Sequential(SeparableConv2d(3*width, width, kernel_size=3, padding=4, dilation=4, bias=False),
                                       norm_layer(width),
                                       nn.ReLU(inplace=True))
        self.dilation4 = nn.Sequential(SeparableConv2d(3*width, width, kernel_size=3, padding=8, dilation=8, bias=False),
                                       norm_layer(width),
                                       nn.ReLU(inplace=True))

    def forward(self, *inputs):
        feats = [self.conv5(inputs[-1]), self.conv4(inputs[-2]), self.conv3(inputs[-3])]
        _, _, h, w = feats[-1].size()
        feats[-2] = F.interpolate(feats[-2], (h, w), **self.up_kwargs)
        feats[-3] = F.interpolate(feats[-3], (h, w), **self.up_kwargs)
        feat = torch.cat(feats, dim=1)
        feat = torch.cat([self.dilation1(feat), self.dilation2(feat), self.dilation3(feat), self.dilation4(feat)], dim=1)

        return inputs[0], inputs[1], inputs[2], feat

五、Encoding(Context Encoding Module)

Encoding模块最早由ENCNet提出,Context Encoding能够有选择地突出不同类别之间相互依赖的特征。

输入特征图维度为[N, C, H, W],X= \left \{ { x_{1},...,x_{N} }\right \},引入可学习参数codebook:D=\left \{ d_{1},...,d_{K} \right \}(d_{i}是codeword,论文中又叫做visual centers),引入visual centers对应的可学习参数平滑因子:S=\left \{ s_{1},...,s_{K} \right \}r_{ik}=x_{i}-d_{k},计算公式如下:

codebook中的K个visual centers与所有训练数据的特征图进行过计算,可以学习到全局的语义信息。

ENCNet的网络结构如下图:

Context Encoding Module的输入是[N, C, H, W]的特征图,输出有2个,一个是[N, C, 1, 1]的注意力权重,一个是[N, num_classes]的矩阵,用于计算SELoss

 Context Encoding Module代码(来自Encoding):

class Encoding(nn.Layer):
    def __init__(self, channels, num_codes):
        super().__init__()
        self.channels, self.num_codes = channels, num_codes

        std = 1 / ((channels * num_codes) ** 0.5)
        self.codewords = self.create_parameter(
            shape=(num_codes, channels),
            default_initializer=nn.initializer.Uniform(-std, std),
        )  # codebook,visual centers合集
        self.scale = self.create_parameter(
            shape=(num_codes,),
            default_initializer=nn.initializer.Uniform(-1, 0),
        )  # codewords对应的平滑因子
        self.channels = channels

    def scaled_l2(self, x, codewords, scale):
        # 对应公式中分子分母括号内部分
        num_codes, channels = paddle.shape(codewords)
        reshaped_scale = scale.reshape([1, 1, num_codes])
  
        expanded_x = paddle.tile(x.unsqueeze(2), [1, 1, num_codes, 1])
        reshaped_codewords = codewords.reshape([1, 1, num_codes, channels])

        scaled_l2_norm = paddle.multiply(reshaped_scale, (expanded_x - reshaped_codewords).pow(2).sum(axis=3)) # N, H*W, num_codes
        return scaled_l2_norm

    def aggregate(self, assignment_weights, x, codewords):
        num_codes, channels = paddle.shape(codewords)
        reshaped_codewords = codewords.reshape([1, 1, num_codes, channels])
        expanded_x = paddle.tile(x.unsqueeze(2), [1, 1, num_codes, 1])
        
        encoded_feat = paddle.multiply(assignment_weights.unsqueeze(3), (expanded_x - reshaped_codewords)).sum(axis=1) # N, num_codes, C
        encoded_feat = paddle.reshape(encoded_feat, [-1, self.num_codes, self.channels])
        return encoded_feat
    
    def forward(self, x):
        x_dims = x.ndim
        assert x_dims == 4, "The dimension of input tensor must equal 4, but got {}.".format(x_dims)
        assert paddle.shape(x)[1] == self.channels, "Encoding channels error, excepted {} but got {}.".format(self.channels, paddle.shape(x)[1])
 
        batch_size = paddle.shape(x)[0]
        
        x = x.reshape([batch_size, self.channels, -1]).transpose([0, 2, 1]) # N, H*W, C
        
        assignment_weights = F.softmax(self.scaled_l2(x, self.codewords, self.scale), axis=2) # N, H*W, num_codes

        encoded_feat = self.aggregate(assignment_weights, x, self.codewords) # N, num_codes, C
        
        return encoded_feat

class EncModule(nn.Layer):
    def __init__(self, in_channels, num_codes):
        super().__init__()
        self.encoding_project = layers.ConvBNReLU(
            in_channels,
            in_channels,
            1,
        )
        self.encoding = nn.Sequential(
            Encoding(channels=in_channels, num_codes=num_codes),
            nn.BatchNorm1D(num_codes),
            nn.ReLU(),
        )
        self.fc = nn.Sequential(
            nn.Linear(in_channels, in_channels),
            nn.Sigmoid(),
        )
        self.in_channels = in_channels
        self.num_codes = num_codes

    def forward(self, x):
        encoding_projection = self.encoding_project(x)
        encoding_feat = self.encoding(encoding_projection) # N, num_codes, C
        
        encoding_feat = encoding_feat.mean(axis=1) # N, C
        batch_size, channels, _, _ = paddle.shape(x)
        
        gamma = self.fc(encoding_feat)
        y = gamma.reshape([batch_size, self.in_channels, 1, 1])
        output = F.relu(x + x * y)
        return encoding_feat, output

六、Criss-Cross Attention Module

Criss-Cross Attention Module是CCNet提出的一种可以建立全局联系的模块,其原理如下图:

 根据输入得到QKV,利用矩阵乘法使每个像素得到其纵向和横向所有像素的语义信息,只需要2个Criss-cross attention模块,每个像素就可以得到全局语义信息。

该模块详解可参考:【论文笔记】CCNet阅读笔记

模块代码:

def INF(B,H,W):
     return -torch.diag(torch.tensor(float("inf")).cuda().repeat(H),0).unsqueeze(0).repeat(B*W,1,1)
 
 
class CrissCrossAttention(nn.Module):
    """ Criss-Cross Attention Module"""
    def __init__(self, in_dim):
        super(CrissCrossAttention,self).__init__()
        self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
        self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
        self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
        self.softmax = Softmax(dim=3)
        self.INF = INF
        self.gamma = nn.Parameter(torch.zeros(1))
 
 
    def forward(self, x):
        m_batchsize, _, height, width = x.size()
        proj_query = self.query_conv(x)
        proj_query_H = proj_query.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height).permute(0, 2, 1)
        proj_query_W = proj_query.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width).permute(0, 2, 1)
        proj_key = self.key_conv(x)
        proj_key_H = proj_key.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height)
        proj_key_W = proj_key.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width)
        proj_value = self.value_conv(x)
        proj_value_H = proj_value.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height)
        proj_value_W = proj_value.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width)
        energy_H = (torch.bmm(proj_query_H, proj_key_H)+self.INF(m_batchsize, height, width)).view(m_batchsize,width,height,height).permute(0,2,1,3)
        energy_W = torch.bmm(proj_query_W, proj_key_W).view(m_batchsize,height,width,width)
        concate = self.softmax(torch.cat([energy_H, energy_W], 3))
 
        att_H = concate[:,:,:,0:height].permute(0,2,1,3).contiguous().view(m_batchsize*width,height,height)
        #print(concate)
        #print(att_H) 
        att_W = concate[:,:,:,height:height+width].contiguous().view(m_batchsize*height,width,width)
        out_H = torch.bmm(proj_value_H, att_H.permute(0, 2, 1)).view(m_batchsize,width,-1,height).permute(0,2,3,1)
        out_W = torch.bmm(proj_value_W, att_W.permute(0, 2, 1)).view(m_batchsize,height,-1,width).permute(0,2,1,3)
        #print(out_H.size(),out_W.size())
        return self.gamma*(out_H + out_W) + x

声明

禁止转载。

Logo

旨在为数千万中国开发者提供一个无缝且高效的云端环境,以支持学习、使用和贡献开源项目。

更多推荐