在这里插入图片描述

题目:ITFuse: An interactive transformer for infrared and visible image fusion

会议:Pattern Recognition(PR)

论文:https://www.sciencedirect.com/science/article/pii/S0031320324005739?via%3Dihub

代码:https://github.com/tthinking/ITFuse

年份:2024

1.摘要和引言

近年来,一些基于Transformer的融合模型被提出用于全局特征提取。然而,这些方法要么是特征级融合模型,要么是图像级融合模型。这些基于Transformer的融合模型均未考虑信息交互,导致对互补信息的挖掘不足。

在本文中,我们提出一种新颖的用于红外与可见光图像融合(IVIF)的端到端无监督交互式Transformer,称为ITFuse

它由**特征交互模块(FIMs)特征重建模块(FRM)**组成,用于交替捕获和融合重要信息。

具体而言,为了充分捕获不同模态的同质信息,我们设计了**残差注意力块(RAB)用于特征表示。此外,为了融合各自源图像中存在的独特属性,我们利用交互式注意力(ITA)**来聚合互补特征,实现充分的特征交互与整合。

另外,我们设计了跨模态注意力(CMA)Transformer块(TRB),用于合并和重建所提取的独特特征与共同特征,从而得到信息丰富的融合结果。

此外,我们还设计了像素损失结构损失,以无监督方式训练ITFuse,从而实现性能的进一步提升。

主要贡献:

  • 我们提出了一种用于 IVIF 的端到端交互式 Transformer(名为 ITFuse),其中设计了交互注意力(ITA)和残差注意力块(RAB),以选择性地提取个体特征并保留共同信息,从而实现充分的互补特征保留。
  • 我们提出了跨模态注意力(CMA)和 Transformer 块(TRB),以充分整合所提取的特征并构建多模态长距离关系,从而获得良好的性能。
  • 我们设计了像素损失结构损失,以无监督方式训练 ITFuse,从而进一步提升性能。

2.方法

2.1 问题表述

考虑到 IVIF(红外-可见光融合)的基本目标,即提取互补信息以生成信息丰富的合成融合图像,我们的模型基于以交互方式保留来自各个模态的同质和异质特征,以实现充分的重要信息挖掘和交互。具体而言:

  • 红外图像提供的信息表示为:

    F i r = F i r u + F c \mathbb{F}_{ir}=\mathbb{F}_{ir}^{u}+\mathbb{F}_{c} Fir=Firu+Fc
    其中 F i r u \mathbb{F}_{ir}^{u} Firu 表示红外图像的独特特征(反映物体的热辐射信息), F c \mathbb{F}_{c} Fc 表示红外与可见光图像共享的共同特征(如形状、边缘)。

  • 可见光图像包含的信息表示为:
    F v i = F v i u + F c \mathbb{F}_{vi}=\mathbb{F}_{vi}^{u}+\mathbb{F}_{c} Fvi=Fviu+Fc
    其中 F v i u \mathbb{F}_{vi}^{u} Fviu 表示可见光图像的独特特征(如场景的纹理细节)。

  • 红外与可见光图像融合得到的合成图像(同时具有清晰的物体和明确的场景)表示为:
    F f = F i r u + F v i u + F c \mathbb{F}_{f}=\mathbb{F}_{ir}^{u}+\mathbb{F}_{vi}^{u}+\mathbb{F}_{c} Ff=Firu+Fviu+Fc

因此,从不同模态中充分挖掘重要特征是关键。通过研究不同模态图像的内在特性,我们设计了交互式 Transformer 以解决这一挑战。

2.2 框架概述

所提出的 ITFuse 框架如图 3 所示。输入图像 I 1 I_{1} I1 I 2 I_{2} I2 被送入 N N N 个特征交互模块(FIMs)进行互补信息提取与整合,随后通过特征重建模块(FRM)生成融合结果 I f I_{f} If。具体流程如下:
在这里插入图片描述

  • 特征提取:为捕获不同模态的独特特征, I 1 I_{1} I1 I 2 I_{2} I2 分别输入 FIM。FIM 可准确保留各模态的独特特征。
  • 特征交互:对 I 1 I_{1} I1 I 2 I_{2} I2 进行通道拼接后输入 FIM,以挖掘输入图像的共同属性。共同特征与独特特征并非单独提取,而是通过交互注意力(ITA)动态交互,从而挖掘更多上下文信息(而非直接构建二维特征图的自注意力)。

2.3 特征交互模块

IVIF 的核心目标是聚合多模态重要特征,因此特征提取对融合结果至关重要。针对多模态特性,我们提出由 I 1 I_{1} I1 分支、 I c I_{c} Ic 分支(共同特征提取)和 I 2 I_{2} I2 分支组成的特征交互模块,以交互方式挖掘共同特征与独特特征(而非单独或统一提取)。

2.3.1 共同特征提取分支( I c I_{c} Ic 分支)

为充分挖掘输入图像的同质属性, I 1 I_{1} I1 I 2 I_{2} I2通道拼接后输入 I c I_{c} Ic 分支(由卷积块 CB 和残差注意力块 RAB 组成):

卷积块(CB)

用于浅层特征提取,包含两个连续卷积层,公式为:

F c b o = R ( B ( C 3 C → C ( R ( B ( C 3 2 → C ( F c b i ) ) ) ) ) ) \mathbb{F}_{cb}^{o}=\mathcal{R}\left(\mathcal{B}\left(C_{3}^{C \to C}\left(\mathcal{R}\left(\mathcal{B}\left(C_{3}^{2 \to C}\left(\mathbb{F}_{cb}^{i}\right)\right)\right)\right)\right)\right) Fcbo=R(B(C3CC(R(B(C32C(Fcbi))))))
其中 F c b i \mathbb{F}_{cb}^{i} Fcbi F c b o \mathbb{F}_{cb}^{o} Fcbo 分别为 CB 的输入和输出特征; C k c 1 → c 2 ( ⋅ ) C_{k}^{c_{1} \to c_{2}}(\cdot) Ckc1c2() 表示 k × k k \times k k×k 卷积(输入通道 c 1 c_{1} c1,输出通道 c 2 c_{2} c2); B ( ⋅ ) \mathcal{B}(\cdot) B() 为批归一化, R ( ⋅ ) \mathcal{R}(\cdot) R() 为整流线性单元(ReLU)。

cat=torch.cat((ir , vi ), 1)  # 将原始红外与可见光图在通道维度拼接

conv1_1=self.conv2_16( cat)  # 将双通道输入映射到16通道融合特征
conv1_2 = self.conv16_16(conv1_1)  # 对融合特征进行卷积增强
残差注意力块 (RAB)

在这里插入图片描述

用于强调红外与可见光共享的重要特征,包含两个协同注意力(CA)残差排列,以非平凡方式提取特征。RAB 从水平垂直方向构建注意力图:

  • 对 CA 输入特征 F c a i ∈ R H × W × C \mathbb{F}_{ca}^{i} \in \mathbb{R}^{H \times W \times C} FcaiRH×W×C 同时进行 X 方向(池化核 ( H , 1 ) (H, 1) (H,1))和 Y 方向(池化核 ( 1 , W ) (1, W) (1,W))平均池化;
  • 拼接双向池化特征后输入卷积层,输出为:
    F c l o = N ( B ( C 1 C → C / R ( F c l i ) ) ) \mathbb{F}_{cl}^{o}=\mathcal{N}\left(\mathcal{B}\left(C_{1}^{C \to C / R}\left(\mathbb{F}_{cl}^{i}\right)\right)\right) Fclo=N(B(C1CC/R(Fcli)))
    其中 F c l i ∈ R ( H + W ) × 1 × C \mathbb{F}_{cl}^{i} \in \mathbb{R}^{(H+W) \times 1 \times C} FcliR(H+W)×1×C(输入), F c l o ∈ R ( H + W ) × 1 × C / R \mathbb{F}_{cl}^{o} \in \mathbb{R}^{(H+W) \times 1 \times C/R} FcloR(H+W)×1×C/R(输出), R R R 为缩减比, N ( ⋅ ) \mathcal{N}(\cdot) N() 为非线性层(计算为 F n l o = F n l i × S ( F n l i ) \mathbb{F}_{nl}^{o}=\mathbb{F}_{nl}^{i} \times S\left(\mathbb{F}_{nl}^{i}\right) Fnlo=Fnli×S(Fnli) S ( ⋅ ) S(\cdot) S() 为 sigmoid 激活函数)。

随后, F n l o \mathbb{F}_{nl}^{o} Fnlo 分割为 F x ∈ R H × 1 × C / R \mathbb{F}_{x} \in \mathbb{R}^{H \times 1 \times C/R} FxRH×1×C/R F y ∈ R 1 × W × C / R \mathbb{F}_{y} \in \mathbb{R}^{1 \times W \times C/R} FyR1×W×C/R,分别输入卷积层得到 F x o = S ( C 1 C / R → C ( F x ) ) \mathbb{F}_{x}^{o}=S\left(C_{1}^{C/R \to C}\left(\mathbb{F}_{x}\right)\right) Fxo=S(C1C/RC(Fx)) F y o = S ( C 1 C / R → C ( F y ) ) \mathbb{F}_{y}^{o}=S\left(C_{1}^{C/R \to C}\left(\mathbb{F}_{y}\right)\right) Fyo=S(C1C/RC(Fy))(与 F c a i \mathbb{F}_{ca}^{i} Fcai 通道数一致)。最终, F x o \mathbb{F}_{x}^{o} Fxo F y o \mathbb{F}_{y}^{o} Fyo F c a i \mathbb{F}_{ca}^{i} Fcai 相乘得到 F c a o \mathbb{F}_{ca}^{o} Fcao

coorAtt1_1=self.CoordAtt( conv1_2)  # 第一次坐标注意力,捕获方向敏感信息,x
coorAtt1_2 = self.CoordAtt( coorAtt1_1)  # 第二次坐标注意力,进一步强化空间响应,y
class CoordAtt(nn.Module):  # 坐标注意力模块
    def __init__(self, inp, oup, reduction=4):
        super(CoordAtt, self).__init__()
        self.pool_h = nn.AdaptiveAvgPool2d((None, 1))   # 竖着看:只保留每一列的平均
        self.pool_w = nn.AdaptiveAvgPool2d((1, None))   # 横着看:只保留每一行的平均

        mip =  inp // reduction  # 先把通道降一点,省算力

        self.conv1 = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0) # 卷积
        self.bn1 = nn.BatchNorm2d(mip) # 归一化
        self.act = h_swish()  # 一个激活函数

        self.conv_h = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
        self.conv_w = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)

    def forward(self, x):
        identity = x          # 先拷一份原始输入,最后要乘回去
        b, c, h, w = x.size()

        # 竖直方向的全局平均:保留高度 h,宽度压成1
        x_h = self.pool_h(x)                  # [B, C, H, 1]

        # 水平方向的全局平均:保留宽度 w,高度压成1
        x_w = self.pool_w(x).permute(0,1,3,2) # [B, C, 1, W] → [B, C, W, 1]

        y = torch.cat([x_h, x_w], dim=2)      # 把“竖着看”和“横着看”的结果拼起来

        y = self.conv1(y)                     # 通道降维
        y = self.bn1(y)
        y = self.act(y)

        x_h, x_w = torch.split(y, [h, w], dim=2)
        x_w = x_w.permute(0, 1, 3, 2)

        a_h = self.conv_h(x_h).sigmoid()      # 得到“竖直方向”的注意力图
        a_w = self.conv_w(x_w).sigmoid()      # 得到“水平方向”的注意力图

        out = identity * a_w * a_h            # 原图 × 横向权重 × 纵向权重

        return out

2.3.2 独特特征提取分支( I 1 I_{1} I1 I 2 I_{2} I2 分支)

RB

为挖掘各模态独特特征, I 1 I_{1} I1 I 2 I_{2} I2 分别输入独立分支(含残差块 RB 和交互注意力 ITA)。RB 用于挖掘浅层内在属性,如图 5 所示,由三个带跳跃连接的卷积层组成(每个层含 3×3 卷积、BN 和 ReLU)。

在这里插入图片描述

把最开始那层 layerA1_1 和最后那层 layerA1_3 相加,形成一个残差结构: “输出 = 原始 + 深层卷积后的结果”,这样有利于保留原始信息,又加一点“加工”。

layerA1_1 = self.conv16(ir)  # 对红外输入提取初始16通道特征
layerA1_2 = self.conv16_16(layerA1_1)  # 红外分支第一次残差卷积
layerA1_3 = self.conv16_16(layerA1_2)  # 红外分支第二次残差卷积
resA1=layerA1_1+layerA1_3  # 形成红外分支的第一阶段残差输出
layerB1_1 = self.conv16(vi)  # 对可见光输入提取初始16通道特征
layerB1_2= self.conv16_16(layerB1_1)  # 可见分支第一次残差卷积
layerB1_3 = self.conv16_16(layerB1_2)  # 可见分支第二次残差卷积
resB1 = layerB1_1 + layerB1_3  # 形成可见分支的第一阶段残差输出
ITA

源图像的典型特征通过 ITA 动态交互(而非单独利用),设计为动态关注上下文信息(而非直接构建二维特征图的自注意力)。

在这里插入图片描述

F K \mathbb{F}_{K} FK(对应 I 1 I_{1} I1 分支的 F 1 i \mathbb{F}_{1}^{i} F1i I 2 I_{2} I2 分支的 F 2 i \mathbb{F}_{2}^{i} F2i)采用 3×3 卷积编码为单模态静态上下文信息;
F V \mathbb{F}_{V} FV(表示 I c I_{c} Ic 分支的 F c i \mathbb{F}_{c}^{i} Fci)通过 1×1 卷积编码为共同特征

thbb{F}_{K}$ 与 F C o \mathbb{F}_{Co} FCo I 1 I_{1} I1 分支的 F 2 i \mathbb{F}_{2}^{i} F2i I 2 I_{2} I2 分支的 F 1 i \mathbb{F}_{1}^{i} F1i元素级求和,经两个 1×1 卷积构建动态注意力矩阵 F D \mathbb{F}_{D} FD
F V \mathbb{F}_{V} FV F D \mathbb{F}_{D} FD 相乘实现多模态综合表示,最终 ITA 输出 F i t a \mathbb{F}_{ita} Fita 通过 F K \mathbb{F}_{K} FK 与相乘特征的加和整合静态与动态信息: F i t a o = ( F D × F V ) + F K \mathbb{F}_{ita}^{o}=\left(\mathbb{F}_{D} × \mathbb{F}_{V}\right)+\mathbb{F}_{K} Fitao=(FD×FV)+FK

    def forward(self,  v, k, q):
        k = self.key_embed(k)   #  先对键 k 做卷积编码(增强结构信息)
        qk = q + k              #  将查询 q 和键 k 相加(融合两路模态信息)

        w = self.embed(qk)      #  用 embed 生成单通道权重图 w(注意力权重)

        v = self.conv1x1(v)     #  对值 v 做1x1卷积,调整通道

        mul = w * v             #  用权重图 w 逐点调制值特征 v
        out = mul + k           #  与键特征 k 做残差相加,得到输出

        return out              #  输出交互后的特征

这里对应关系是:

  • v1:两张图一起算出来的共享特征(公共 Value)
  • 红外这条线:
    • Q_A1 = resA1:红外自己的观点
    • K_A1 = resB1:可见给红外看的“参考”
  • 可见这条线:
    • Q_B1 = resB1:可见自己的观点
    • K_B1 = resA1:红外给可见看的“参考”

2.4 特征重建模块

为将潜在空间的特征重建为融合结果,设计特征重建模块(FRM):包含跨模态注意力(CMA)、残差注意力块(RAB)、Transformer 块(TRB)和卷积块(CB)。

CMA

在这里插入图片描述

整合 I 1 I_{1} I1 I c I_{c} Ic I 2 I_{2} I2 分支的挖掘特征( F 1 n \mathbb{F}_{1}^{n} F1n F c n \mathbb{F}_{c}^{n} Fcn F 2 n \mathbb{F}_{2}^{n} F2n)。 F 1 n \mathbb{F}_{1}^{n} F1n F 2 n \mathbb{F}_{2}^{n} F2n 经 3×3 卷积编码为独特静态特征,相加后经两个卷积层生成独特动态属性; F c n \mathbb{F}_{c}^{n} Fcn 经 1×1 卷积保留共同特征。最终,独特动态特征与共同特征相乘得到整合特征 F c m a o \mathbb{F}_{cma}^{o} Fcmao(保留重要特征,过滤冗余)。

class EqualAtt(nn.Module):  # 平衡红外与可见注意力贡献的模块
    def __init__(self, dim):
        super(EqualAtt, self).__init__()
        self.key_embed = Convlutioanl(dim, dim)  # 对两个分支的特征做卷积编码

        factor = 8  # 通道压缩比
        self.embed = nn.Sequential(
            nn.Conv2d(dim, dim // factor, 1, bias=False),  # 压缩通道
            nn.BatchNorm2d(dim // factor),                # 归一化
            nn.ReLU(inplace=True),                        # 激活
            nn.Conv2d(dim // factor, 1, kernel_size=1),   # 生成单通道权重图 w
            nn.BatchNorm2d(1)
        )

        self.conv1x1 = nn.Sequential(
            nn.Conv2d(dim, dim, kernel_size=1, stride=1, padding=0, bias=False),  # 调整值 v 的通道
            nn.BatchNorm2d(dim)
        )


    def forward(self, v, k, q):
        k = self.key_embed(k)           #  对“键”做卷积编码
        q = self.key_embed(q)           #  对“查询”做卷积编码
        qk = q + k                      #  把两个模态的特征加在一起

        w = self.embed(qk)              #  从 q+k 里算出单通道权重图 w
        v = self.conv1x1(v)             #  对公共特征 v 做 1×1 卷积调整通道

        mul = w * v                     #  用 w 对 v 做逐点加权

        return mul                      #  返回融合后的特征
  
RAB

对 CMA 整合的特征从水平和垂直方向重建表示。

coorAtt4_1 = self.CoordAtt(att)  # 对平衡后的特征施加坐标注意力
coorAtt4_2 = self.CoordAtt(coorAtt4_1)  # 再次加强坐标注意力
nlb4 =  att + coorAtt4_2  # 残差融合得到最终共享特征
TRB

在这里插入图片描述

把经过 CMA + RAB 的特征丢进 TRB 里,让 Transformer 再整体“咀嚼”一遍。让 整幅图里远处的区域也能互相影响,把细节、结构和整体关系再梳理一遍。

MLP = 多层感知机 = 两个全连接层 + 一个激活函数。

为什么要用“窗口”来算注意力?普通 Transformer 的自注意力,是 整张图里所有位置互相看,太费算力了。先把图按窗口切成一块一块,每个窗口内部自己玩自注意力,下一层换一种切法(移位窗口),再让“原来不同窗口的点”有机会见面。

L L L 个基本单元(BUs)组成,每个 BU 包含两个加法操作:
先做一次“窗口自注意力 + 残差”
再做一次“MLP 前馈网络 + 残差”

  • 第一个操作: F A = M ( Y ( F b u i ) ) + F b u i \mathbb{F}_{A}=\mathcal{M}\left(\mathcal{Y}\left(\mathbb{F}_{bu}^{i}\right)\right)+\mathbb{F}_{bu}^{i} FA=M(Y(Fbui))+Fbui M ( ⋅ ) \mathcal{M}(\cdot) M() 为多头自注意力, Y ( ⋅ ) \mathcal{Y}(\cdot) Y() 为层归一化);

  • 第二个操作: F b u o = P ( Y ( F A ) ) + F A \mathbb{F}_{bu}^{o}=\mathcal{P}\left(\mathcal{Y}\left(\mathbb{F}_{A}\right)\right)+\mathbb{F}_{A} Fbuo=P(Y(FA))+FA P ( ⋅ ) \mathcal{P}(\cdot) P() 为多层感知机)。

    TRB 通过上述操作构建长距离依赖关系。

class SwinTransformerBlock(nn.Module):  # Swin Transformer基本模块

    def __init__(self, dim, input_resolution, num_heads, window_size=1, shift_size=0,
                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
                 act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim  # 特征通道数
        self.input_resolution = input_resolution  # 输入特征的空间尺寸(H, W)
        self.num_heads = num_heads  # 多头数量
        self.window_size = window_size  # 注意力窗口尺寸
        self.shift_size = shift_size  # Shift大小,实现交错窗口
        self.mlp_ratio = mlp_ratio  # MLP隐藏层倍率

        if min(self.input_resolution) <= self.window_size:  # 若输入尺寸小于窗口则取消移位
            self.shift_size = 0
            self.window_size = min(self.input_resolution)

        assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"  # 保证参数合法

        self.norm1 = norm_layer(dim)  # 第一层LayerNorm
        self.attn = WindowAttention(
            dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
            qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)  # 窗口注意力

        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()  # 随机深度
        self.norm2 = norm_layer(dim)  # 第二层LayerNorm
        mlp_hidden_dim = int(dim * mlp_ratio)  # 计算MLP隐藏维度
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,
                       act_layer=act_layer, drop=drop)  # 前馈网络
        
        
		#移位窗口:为了让不同块的信息能逐层互相传递;
		#遮罩(mask):防止那些原本“不相邻的远处点”,因为移位“假装在同一窗口里”,结果乱互相关注。
        
        
        if self.shift_size > 0:
            attn_mask = self.calculate_mask(self.input_resolution)  # 预计算移位窗口遮罩
        else:
            attn_mask = None  # 普通窗口无需遮罩

        self.register_buffer("attn_mask", attn_mask)  # 保存遮罩以便前向使用

    def calculate_mask(self, x_size):  # 构造移位窗口的注意力遮罩
        H, W = x_size  # 输入尺寸
        img_mask = torch.zeros((1, H, W, 1))  # 初始化遮罩图

        h_slices = (slice(0, -self.window_size),
                    slice(-self.window_size, -self.shift_size),
                    slice(-self.shift_size, None))  # 垂直方向切片
        w_slices = (slice(0, -self.window_size),
                    slice(-self.window_size, -self.shift_size),
                    slice(-self.shift_size, None))  # 水平方向切片
        cnt = 0
        for h in h_slices:
            for w in w_slices:
                img_mask[:, h, w, :] = cnt  # 对每个区域赋予不同标识
                cnt += 1

        mask_windows = window_partition(img_mask, self.window_size)  # 划分窗口
        mask_windows = mask_windows.view(-1, self.window_size * self.window_size)  # 展平窗口
        attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)  # 计算区域差异
        attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)) \
                               .masked_fill(attn_mask == 0, float(0.0))  # 不同块之间给极小值

        return attn_mask  # 返回注意力遮罩

    def forward(self, x, x_size):  # Swin模块前向
        B, C, H, W = x.shape  # 当前特征的批次与空间尺寸

        x = x.view(B, H, W, C)  # 调整为(B,H,W,C)便于窗口划分
        shortcut = x  # 保存残差分支
        shape = x.view(H * W * B, C)  # 展平成序列
        x = self.norm1(shape)  # 首先做LayerNorm
        x = x.view(B, H, W, C)  # 再恢复回空间布局

        # 1. 处理“移位窗口”
        if self.shift_size > 0:
            shifted_x = torch.roll(
                x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))  # 平移窗口实现交错
        else:
            shifted_x = x  # 不移位直接使用

        # 2. 划分窗口,做窗口注意力
        x_windows = window_partition(shifted_x, self.window_size)  # 划分成窗口
        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # 每窗展平为序列

        # 根据尺寸选择合适的mask
        if self.input_resolution == x_size:
            attn_windows = self.attn(x_windows, mask=self.attn_mask)  # 使用预计算遮罩
        else:
            attn_windows = self.attn(
                x_windows, mask=self.calculate_mask(x_size).to(x.device))  # 尺寸变了就重算遮罩

        # 3. 再把每个窗口结果拼回整幅图
        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
        shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # 拼接回特征图

        # 把前面的平移操作“挪回去”
        if self.shift_size > 0:
            x = torch.roll(
                shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))  # 恢复位置
        else:
            x = shifted_x  # 无需恢复

        # 4. 第一次残差:注意力分支
        x = shortcut + self.drop_path(x)  # 与残差分支相加,并应用随机深度
        #随机深度:训练时让某些层偶尔“休息”,不工作。

        # 5. 第二次残差:MLP 分支
        x = self.norm2(x.view(B * H * W, C))
        x = x.view(B, H, W, C)
        x = x + self.drop_path(self.mlp(x))  # 通过 MLP 再做一次残差

        B, H, W, C = x.shape  # 更新尺寸
        x = x.view(B, C, H, W)  # 转回 (B,C,H,W) 形式

        return x  # 返回 Swin 块输出

    def extra_repr(self) -> str:
        return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
               f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"

    def flops(self):
        flops = 0  # 计算FLOPs
        H, W = self.input_resolution

        flops += self.dim * H * W  # LayerNorm等线性操作
        nW = H * W / self.window_size / self.window_size  # 窗口数量
        flops += nW * self.attn.flops(self.window_size * self.window_size)  # 注意力部分计算量
        flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio  # MLP两层计算量
        flops += self.dim * H * W  # 其他线性操作
        return flops

encode_size_DTRM1 = (nlb4.shape[2],nlb4.shape[3])  # 记录输入Swin层的空间尺寸
swinTransformer_DTRM1 = self.basicLayer(nlb4 , encode_size_DTRM1)  # 通过Swin Transformer进一步建模长程依赖

out1 = self.conv16_16(swinTransformer_DTRM1 )  # 对Transformer输出做卷积细化
out2 =self.convolutional_out (out1)  # 压缩为单通道融合结果并归一化
CB

对 TRB 输出特征经 3×3 卷积和 1×1 卷积生成最终融合图像 I f I_{f} If

I f = S ( C 1 C → 1 ( R ( B ( C 3 C → C ( F t r b ) ) ) ) ) I_{f}=S\left(C_{1}^{C \to 1}\left(\mathcal{R}\left(\mathcal{B}\left(C_{3}^{C \to C}\left(\mathbb{F}_{trb}\right)\right)\right)\right)\right) If=S(C1C1(R(B(C3CC(Ftrb)))))

2.5 损失函数

由于 IVIF 无 ground truth,设计**像素损失 L p L_{p} Lp 结构损失 L s L_{s} Ls **以无监督方式训练模型,总损失为:
L I T F u s e = L p + L s L_{ITFuse}=L_{p}+L_{s} LITFuse=Lp+Ls

  • 像素损失 L p L_{p} Lp:确保融合结果与源图像像素属性相似,计算为:
    L p = L c p 1 + L c p 2 L_{p}=L_{cp}^{1}+L_{cp}^{2} Lp=Lcp1+Lcp2
    其中 L c p z L_{cp}^{z} Lcpz 为 Charbonnier 惩罚函数:
    L c p z = ( I f − I z ) 2 + ε 2 L_{cp}^{z}=\sqrt{\left(I_{f}-I_{z}\right)^{2}+\varepsilon^{2}} Lcpz=(IfIz)2+ε2 ε \varepsilon ε 为惩罚系数)。

  • 结构损失 L s L_{s} Ls:引导融合图像与输入图像结构分布相似,计算为:
    L s = L s 1 + L s 2 L_{s}=L_{s}^{1}+L_{s}^{2} Ls=Ls1+Ls2
    其中 L s z L_{s}^{z} Lsz 为结构相似性度量:
    L s z = 1 − ( 2 μ z μ f + C 1 ) ( 2 σ z f + C 2 ) ( μ z 2 + μ f 2 + C 1 ) ( σ z 2 + σ f 2 + C 2 ) L_{s}^{z}=1-\frac{\left(2 \mu_{z} \mu_{f}+C_{1}\right)\left(2 \sigma_{z f}+C_{2}\right)}{\left(\mu_{z}^{2}+\mu_{f}^{2}+C_{1}\right)\left(\sigma_{z}^{2}+\sigma_{f}^{2}+C_{2}\right)} Lsz=1(μz2+μf2+C1)(σz2+σf2+C2)(2μzμf+C1)(2σzf+C2)
    μ z \mu_{z} μz μ f \mu_{f} μf 为平均强度; σ z f \sigma_{zf} σzf 为协方差; σ z 2 \sigma_{z}^{2} σz2 σ f 2 \sigma_{f}^{2} σf2 为方差; C 1 C_{1} C1 C 2 C_{2} C2 为常数)。

总损失可重表示为:
L I T F u s e = α ⋅ L c p 1 + β ⋅ L c p 2 + γ ⋅ L s 1 + λ ⋅ L s 2 L_{ITFuse}=\alpha \cdot L_{cp}^{1}+\beta \cdot L_{cp}^{2}+\gamma \cdot L_{s}^{1}+\lambda \cdot L_{s}^{2} LITFuse=αLcp1+βLcp2+γLs1+λLs2
α \alpha α β \beta β γ \gamma γ λ \lambda λ 为权重参数,控制各项权衡)。

3. 实验

3.1 数据集和实现细节

在本文中,训练、测试和验证数据集基于 TNO [1]和 RoadScene [11] 数据库构建。具体而言,下载了 348 对源图像,并随机分为包含 288 对图像的训练集、包含 40 对图像的测试集和包含 20 对图像的验证集。为获得足够的训练样本,将训练集中的图像裁剪为 120×120 像素(裁剪步长为 20),最终获得 58708 对图像块,并归一化到 [0,1]。

所有实验均在配备 NVIDIA GeForce RTX 3090 GPU 的设备上使用 PyTorch 框架实现,采用 Adam 优化器 [37] 更新参数(学习率 0.0003,批处理大小 64,总 epoch 数 20)。

超参数设置如下: N = 3 N=3 N=3(特征交互模块数量)、 C = 16 C=16 C=16(卷积块通道数)、 R = 4 R=4 R=4(残差注意力块缩减比)、 L = 3 L=3 L=3(Transformer 块数量), ε = 0.001 \varepsilon=0.001 ε=0.001(Charbonnier 惩罚系数)。对于 3×3 卷积,步长和填充均为 1,采用 “replicate” 边界模式。权重参数 α = 1 \alpha=1 α=1 β = 1 \beta=1 β=1 γ = 4 \gamma=4 γ=4 λ = 4 \lambda=4 λ=4(控制像素损失、结构损失的权衡)。


3.2 对比方法和评估指标

将提出的 ITFuse 与以下方法对比:

  • 传统方法:GTF [5](基于多尺度变换);
  • 深度学习方法:FusionGAN [6](生成对抗网络)、U2Fusion [11](多模态融合网络)、GANMcC [42](基于生成对抗网络的融合)、RFN-Nest [17](残差特征网络)、AUIF [43](自适应交互融合)、TarDAL [20](基于Transformer的跨模态对齐)、UNFusion [18](无监督融合);
  • Transformer方法:SwinFuse [7](基于Swin Transformer)、CMTFusion [30](基于交叉注意力机制)。

所有对比算法的代码均为公开可用或由作者提供,参数设置遵循原始文献以保公平性。

定量评估采用五种指标:

  • Tsallis 熵 Q T E Q_{TE} QTE [38]:衡量源图像与融合结果的依赖程度,聚焦边缘信息保留,值越高表示融合结果边缘保留越好;
  • Chen-Varshney 指标 Q C V Q_{CV} QCV [39]:符合人类视觉系统,通过纹理和细节评估视觉自然度,值越低表示融合结果越自然;
  • 峰值信噪比 PSNR [40]:衡量像素级失真,值越高表示图像质量越好;
  • 多尺度结构相似性指数 MS-SSIM [41]:从亮度、对比度、结构三方面评估相似性(扩展自 SSIM),值越高表示结构保留越好;
  • 均方误差 MSE [36]:反映融合结果与源图像的细节差异,值越低表示细节保留越完整。

指标特性总结: Q T E Q_{TE} QTE、PSNR、MS-SSIM 越高越好; Q C V Q_{CV} QCV、MSE 越低越好。

3.3 结果与讨论

3.3.1 在 TNO 数据集上的结果

在这里插入图片描述

表 1 显示了 ITFuse 与其他十种算法的定量对比。ITFuse 在 Q T E Q_{TE} QTE、PSNR、MS-SSIM 和 MSE 上均取得最佳性能(仅 Q C V Q_{CV} QCV 第二)。图 9 和图 10 展示了 TNO 数据库的融合结果(源图像对及 11 种方法的融合图像):

  • ITFuse 能有效保留源图像的互补信息(红外的热辐射与可见光的纹理细节),生成信息丰富的融合结果;
  • GTF 融合结果模糊,丢失部分结构细节(图 9 第二个放大框);
  • FusionGAN、GANMcC、RFN-Nest 场景细节提取较好但仍有模糊;
  • AUIF、SwinFuse 存在过度强调柱子、背景过暗等不自然问题(图 9 放大框);
  • U2Fusion 颜色失真(图 10 右侧特写);
  • TarDAL、UNFusion、CMTFusion 生成过亮或不真实图像(图 10 右侧特写)。

主观与客观结果均表明,ITFuse 在视觉保真度和定量评估上优于现有方法。

3.3.2 在 RoadScene 数据集上的结果

在这里插入图片描述

表 2 展示了 ITFuse 与其他方法的客观评估结果。ITFuse 在 Q T E Q_{TE} QTE、PSNR、MSE 上优势显著,表明其能有效保留输入图像的重要特征(高散度、低失真)。图 11 和图 12 展示了 RoadScene 数据库的融合结果:

  • 融合目标需同时揭示红外的热信息与可见光的场景细节,避免不真实特征;
  • GTF、TarDAL 未能充分提取红外特征,目标不清晰(图 11 第一个特写);
  • FusionGAN 图像过暗、对比度低(图 12 放大框);
  • GANMcC、RFN-Nest、UNFusion、CMTFusion 可见光特征保留较好但对比度不足(图 12 第一个特写);
  • U2Fusion、AUIF、SwinFuse 生成旗杆、道路标记过突出等不真实结果(图 11 放大框)。

ITFuse 在挖掘红外热辐射与可见光纹理细节、减少冗余属性上表现最佳。

3.4 消融研究

3.4.1 关于网络架构的消融研究

为验证交互式 Transformer 网络的有效性,在验证集上对 ITFuse 进行五种架构修改的消融实验:

  1. Uniformly Input:移除 I 1 I_{1} I1 I 2 I_{2} I2 分支,FIM 仅保留 I c I_{c} Ic 分支;
  2. Separately Input:删除 I c I_{c} Ic 分支, I 1 I_{1} I1 I 2 I_{2} I2 分别输入 FIM(不使用 ITA,因 ITA 需三输入);
  3. w/o RAB:用 CB 替换残差注意力块(RAB);
  4. w/o CMA:用 CB 替换跨模态注意力(CMA)( F 1 n \mathbb{F}_{1}^{n} F1n F c n \mathbb{F}_{c}^{n} Fcn F 2 n \mathbb{F}_{2}^{n} F2n 相加后输入 CB);
  5. w/o TRB:用 CB 替换 Transformer 块(TRB)。

表 3 列出不同架构的定量结果:完整模型在所有指标上最优。图 13 展示主观融合性能:

  • 退化模型(如 Uniformly Input、Separately Input、w/o TRB)丢失可见光纹理细节,融合结果模糊;
  • w/o RAB、w/o CMA 存在严重伪影,生成不理想图像。

3.4.2 关于损失函数的消融研究

由于 IVIF 无 ground truth,损失函数对融合结果的生成至关重要。为验证所设计的像素损失 L p L_p Lp 和结构损失 L s L_s Ls 的重要性,在验证集上开展了消融实验:

  • 移除像素损失:从总损失 L I T F u s e L_{ITFuse} LITFuse 中删除 L p L_p Lp,新损失函数定义为 L 1 = L s L_1 = L_s L1=Ls(式19);
  • 移除结构损失:从总损失中删除 L s L_s Ls,重新定义为 L 2 = L p L_2 = L_p L2=Lp(式20)。
    在这里插入图片描述

表 4 列出了 ITFuse 在不同损失函数下的定量比较结果。可见,完整损失函数 L I T F u s e L_{ITFuse} LITFuse Q T E Q_{TE} QTE、PSNR、MS-SSIM 和 MSE 所有指标上均取得最佳性能,验证了像素损失与结构损失协同优化对融合质量的关键作用。

3.5 泛化实验

为评估 ITFuse 的跨数据集泛化能力,将训练好的模型直接应用于未微调的其他红外-可见光融合数据集(KAIST [44] 和 M3FD [20]),验证其普适性。

3.5.1 在 KAIST 数据集上的泛化实验

KAIST 数据集包含红外与可见光图像对,用于验证模型对不同场景的适应性。表 5 展示了 11 种对比方法在 KAIST 数据集 200 个测试样本上的五个指标平均性能(最佳/次佳方法分别用粗体/下划线标记)。

在这里插入图片描述

  • ITFuse 能有效保留源图像的互补信息(红外的热辐射与可见光的纹理细节),融合结果信息丰富;
  • 其他方法(如 FusionGAN、U2Fusion 等)存在低对比度、整体变暗问题,导致场景细节丢失。

3.5.2 在 M3FD 数据集上的泛化实验

M3FD 数据集包含灰度红外图像与 RGB 可见光图像,需解决通道不匹配问题。实验中采用 RGB-YUV-RGB 颜色变换 [4] 处理输入,表 6 列出了 11 种方法的平均性能(最佳/次佳方法标记同上)。

在这里插入图片描述

图 17 和图 18 展示了两组源图像对的融合结果:

  • 所有方法均能生成合理融合结果,但 ITFuse 在细节保留与对比度上更优;
  • FusionGAN 融合模糊,GTF、U2Fusion 等存在低对比度问题,UNFusion 热信息保留不足,而 ITFuse 综合表现最佳。

3.6 效率比较

在这里插入图片描述

为全面对比方法性能,计算各方法在主流数据集上的平均运行时间(表7)。ITFuse 在 KAIST 数据集融合两幅图像仅需约 0.1 秒,M3FD 数据集约 0.2 秒,显著优于多数 SOTA 方法(如 GTF、U2Fusion 等),表现出实时融合能力。

4. 结论

本文提出了一种用于红外-可见光融合(IVIF)的交互式 Transformer 模型 ITFuse。与现有图像级或特征级融合模型不同,ITFuse 同时考虑多模态的内在与共同特征,通过以下创新实现高性能:

  1. 特征交互模块:设计特征交互模块(FIM),通过同质/异质属性统一提取与交互注意力(ITA)聚合,动态挖掘上下文信息;
  2. 特征重建模块:采用跨模态注意力(CMA)整合多分支特征,残差注意力块(RAB)保留重要特征,Transformer 块(TRB)构建长距离依赖;
  3. 无监督训练策略:结合像素损失 L p L_p Lp(保留像素属性)与结构损失 L s L_s Ls(约束结构分布),无需 ground truth 即可有效训练。

主流数据库(TNO、RoadScene)实验表明,ITFuse 在定量指标( Q T E Q_{TE} QTE、PSNR、MS-SSIM、MSE)与定性评价(视觉保真度、细节保留)上均优于 GTF、FusionGAN 等 10 余种 SOTA 方法。泛化实验(KAIST、M3FD 数据集)验证了其跨场景适应性,效率实验(运行时间)证明其实时性优势。

use 在 KAIST 数据集融合两幅图像仅需约 0.1 秒,M3FD 数据集约 0.2 秒,显著优于多数 SOTA 方法(如 GTF、U2Fusion 等),表现出实时融合能力。

4. 结论

本文提出了一种用于红外-可见光融合(IVIF)的交互式 Transformer 模型 ITFuse。与现有图像级或特征级融合模型不同,ITFuse 同时考虑多模态的内在与共同特征,通过以下创新实现高性能:

  1. 特征交互模块:设计特征交互模块(FIM),通过同质/异质属性统一提取与交互注意力(ITA)聚合,动态挖掘上下文信息;
  2. 特征重建模块:采用跨模态注意力(CMA)整合多分支特征,残差注意力块(RAB)保留重要特征,Transformer 块(TRB)构建长距离依赖;
  3. 无监督训练策略:结合像素损失 L p L_p Lp(保留像素属性)与结构损失 L s L_s Ls(约束结构分布),无需 ground truth 即可有效训练。

主流数据库(TNO、RoadScene)实验表明,ITFuse 在定量指标( Q T E Q_{TE} QTE、PSNR、MS-SSIM、MSE)与定性评价(视觉保真度、细节保留)上均优于 GTF、FusionGAN 等 10 余种 SOTA 方法。泛化实验(KAIST、M3FD 数据集)验证了其跨场景适应性,效率实验(运行时间)证明其实时性优势。

未来工作将聚焦配准-融合集成模型,解决现实场景中图像未完全配准的问题,进一步推动 IVIF 在 3D 工程等领域的应用。

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐