PR 2024 | ITFuse:一种用于红外与可见光图像融合的交互式_Transformer

题目:ITFuse: An interactive transformer for infrared and visible image fusion
会议:Pattern Recognition(PR)
论文:https://www.sciencedirect.com/science/article/pii/S0031320324005739?via%3Dihub
代码:https://github.com/tthinking/ITFuse
年份:2024
1.摘要和引言
近年来,一些基于Transformer的融合模型被提出用于全局特征提取。然而,这些方法要么是特征级融合模型,要么是图像级融合模型。这些基于Transformer的融合模型均未考虑信息交互,导致对互补信息的挖掘不足。
在本文中,我们提出一种新颖的用于红外与可见光图像融合(IVIF)的端到端无监督交互式Transformer,称为ITFuse。
它由**特征交互模块(FIMs)和特征重建模块(FRM)**组成,用于交替捕获和融合重要信息。
具体而言,为了充分捕获不同模态的同质信息,我们设计了**残差注意力块(RAB)用于特征表示。此外,为了融合各自源图像中存在的独特属性,我们利用交互式注意力(ITA)**来聚合互补特征,实现充分的特征交互与整合。
另外,我们设计了跨模态注意力(CMA)和Transformer块(TRB),用于合并和重建所提取的独特特征与共同特征,从而得到信息丰富的融合结果。
此外,我们还设计了像素损失和结构损失,以无监督方式训练ITFuse,从而实现性能的进一步提升。
主要贡献:
- 我们提出了一种用于 IVIF 的端到端交互式 Transformer(名为 ITFuse),其中设计了交互注意力(ITA)和残差注意力块(RAB),以选择性地提取个体特征并保留共同信息,从而实现充分的互补特征保留。
- 我们提出了跨模态注意力(CMA)和 Transformer 块(TRB),以充分整合所提取的特征并构建多模态长距离关系,从而获得良好的性能。
- 我们设计了像素损失和结构损失,以无监督方式训练 ITFuse,从而进一步提升性能。
2.方法
2.1 问题表述
考虑到 IVIF(红外-可见光融合)的基本目标,即提取互补信息以生成信息丰富的合成融合图像,我们的模型基于以交互方式保留来自各个模态的同质和异质特征,以实现充分的重要信息挖掘和交互。具体而言:
-
红外图像提供的信息表示为:
F i r = F i r u + F c \mathbb{F}_{ir}=\mathbb{F}_{ir}^{u}+\mathbb{F}_{c} Fir=Firu+Fc
其中 F i r u \mathbb{F}_{ir}^{u} Firu 表示红外图像的独特特征(反映物体的热辐射信息), F c \mathbb{F}_{c} Fc 表示红外与可见光图像共享的共同特征(如形状、边缘)。 -
可见光图像包含的信息表示为:
F v i = F v i u + F c \mathbb{F}_{vi}=\mathbb{F}_{vi}^{u}+\mathbb{F}_{c} Fvi=Fviu+Fc
其中 F v i u \mathbb{F}_{vi}^{u} Fviu 表示可见光图像的独特特征(如场景的纹理细节)。 -
红外与可见光图像融合得到的合成图像(同时具有清晰的物体和明确的场景)表示为:
F f = F i r u + F v i u + F c \mathbb{F}_{f}=\mathbb{F}_{ir}^{u}+\mathbb{F}_{vi}^{u}+\mathbb{F}_{c} Ff=Firu+Fviu+Fc
因此,从不同模态中充分挖掘重要特征是关键。通过研究不同模态图像的内在特性,我们设计了交互式 Transformer 以解决这一挑战。
2.2 框架概述
所提出的 ITFuse 框架如图 3 所示。输入图像 I 1 I_{1} I1 和 I 2 I_{2} I2 被送入 N N N 个特征交互模块(FIMs)进行互补信息提取与整合,随后通过特征重建模块(FRM)生成融合结果 I f I_{f} If。具体流程如下:
- 特征提取:为捕获不同模态的独特特征, I 1 I_{1} I1 和 I 2 I_{2} I2 分别输入 FIM。FIM 可准确保留各模态的独特特征。
- 特征交互:对 I 1 I_{1} I1 和 I 2 I_{2} I2 进行通道拼接后输入 FIM,以挖掘输入图像的共同属性。共同特征与独特特征并非单独提取,而是通过交互注意力(ITA)动态交互,从而挖掘更多上下文信息(而非直接构建二维特征图的自注意力)。
2.3 特征交互模块
IVIF 的核心目标是聚合多模态重要特征,因此特征提取对融合结果至关重要。针对多模态特性,我们提出由 I 1 I_{1} I1 分支、 I c I_{c} Ic 分支(共同特征提取)和 I 2 I_{2} I2 分支组成的特征交互模块,以交互方式挖掘共同特征与独特特征(而非单独或统一提取)。
2.3.1 共同特征提取分支( I c I_{c} Ic 分支)
为充分挖掘输入图像的同质属性, I 1 I_{1} I1 和 I 2 I_{2} I2 经通道拼接后输入 I c I_{c} Ic 分支(由卷积块 CB 和残差注意力块 RAB 组成):
卷积块(CB)
用于浅层特征提取,包含两个连续卷积层,公式为:
F c b o = R ( B ( C 3 C → C ( R ( B ( C 3 2 → C ( F c b i ) ) ) ) ) ) \mathbb{F}_{cb}^{o}=\mathcal{R}\left(\mathcal{B}\left(C_{3}^{C \to C}\left(\mathcal{R}\left(\mathcal{B}\left(C_{3}^{2 \to C}\left(\mathbb{F}_{cb}^{i}\right)\right)\right)\right)\right)\right) Fcbo=R(B(C3C→C(R(B(C32→C(Fcbi))))))
其中 F c b i \mathbb{F}_{cb}^{i} Fcbi 和 F c b o \mathbb{F}_{cb}^{o} Fcbo 分别为 CB 的输入和输出特征; C k c 1 → c 2 ( ⋅ ) C_{k}^{c_{1} \to c_{2}}(\cdot) Ckc1→c2(⋅) 表示 k × k k \times k k×k 卷积(输入通道 c 1 c_{1} c1,输出通道 c 2 c_{2} c2); B ( ⋅ ) \mathcal{B}(\cdot) B(⋅) 为批归一化, R ( ⋅ ) \mathcal{R}(\cdot) R(⋅) 为整流线性单元(ReLU)。
cat=torch.cat((ir , vi ), 1) # 将原始红外与可见光图在通道维度拼接
conv1_1=self.conv2_16( cat) # 将双通道输入映射到16通道融合特征
conv1_2 = self.conv16_16(conv1_1) # 对融合特征进行卷积增强
残差注意力块 (RAB)

用于强调红外与可见光共享的重要特征,包含两个协同注意力(CA)残差排列,以非平凡方式提取特征。RAB 从水平和垂直方向构建注意力图:
- 对 CA 输入特征 F c a i ∈ R H × W × C \mathbb{F}_{ca}^{i} \in \mathbb{R}^{H \times W \times C} Fcai∈RH×W×C 同时进行 X 方向(池化核 ( H , 1 ) (H, 1) (H,1))和 Y 方向(池化核 ( 1 , W ) (1, W) (1,W))平均池化;
- 拼接双向池化特征后输入卷积层,输出为:
F c l o = N ( B ( C 1 C → C / R ( F c l i ) ) ) \mathbb{F}_{cl}^{o}=\mathcal{N}\left(\mathcal{B}\left(C_{1}^{C \to C / R}\left(\mathbb{F}_{cl}^{i}\right)\right)\right) Fclo=N(B(C1C→C/R(Fcli)))
其中 F c l i ∈ R ( H + W ) × 1 × C \mathbb{F}_{cl}^{i} \in \mathbb{R}^{(H+W) \times 1 \times C} Fcli∈R(H+W)×1×C(输入), F c l o ∈ R ( H + W ) × 1 × C / R \mathbb{F}_{cl}^{o} \in \mathbb{R}^{(H+W) \times 1 \times C/R} Fclo∈R(H+W)×1×C/R(输出), R R R 为缩减比, N ( ⋅ ) \mathcal{N}(\cdot) N(⋅) 为非线性层(计算为 F n l o = F n l i × S ( F n l i ) \mathbb{F}_{nl}^{o}=\mathbb{F}_{nl}^{i} \times S\left(\mathbb{F}_{nl}^{i}\right) Fnlo=Fnli×S(Fnli), S ( ⋅ ) S(\cdot) S(⋅) 为 sigmoid 激活函数)。
随后, F n l o \mathbb{F}_{nl}^{o} Fnlo 分割为 F x ∈ R H × 1 × C / R \mathbb{F}_{x} \in \mathbb{R}^{H \times 1 \times C/R} Fx∈RH×1×C/R 和 F y ∈ R 1 × W × C / R \mathbb{F}_{y} \in \mathbb{R}^{1 \times W \times C/R} Fy∈R1×W×C/R,分别输入卷积层得到 F x o = S ( C 1 C / R → C ( F x ) ) \mathbb{F}_{x}^{o}=S\left(C_{1}^{C/R \to C}\left(\mathbb{F}_{x}\right)\right) Fxo=S(C1C/R→C(Fx)) 和 F y o = S ( C 1 C / R → C ( F y ) ) \mathbb{F}_{y}^{o}=S\left(C_{1}^{C/R \to C}\left(\mathbb{F}_{y}\right)\right) Fyo=S(C1C/R→C(Fy))(与 F c a i \mathbb{F}_{ca}^{i} Fcai 通道数一致)。最终, F x o \mathbb{F}_{x}^{o} Fxo、 F y o \mathbb{F}_{y}^{o} Fyo 与 F c a i \mathbb{F}_{ca}^{i} Fcai 相乘得到 F c a o \mathbb{F}_{ca}^{o} Fcao。
coorAtt1_1=self.CoordAtt( conv1_2) # 第一次坐标注意力,捕获方向敏感信息,x
coorAtt1_2 = self.CoordAtt( coorAtt1_1) # 第二次坐标注意力,进一步强化空间响应,y
class CoordAtt(nn.Module): # 坐标注意力模块
def __init__(self, inp, oup, reduction=4):
super(CoordAtt, self).__init__()
self.pool_h = nn.AdaptiveAvgPool2d((None, 1)) # 竖着看:只保留每一列的平均
self.pool_w = nn.AdaptiveAvgPool2d((1, None)) # 横着看:只保留每一行的平均
mip = inp // reduction # 先把通道降一点,省算力
self.conv1 = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0) # 卷积
self.bn1 = nn.BatchNorm2d(mip) # 归一化
self.act = h_swish() # 一个激活函数
self.conv_h = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
self.conv_w = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
def forward(self, x):
identity = x # 先拷一份原始输入,最后要乘回去
b, c, h, w = x.size()
# 竖直方向的全局平均:保留高度 h,宽度压成1
x_h = self.pool_h(x) # [B, C, H, 1]
# 水平方向的全局平均:保留宽度 w,高度压成1
x_w = self.pool_w(x).permute(0,1,3,2) # [B, C, 1, W] → [B, C, W, 1]
y = torch.cat([x_h, x_w], dim=2) # 把“竖着看”和“横着看”的结果拼起来
y = self.conv1(y) # 通道降维
y = self.bn1(y)
y = self.act(y)
x_h, x_w = torch.split(y, [h, w], dim=2)
x_w = x_w.permute(0, 1, 3, 2)
a_h = self.conv_h(x_h).sigmoid() # 得到“竖直方向”的注意力图
a_w = self.conv_w(x_w).sigmoid() # 得到“水平方向”的注意力图
out = identity * a_w * a_h # 原图 × 横向权重 × 纵向权重
return out
2.3.2 独特特征提取分支( I 1 I_{1} I1、 I 2 I_{2} I2 分支)
RB
为挖掘各模态独特特征, I 1 I_{1} I1 和 I 2 I_{2} I2 分别输入独立分支(含残差块 RB 和交互注意力 ITA)。RB 用于挖掘浅层内在属性,如图 5 所示,由三个带跳跃连接的卷积层组成(每个层含 3×3 卷积、BN 和 ReLU)。

把最开始那层 layerA1_1 和最后那层 layerA1_3 相加,形成一个残差结构: “输出 = 原始 + 深层卷积后的结果”,这样有利于保留原始信息,又加一点“加工”。
layerA1_1 = self.conv16(ir) # 对红外输入提取初始16通道特征
layerA1_2 = self.conv16_16(layerA1_1) # 红外分支第一次残差卷积
layerA1_3 = self.conv16_16(layerA1_2) # 红外分支第二次残差卷积
resA1=layerA1_1+layerA1_3 # 形成红外分支的第一阶段残差输出
layerB1_1 = self.conv16(vi) # 对可见光输入提取初始16通道特征
layerB1_2= self.conv16_16(layerB1_1) # 可见分支第一次残差卷积
layerB1_3 = self.conv16_16(layerB1_2) # 可见分支第二次残差卷积
resB1 = layerB1_1 + layerB1_3 # 形成可见分支的第一阶段残差输出
ITA
源图像的典型特征通过 ITA 动态交互(而非单独利用),设计为动态关注上下文信息(而非直接构建二维特征图的自注意力)。

对 F K \mathbb{F}_{K} FK(对应 I 1 I_{1} I1 分支的 F 1 i \mathbb{F}_{1}^{i} F1i 和 I 2 I_{2} I2 分支的 F 2 i \mathbb{F}_{2}^{i} F2i)采用 3×3 卷积编码为单模态静态上下文信息;
F V \mathbb{F}_{V} FV(表示 I c I_{c} Ic 分支的 F c i \mathbb{F}_{c}^{i} Fci)通过 1×1 卷积编码为共同特征。
thbb{F}_{K}$ 与 F C o \mathbb{F}_{Co} FCo( I 1 I_{1} I1 分支的 F 2 i \mathbb{F}_{2}^{i} F2i 和 I 2 I_{2} I2 分支的 F 1 i \mathbb{F}_{1}^{i} F1i)元素级求和,经两个 1×1 卷积构建动态注意力矩阵 F D \mathbb{F}_{D} FD。
F V \mathbb{F}_{V} FV 与 F D \mathbb{F}_{D} FD 相乘实现多模态综合表示,最终 ITA 输出 F i t a \mathbb{F}_{ita} Fita 通过 F K \mathbb{F}_{K} FK 与相乘特征的加和整合静态与动态信息: F i t a o = ( F D × F V ) + F K \mathbb{F}_{ita}^{o}=\left(\mathbb{F}_{D} × \mathbb{F}_{V}\right)+\mathbb{F}_{K} Fitao=(FD×FV)+FK
def forward(self, v, k, q):
k = self.key_embed(k) # 先对键 k 做卷积编码(增强结构信息)
qk = q + k # 将查询 q 和键 k 相加(融合两路模态信息)
w = self.embed(qk) # 用 embed 生成单通道权重图 w(注意力权重)
v = self.conv1x1(v) # 对值 v 做1x1卷积,调整通道
mul = w * v # 用权重图 w 逐点调制值特征 v
out = mul + k # 与键特征 k 做残差相加,得到输出
return out # 输出交互后的特征
这里对应关系是:
v1:两张图一起算出来的共享特征(公共 Value)- 红外这条线:
Q_A1 = resA1:红外自己的观点K_A1 = resB1:可见给红外看的“参考”
- 可见这条线:
Q_B1 = resB1:可见自己的观点K_B1 = resA1:红外给可见看的“参考”
2.4 特征重建模块
为将潜在空间的特征重建为融合结果,设计特征重建模块(FRM):包含跨模态注意力(CMA)、残差注意力块(RAB)、Transformer 块(TRB)和卷积块(CB)。
CMA

整合 I 1 I_{1} I1、 I c I_{c} Ic、 I 2 I_{2} I2 分支的挖掘特征( F 1 n \mathbb{F}_{1}^{n} F1n、 F c n \mathbb{F}_{c}^{n} Fcn、 F 2 n \mathbb{F}_{2}^{n} F2n)。 F 1 n \mathbb{F}_{1}^{n} F1n 和 F 2 n \mathbb{F}_{2}^{n} F2n 经 3×3 卷积编码为独特静态特征,相加后经两个卷积层生成独特动态属性; F c n \mathbb{F}_{c}^{n} Fcn 经 1×1 卷积保留共同特征。最终,独特动态特征与共同特征相乘得到整合特征 F c m a o \mathbb{F}_{cma}^{o} Fcmao(保留重要特征,过滤冗余)。
class EqualAtt(nn.Module): # 平衡红外与可见注意力贡献的模块
def __init__(self, dim):
super(EqualAtt, self).__init__()
self.key_embed = Convlutioanl(dim, dim) # 对两个分支的特征做卷积编码
factor = 8 # 通道压缩比
self.embed = nn.Sequential(
nn.Conv2d(dim, dim // factor, 1, bias=False), # 压缩通道
nn.BatchNorm2d(dim // factor), # 归一化
nn.ReLU(inplace=True), # 激活
nn.Conv2d(dim // factor, 1, kernel_size=1), # 生成单通道权重图 w
nn.BatchNorm2d(1)
)
self.conv1x1 = nn.Sequential(
nn.Conv2d(dim, dim, kernel_size=1, stride=1, padding=0, bias=False), # 调整值 v 的通道
nn.BatchNorm2d(dim)
)
def forward(self, v, k, q):
k = self.key_embed(k) # 对“键”做卷积编码
q = self.key_embed(q) # 对“查询”做卷积编码
qk = q + k # 把两个模态的特征加在一起
w = self.embed(qk) # 从 q+k 里算出单通道权重图 w
v = self.conv1x1(v) # 对公共特征 v 做 1×1 卷积调整通道
mul = w * v # 用 w 对 v 做逐点加权
return mul # 返回融合后的特征
RAB
对 CMA 整合的特征从水平和垂直方向重建表示。
coorAtt4_1 = self.CoordAtt(att) # 对平衡后的特征施加坐标注意力
coorAtt4_2 = self.CoordAtt(coorAtt4_1) # 再次加强坐标注意力
nlb4 = att + coorAtt4_2 # 残差融合得到最终共享特征
TRB

把经过 CMA + RAB 的特征丢进 TRB 里,让 Transformer 再整体“咀嚼”一遍。让 整幅图里远处的区域也能互相影响,把细节、结构和整体关系再梳理一遍。
MLP = 多层感知机 = 两个全连接层 + 一个激活函数。
为什么要用“窗口”来算注意力?普通 Transformer 的自注意力,是 整张图里所有位置互相看,太费算力了。先把图按窗口切成一块一块,每个窗口内部自己玩自注意力,下一层换一种切法(移位窗口),再让“原来不同窗口的点”有机会见面。
由 L L L 个基本单元(BUs)组成,每个 BU 包含两个加法操作:
先做一次“窗口自注意力 + 残差”
再做一次“MLP 前馈网络 + 残差”
-
第一个操作: F A = M ( Y ( F b u i ) ) + F b u i \mathbb{F}_{A}=\mathcal{M}\left(\mathcal{Y}\left(\mathbb{F}_{bu}^{i}\right)\right)+\mathbb{F}_{bu}^{i} FA=M(Y(Fbui))+Fbui ( M ( ⋅ ) \mathcal{M}(\cdot) M(⋅) 为多头自注意力, Y ( ⋅ ) \mathcal{Y}(\cdot) Y(⋅) 为层归一化);
-
第二个操作: F b u o = P ( Y ( F A ) ) + F A \mathbb{F}_{bu}^{o}=\mathcal{P}\left(\mathcal{Y}\left(\mathbb{F}_{A}\right)\right)+\mathbb{F}_{A} Fbuo=P(Y(FA))+FA( P ( ⋅ ) \mathcal{P}(\cdot) P(⋅) 为多层感知机)。
TRB 通过上述操作构建长距离依赖关系。
class SwinTransformerBlock(nn.Module): # Swin Transformer基本模块
def __init__(self, dim, input_resolution, num_heads, window_size=1, shift_size=0,
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim # 特征通道数
self.input_resolution = input_resolution # 输入特征的空间尺寸(H, W)
self.num_heads = num_heads # 多头数量
self.window_size = window_size # 注意力窗口尺寸
self.shift_size = shift_size # Shift大小,实现交错窗口
self.mlp_ratio = mlp_ratio # MLP隐藏层倍率
if min(self.input_resolution) <= self.window_size: # 若输入尺寸小于窗口则取消移位
self.shift_size = 0
self.window_size = min(self.input_resolution)
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" # 保证参数合法
self.norm1 = norm_layer(dim) # 第一层LayerNorm
self.attn = WindowAttention(
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) # 窗口注意力
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() # 随机深度
self.norm2 = norm_layer(dim) # 第二层LayerNorm
mlp_hidden_dim = int(dim * mlp_ratio) # 计算MLP隐藏维度
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,
act_layer=act_layer, drop=drop) # 前馈网络
#移位窗口:为了让不同块的信息能逐层互相传递;
#遮罩(mask):防止那些原本“不相邻的远处点”,因为移位“假装在同一窗口里”,结果乱互相关注。
if self.shift_size > 0:
attn_mask = self.calculate_mask(self.input_resolution) # 预计算移位窗口遮罩
else:
attn_mask = None # 普通窗口无需遮罩
self.register_buffer("attn_mask", attn_mask) # 保存遮罩以便前向使用
def calculate_mask(self, x_size): # 构造移位窗口的注意力遮罩
H, W = x_size # 输入尺寸
img_mask = torch.zeros((1, H, W, 1)) # 初始化遮罩图
h_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None)) # 垂直方向切片
w_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None)) # 水平方向切片
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt # 对每个区域赋予不同标识
cnt += 1
mask_windows = window_partition(img_mask, self.window_size) # 划分窗口
mask_windows = mask_windows.view(-1, self.window_size * self.window_size) # 展平窗口
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) # 计算区域差异
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)) \
.masked_fill(attn_mask == 0, float(0.0)) # 不同块之间给极小值
return attn_mask # 返回注意力遮罩
def forward(self, x, x_size): # Swin模块前向
B, C, H, W = x.shape # 当前特征的批次与空间尺寸
x = x.view(B, H, W, C) # 调整为(B,H,W,C)便于窗口划分
shortcut = x # 保存残差分支
shape = x.view(H * W * B, C) # 展平成序列
x = self.norm1(shape) # 首先做LayerNorm
x = x.view(B, H, W, C) # 再恢复回空间布局
# 1. 处理“移位窗口”
if self.shift_size > 0:
shifted_x = torch.roll(
x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) # 平移窗口实现交错
else:
shifted_x = x # 不移位直接使用
# 2. 划分窗口,做窗口注意力
x_windows = window_partition(shifted_x, self.window_size) # 划分成窗口
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # 每窗展平为序列
# 根据尺寸选择合适的mask
if self.input_resolution == x_size:
attn_windows = self.attn(x_windows, mask=self.attn_mask) # 使用预计算遮罩
else:
attn_windows = self.attn(
x_windows, mask=self.calculate_mask(x_size).to(x.device)) # 尺寸变了就重算遮罩
# 3. 再把每个窗口结果拼回整幅图
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # 拼接回特征图
# 把前面的平移操作“挪回去”
if self.shift_size > 0:
x = torch.roll(
shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) # 恢复位置
else:
x = shifted_x # 无需恢复
# 4. 第一次残差:注意力分支
x = shortcut + self.drop_path(x) # 与残差分支相加,并应用随机深度
#随机深度:训练时让某些层偶尔“休息”,不工作。
# 5. 第二次残差:MLP 分支
x = self.norm2(x.view(B * H * W, C))
x = x.view(B, H, W, C)
x = x + self.drop_path(self.mlp(x)) # 通过 MLP 再做一次残差
B, H, W, C = x.shape # 更新尺寸
x = x.view(B, C, H, W) # 转回 (B,C,H,W) 形式
return x # 返回 Swin 块输出
def extra_repr(self) -> str:
return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
def flops(self):
flops = 0 # 计算FLOPs
H, W = self.input_resolution
flops += self.dim * H * W # LayerNorm等线性操作
nW = H * W / self.window_size / self.window_size # 窗口数量
flops += nW * self.attn.flops(self.window_size * self.window_size) # 注意力部分计算量
flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio # MLP两层计算量
flops += self.dim * H * W # 其他线性操作
return flops
encode_size_DTRM1 = (nlb4.shape[2],nlb4.shape[3]) # 记录输入Swin层的空间尺寸
swinTransformer_DTRM1 = self.basicLayer(nlb4 , encode_size_DTRM1) # 通过Swin Transformer进一步建模长程依赖
out1 = self.conv16_16(swinTransformer_DTRM1 ) # 对Transformer输出做卷积细化
out2 =self.convolutional_out (out1) # 压缩为单通道融合结果并归一化
CB
对 TRB 输出特征经 3×3 卷积和 1×1 卷积生成最终融合图像 I f I_{f} If:
I f = S ( C 1 C → 1 ( R ( B ( C 3 C → C ( F t r b ) ) ) ) ) I_{f}=S\left(C_{1}^{C \to 1}\left(\mathcal{R}\left(\mathcal{B}\left(C_{3}^{C \to C}\left(\mathbb{F}_{trb}\right)\right)\right)\right)\right) If=S(C1C→1(R(B(C3C→C(Ftrb)))))
2.5 损失函数
由于 IVIF 无 ground truth,设计**像素损失 L p L_{p} Lp 和结构损失 L s L_{s} Ls **以无监督方式训练模型,总损失为:
L I T F u s e = L p + L s L_{ITFuse}=L_{p}+L_{s} LITFuse=Lp+Ls
-
像素损失 L p L_{p} Lp:确保融合结果与源图像像素属性相似,计算为:
L p = L c p 1 + L c p 2 L_{p}=L_{cp}^{1}+L_{cp}^{2} Lp=Lcp1+Lcp2
其中 L c p z L_{cp}^{z} Lcpz 为 Charbonnier 惩罚函数:
L c p z = ( I f − I z ) 2 + ε 2 L_{cp}^{z}=\sqrt{\left(I_{f}-I_{z}\right)^{2}+\varepsilon^{2}} Lcpz=(If−Iz)2+ε2( ε \varepsilon ε 为惩罚系数)。 -
结构损失 L s L_{s} Ls:引导融合图像与输入图像结构分布相似,计算为:
L s = L s 1 + L s 2 L_{s}=L_{s}^{1}+L_{s}^{2} Ls=Ls1+Ls2
其中 L s z L_{s}^{z} Lsz 为结构相似性度量:
L s z = 1 − ( 2 μ z μ f + C 1 ) ( 2 σ z f + C 2 ) ( μ z 2 + μ f 2 + C 1 ) ( σ z 2 + σ f 2 + C 2 ) L_{s}^{z}=1-\frac{\left(2 \mu_{z} \mu_{f}+C_{1}\right)\left(2 \sigma_{z f}+C_{2}\right)}{\left(\mu_{z}^{2}+\mu_{f}^{2}+C_{1}\right)\left(\sigma_{z}^{2}+\sigma_{f}^{2}+C_{2}\right)} Lsz=1−(μz2+μf2+C1)(σz2+σf2+C2)(2μzμf+C1)(2σzf+C2)
( μ z \mu_{z} μz、 μ f \mu_{f} μf 为平均强度; σ z f \sigma_{zf} σzf 为协方差; σ z 2 \sigma_{z}^{2} σz2、 σ f 2 \sigma_{f}^{2} σf2 为方差; C 1 C_{1} C1、 C 2 C_{2} C2 为常数)。
总损失可重表示为:
L I T F u s e = α ⋅ L c p 1 + β ⋅ L c p 2 + γ ⋅ L s 1 + λ ⋅ L s 2 L_{ITFuse}=\alpha \cdot L_{cp}^{1}+\beta \cdot L_{cp}^{2}+\gamma \cdot L_{s}^{1}+\lambda \cdot L_{s}^{2} LITFuse=α⋅Lcp1+β⋅Lcp2+γ⋅Ls1+λ⋅Ls2
( α \alpha α、 β \beta β、 γ \gamma γ、 λ \lambda λ 为权重参数,控制各项权衡)。
3. 实验
3.1 数据集和实现细节
在本文中,训练、测试和验证数据集基于 TNO [1]和 RoadScene [11] 数据库构建。具体而言,下载了 348 对源图像,并随机分为包含 288 对图像的训练集、包含 40 对图像的测试集和包含 20 对图像的验证集。为获得足够的训练样本,将训练集中的图像裁剪为 120×120 像素(裁剪步长为 20),最终获得 58708 对图像块,并归一化到 [0,1]。
所有实验均在配备 NVIDIA GeForce RTX 3090 GPU 的设备上使用 PyTorch 框架实现,采用 Adam 优化器 [37] 更新参数(学习率 0.0003,批处理大小 64,总 epoch 数 20)。
超参数设置如下: N = 3 N=3 N=3(特征交互模块数量)、 C = 16 C=16 C=16(卷积块通道数)、 R = 4 R=4 R=4(残差注意力块缩减比)、 L = 3 L=3 L=3(Transformer 块数量), ε = 0.001 \varepsilon=0.001 ε=0.001(Charbonnier 惩罚系数)。对于 3×3 卷积,步长和填充均为 1,采用 “replicate” 边界模式。权重参数 α = 1 \alpha=1 α=1、 β = 1 \beta=1 β=1、 γ = 4 \gamma=4 γ=4、 λ = 4 \lambda=4 λ=4(控制像素损失、结构损失的权衡)。
3.2 对比方法和评估指标
将提出的 ITFuse 与以下方法对比:
- 传统方法:GTF [5](基于多尺度变换);
- 深度学习方法:FusionGAN [6](生成对抗网络)、U2Fusion [11](多模态融合网络)、GANMcC [42](基于生成对抗网络的融合)、RFN-Nest [17](残差特征网络)、AUIF [43](自适应交互融合)、TarDAL [20](基于Transformer的跨模态对齐)、UNFusion [18](无监督融合);
- Transformer方法:SwinFuse [7](基于Swin Transformer)、CMTFusion [30](基于交叉注意力机制)。
所有对比算法的代码均为公开可用或由作者提供,参数设置遵循原始文献以保公平性。
定量评估采用五种指标:
- Tsallis 熵 Q T E Q_{TE} QTE [38]:衡量源图像与融合结果的依赖程度,聚焦边缘信息保留,值越高表示融合结果边缘保留越好;
- Chen-Varshney 指标 Q C V Q_{CV} QCV [39]:符合人类视觉系统,通过纹理和细节评估视觉自然度,值越低表示融合结果越自然;
- 峰值信噪比 PSNR [40]:衡量像素级失真,值越高表示图像质量越好;
- 多尺度结构相似性指数 MS-SSIM [41]:从亮度、对比度、结构三方面评估相似性(扩展自 SSIM),值越高表示结构保留越好;
- 均方误差 MSE [36]:反映融合结果与源图像的细节差异,值越低表示细节保留越完整。
指标特性总结: Q T E Q_{TE} QTE、PSNR、MS-SSIM 越高越好; Q C V Q_{CV} QCV、MSE 越低越好。
3.3 结果与讨论
3.3.1 在 TNO 数据集上的结果

表 1 显示了 ITFuse 与其他十种算法的定量对比。ITFuse 在 Q T E Q_{TE} QTE、PSNR、MS-SSIM 和 MSE 上均取得最佳性能(仅 Q C V Q_{CV} QCV 第二)。图 9 和图 10 展示了 TNO 数据库的融合结果(源图像对及 11 种方法的融合图像):
- ITFuse 能有效保留源图像的互补信息(红外的热辐射与可见光的纹理细节),生成信息丰富的融合结果;
- GTF 融合结果模糊,丢失部分结构细节(图 9 第二个放大框);
- FusionGAN、GANMcC、RFN-Nest 场景细节提取较好但仍有模糊;
- AUIF、SwinFuse 存在过度强调柱子、背景过暗等不自然问题(图 9 放大框);
- U2Fusion 颜色失真(图 10 右侧特写);
- TarDAL、UNFusion、CMTFusion 生成过亮或不真实图像(图 10 右侧特写)。
主观与客观结果均表明,ITFuse 在视觉保真度和定量评估上优于现有方法。
3.3.2 在 RoadScene 数据集上的结果

表 2 展示了 ITFuse 与其他方法的客观评估结果。ITFuse 在 Q T E Q_{TE} QTE、PSNR、MSE 上优势显著,表明其能有效保留输入图像的重要特征(高散度、低失真)。图 11 和图 12 展示了 RoadScene 数据库的融合结果:
- 融合目标需同时揭示红外的热信息与可见光的场景细节,避免不真实特征;
- GTF、TarDAL 未能充分提取红外特征,目标不清晰(图 11 第一个特写);
- FusionGAN 图像过暗、对比度低(图 12 放大框);
- GANMcC、RFN-Nest、UNFusion、CMTFusion 可见光特征保留较好但对比度不足(图 12 第一个特写);
- U2Fusion、AUIF、SwinFuse 生成旗杆、道路标记过突出等不真实结果(图 11 放大框)。
ITFuse 在挖掘红外热辐射与可见光纹理细节、减少冗余属性上表现最佳。
3.4 消融研究
3.4.1 关于网络架构的消融研究
为验证交互式 Transformer 网络的有效性,在验证集上对 ITFuse 进行五种架构修改的消融实验:
- Uniformly Input:移除 I 1 I_{1} I1、 I 2 I_{2} I2 分支,FIM 仅保留 I c I_{c} Ic 分支;
- Separately Input:删除 I c I_{c} Ic 分支, I 1 I_{1} I1、 I 2 I_{2} I2 分别输入 FIM(不使用 ITA,因 ITA 需三输入);
- w/o RAB:用 CB 替换残差注意力块(RAB);
- w/o CMA:用 CB 替换跨模态注意力(CMA)( F 1 n \mathbb{F}_{1}^{n} F1n、 F c n \mathbb{F}_{c}^{n} Fcn、 F 2 n \mathbb{F}_{2}^{n} F2n 相加后输入 CB);
- w/o TRB:用 CB 替换 Transformer 块(TRB)。
表 3 列出不同架构的定量结果:完整模型在所有指标上最优。图 13 展示主观融合性能:
- 退化模型(如 Uniformly Input、Separately Input、w/o TRB)丢失可见光纹理细节,融合结果模糊;
- w/o RAB、w/o CMA 存在严重伪影,生成不理想图像。
3.4.2 关于损失函数的消融研究
由于 IVIF 无 ground truth,损失函数对融合结果的生成至关重要。为验证所设计的像素损失 L p L_p Lp 和结构损失 L s L_s Ls 的重要性,在验证集上开展了消融实验:
- 移除像素损失:从总损失 L I T F u s e L_{ITFuse} LITFuse 中删除 L p L_p Lp,新损失函数定义为 L 1 = L s L_1 = L_s L1=Ls(式19);
- 移除结构损失:从总损失中删除 L s L_s Ls,重新定义为 L 2 = L p L_2 = L_p L2=Lp(式20)。

表 4 列出了 ITFuse 在不同损失函数下的定量比较结果。可见,完整损失函数 L I T F u s e L_{ITFuse} LITFuse 在 Q T E Q_{TE} QTE、PSNR、MS-SSIM 和 MSE 所有指标上均取得最佳性能,验证了像素损失与结构损失协同优化对融合质量的关键作用。
3.5 泛化实验
为评估 ITFuse 的跨数据集泛化能力,将训练好的模型直接应用于未微调的其他红外-可见光融合数据集(KAIST [44] 和 M3FD [20]),验证其普适性。
3.5.1 在 KAIST 数据集上的泛化实验
KAIST 数据集包含红外与可见光图像对,用于验证模型对不同场景的适应性。表 5 展示了 11 种对比方法在 KAIST 数据集 200 个测试样本上的五个指标平均性能(最佳/次佳方法分别用粗体/下划线标记)。

- ITFuse 能有效保留源图像的互补信息(红外的热辐射与可见光的纹理细节),融合结果信息丰富;
- 其他方法(如 FusionGAN、U2Fusion 等)存在低对比度、整体变暗问题,导致场景细节丢失。
3.5.2 在 M3FD 数据集上的泛化实验
M3FD 数据集包含灰度红外图像与 RGB 可见光图像,需解决通道不匹配问题。实验中采用 RGB-YUV-RGB 颜色变换 [4] 处理输入,表 6 列出了 11 种方法的平均性能(最佳/次佳方法标记同上)。

图 17 和图 18 展示了两组源图像对的融合结果:
- 所有方法均能生成合理融合结果,但 ITFuse 在细节保留与对比度上更优;
- FusionGAN 融合模糊,GTF、U2Fusion 等存在低对比度问题,UNFusion 热信息保留不足,而 ITFuse 综合表现最佳。
3.6 效率比较

为全面对比方法性能,计算各方法在主流数据集上的平均运行时间(表7)。ITFuse 在 KAIST 数据集融合两幅图像仅需约 0.1 秒,M3FD 数据集约 0.2 秒,显著优于多数 SOTA 方法(如 GTF、U2Fusion 等),表现出实时融合能力。
4. 结论
本文提出了一种用于红外-可见光融合(IVIF)的交互式 Transformer 模型 ITFuse。与现有图像级或特征级融合模型不同,ITFuse 同时考虑多模态的内在与共同特征,通过以下创新实现高性能:
- 特征交互模块:设计特征交互模块(FIM),通过同质/异质属性统一提取与交互注意力(ITA)聚合,动态挖掘上下文信息;
- 特征重建模块:采用跨模态注意力(CMA)整合多分支特征,残差注意力块(RAB)保留重要特征,Transformer 块(TRB)构建长距离依赖;
- 无监督训练策略:结合像素损失 L p L_p Lp(保留像素属性)与结构损失 L s L_s Ls(约束结构分布),无需 ground truth 即可有效训练。
主流数据库(TNO、RoadScene)实验表明,ITFuse 在定量指标( Q T E Q_{TE} QTE、PSNR、MS-SSIM、MSE)与定性评价(视觉保真度、细节保留)上均优于 GTF、FusionGAN 等 10 余种 SOTA 方法。泛化实验(KAIST、M3FD 数据集)验证了其跨场景适应性,效率实验(运行时间)证明其实时性优势。
use 在 KAIST 数据集融合两幅图像仅需约 0.1 秒,M3FD 数据集约 0.2 秒,显著优于多数 SOTA 方法(如 GTF、U2Fusion 等),表现出实时融合能力。
4. 结论
本文提出了一种用于红外-可见光融合(IVIF)的交互式 Transformer 模型 ITFuse。与现有图像级或特征级融合模型不同,ITFuse 同时考虑多模态的内在与共同特征,通过以下创新实现高性能:
- 特征交互模块:设计特征交互模块(FIM),通过同质/异质属性统一提取与交互注意力(ITA)聚合,动态挖掘上下文信息;
- 特征重建模块:采用跨模态注意力(CMA)整合多分支特征,残差注意力块(RAB)保留重要特征,Transformer 块(TRB)构建长距离依赖;
- 无监督训练策略:结合像素损失 L p L_p Lp(保留像素属性)与结构损失 L s L_s Ls(约束结构分布),无需 ground truth 即可有效训练。
主流数据库(TNO、RoadScene)实验表明,ITFuse 在定量指标( Q T E Q_{TE} QTE、PSNR、MS-SSIM、MSE)与定性评价(视觉保真度、细节保留)上均优于 GTF、FusionGAN 等 10 余种 SOTA 方法。泛化实验(KAIST、M3FD 数据集)验证了其跨场景适应性,效率实验(运行时间)证明其实时性优势。
未来工作将聚焦配准-融合集成模型,解决现实场景中图像未完全配准的问题,进一步推动 IVIF 在 3D 工程等领域的应用。
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐



所有评论(0)