【TMM 2026 】 FDFAM(频域特征聚合模块):频域注意力、多尺度频谱混合、跨模态融合增强
·
模块来源

- Paper:FreDFT: Frequency Domain Fusion Transformer for Visible-Infrared Object Detection
- Code:https://github.com/WenCongWu/FreDFT
模块整体简介
FDFAM(Frequency Domain Feature Aggregation Module)是 FreDFT 中负责跨模态深度融合的核心模块,用于将经过 LFEM 与 CGMM 处理后的 RGB/IR 特征进一步聚合。论文指出,现有方法大多在空间域用 Transformer 建模模态互补性,却忽略了频域在纹理细节与热结构信息解耦方面的优势。为此,FDFAM 由多模态频域注意力 MFDA 和频域前馈层 FDFFL 构成:前者借助 FFT/IFFT 与逐元素乘法建模跨模态相关性,后者通过多尺度频谱分块重组增强全局表征。其核心价值在于,以频域变换替代部分空间域高复杂度关系计算,提升 RGB-IR 互补信息挖掘质量与融合鲁棒性。
模块结构展示
- 论文位置:Section III-D “Frequency domain feature aggregation module”,对应论文第 4-6 页;结构图为
Fig. 5,见论文第 5 页。 - 模块组成:
- 输入:来自 CGMM 的两路特征
E^G_RGB与E^G_IR - 子模块 1:MFDA
- 对每个模态生成
Q/K/V - 对
Q/K做FFT - 在频域中逐元素乘法计算相关性
IFFT回到空间域并归一化- 与另一模态
V做逐元素交互 1×1 Conv + Residual得到两路增强特征
- 对每个模态生成
- 子模块 2:FDFFL
- 每路特征经
LayerNorm - 进入
3×3 / 5×5 / 7×7三个深度卷积分支 - 各分支结果经
FFT - 在通道维分块后做跨尺度重组拼接
- 经
IFFT恢复 - 各分支再做对应深度卷积
- 拼接并经
1×1 Conv降维,再残差加回输入
- 每路特征经
- 输出融合:
- 将 RGB 与 IR 经 FDFFL 后的结果拼接
1×1 Conv + ReLU- 输出最终融合特征
X_f
- 输入:来自 CGMM 的两路特征
设计出发点
- 现有方法缺陷:
- 论文指出多数 RGB-IR 检测方法采用空间域 Transformer 建模跨模态关系,忽略了频域信息的互补价值。
- 在复杂场景下,仅靠空间域交互容易受噪声、遮挡、模态异质性影响,限制特征判别性。
- 空间域 cross-attention 通常依赖矩阵乘法,计算复杂度较高。
- 模块解决的核心问题:
- 如何更有效地挖掘 RGB 纹理与 IR 热结构之间的互补关系。
- 如何在跨模态融合中引入多尺度频域表征,增强复杂环境下的稳健性。
- 设计初衷与创新思路:
- 用频域中的逐元素乘法近似替代空间域中的相关性计算,构造 MFDA。
- 用混合尺度频谱重组策略设计 FDFFL,将不同感受野的频域信息进行交叉混合。
- 将“跨模态相关性建模”和“多尺度频域增强”整合到一个统一融合模块 FDFAM 中。
适用场景与效果
- 适用场景:
- 可见光-红外目标检测
- 低照度、雾天、雨天、烟雾、遮挡等复杂环境
- 需要利用纹理与热目标互补性的跨模态视觉任务
- 性能表现:
- 在 FLIR 上,完整 FreDFT 达到
83.5 mAP50 / 42.6 mAP,表 V 显示仅基线为79.8 / 41.7,加入 FDFAM 后提升到81.7 / 42.3,说明 FDFAM 单独已有显著增益。 - 表 VI 比较了空间域注意力与频域注意力:
MSDA: 82.4 / 42.2MFDA: 83.5 / 42.6- 说明 FDFAM 内的 MFDA 优于空间域版本。
- 表 VII 比较前馈层设计:
standard MLP: 87.6 / 59.3FDFFL: 88.4 / 59.7- 说明频域前馈层优于标准 MLP。
- 在 FLIR 上,完整 FreDFT 达到
- 局限性:
- 论文表 IV 显示 FreDFT 的 FLOPs 为
464.5G,高于多种对比方法,说明频域模块虽有效,但计算开销不低。 - 论文 Discussion 指出,在小目标、重遮挡、低光模糊区域仍存在误检与漏检。
- FDFAM 依赖前置 LFEM/CGMM 提供较高质量特征,其效果并非独立于整体框架存在。
- 论文表 IV 显示 FreDFT 的 FLOPs 为
代码实现与模块对应解析
- 核心代码片段
class FDFFN(nn.Module):
def __init__(self, dim, ffn_expansion_factor, bias):
super(FDFFN, self).__init__()
self.patch_size = 2
hidden_features = int(dim * ffn_expansion_factor)
self.project_in = nn.Conv2d(dim, hidden_features, kernel_size=1, bias=bias)
self.dwconv3x3 = nn.Conv2d(hidden_features, hidden_features, kernel_size=3, stride=1, padding=1,
groups=hidden_features, bias=bias)
self.dwconv5x5 = nn.Conv2d(hidden_features, hidden_features, kernel_size=5, stride=1, padding=2,
groups=hidden_features, bias=bias)
self.dwconv7x7 = nn.Conv2d(hidden_features, hidden_features, kernel_size=7, stride=1, padding=3,
groups=hidden_features, bias=bias)
self.relu3 = nn.ReLU()
self.relu5 = nn.ReLU()
self.relu7 = nn.ReLU()
self.dwconv3x3_1 = nn.Conv2d(hidden_features, hidden_features, kernel_size=3, stride=1, padding=1,
groups=hidden_features, bias=bias)
self.dwconv5x5_1 = nn.Conv2d(hidden_features, hidden_features, kernel_size=5, stride=1, padding=2,
groups=hidden_features, bias=bias)
self.dwconv7x7_1 = nn.Conv2d(hidden_features, hidden_features, kernel_size=7, stride=1, padding=3,
groups=hidden_features, bias=bias)
self.relu3_1 = nn.ReLU()
self.relu5_1 = nn.ReLU()
self.relu7_1 = nn.ReLU()
self.project_out = nn.Conv2d(hidden_features * 3, dim, kernel_size=1, bias=bias)
def forward(self, x):
x = self.project_in(x)
x3 = self.relu3(self.dwconv3x3(x))
x5 = self.relu5(self.dwconv5x5(x))
x7 = self.relu7(self.dwconv7x7(x))
x3_patch_fft = torch.fft.rfft2(x3.float())
x5_patch_fft = torch.fft.rfft2(x5.float())
x7_patch_fft = torch.fft.rfft2(x7.float())
x1_3, x2_3, x3_3 = x3_patch_fft.chunk(3, dim=1)
x1_5, x2_5, x3_5 = x5_patch_fft.chunk(3, dim=1)
x1_7, x2_7, x3_7 = x7_patch_fft.chunk(3, dim=1)
x3_patch_fft = torch.cat([x1_3, x1_5, x1_7], dim=1)
x5_patch_fft = torch.cat([x2_3, x2_5, x2_7], dim=1)
x7_patch_fft = torch.cat([x3_3, x3_5, x3_7], dim=1)
x3 = torch.fft.irfft2(x3_patch_fft, s=(x3.shape[2], x3.shape[3]))
x5 = torch.fft.irfft2(x5_patch_fft, s=(x5.shape[2], x5.shape[3]))
x7 = torch.fft.irfft2(x7_patch_fft, s=(x7.shape[2], x7.shape[3]))
x3 = self.relu3_1(self.dwconv3x3_1(x3))
x5 = self.relu5_1(self.dwconv5x5_1(x5))
x7 = self.relu7_1(self.dwconv7x7_1(x7))
x = torch.cat([x3, x5, x7], dim=1)
x = self.project_out(x)
return x
class FDCA(nn.Module):
def __init__(self, dim, bias):
super(FDCA, self).__init__()
self.norm1 = LayerNorm(dim, LayerNorm_type='WithBias')
self.to_hidden = nn.Conv2d(dim, dim * 6, kernel_size=1, bias=bias)
self.to_hidden_dw = nn.Conv2d(dim * 6, dim * 6, kernel_size=3, stride=1, padding=1, groups=dim * 6, bias=bias)
self.project_out = nn.Conv2d(dim * 2, dim, kernel_size=1, bias=bias)
self.norm = LayerNorm(dim * 2, LayerNorm_type='WithBias')
self.patch_size = 2
def forward(self, x):
rgb_fea = x[0]
ir_fea = x[1]
rgb_fea_norm = self.norm1(rgb_fea)
ir_fea_norm = self.norm1(ir_fea)
rgb_hidden = self.to_hidden(rgb_fea_norm)
ir_hidden = self.to_hidden(ir_fea_norm)
rgb_q, rgb_k, rgb_v = self.to_hidden_dw(rgb_hidden).chunk(3, dim=1)
ir_q, ir_k, ir_v = self.to_hidden_dw(ir_hidden).chunk(3, dim=1)
rgb_q_fft = torch.fft.rfft2(rgb_q.float())
rgb_k_fft = torch.fft.rfft2(rgb_k.float())
ir_q_fft = torch.fft.rfft2(ir_q.float())
ir_k_fft = torch.fft.rfft2(ir_k.float())
rgb_out = rgb_q_fft * rgb_k_fft
rgb_out = torch.fft.irfft2(rgb_out, s=(rgb_q.shape[2], rgb_q.shape[3]))
ir_out = ir_q_fft * ir_k_fft
ir_out = torch.fft.irfft2(ir_out, s=(ir_q.shape[2], ir_q.shape[3]))
rgb_out = self.norm(rgb_out)
ir_out = self.norm(ir_out)
rgb_output = ir_v * rgb_out
ir_output = rgb_v * ir_out
rgb_output = self.project_out(rgb_output)
ir_output = self.project_out(ir_output)
return rgb_output, ir_output
class FDFTM(nn.Module):
def __init__(self, dim):
super(FDFTM, self).__init__()
bias = False
LayerNorm_type = 'WithBias'
ffn_expansion_factor = 3
self.attn = FDCA(dim, bias)
self.norm2 = LayerNorm(dim, LayerNorm_type)
self.ffn = FDFFN(dim, ffn_expansion_factor, bias)
self.concat = Concat(dimension=1)
self.conv = nn.Conv2d(dim * 2, dim, kernel_size=1, bias=bias)
self.relu = nn.ReLU()
def forward(self, x):
rgb_fea = x[0]
ir_fea = x[1]
rgb_fea_out, ir_fea_out = self.attn([rgb_fea, ir_fea])
rgb_att_out = rgb_fea + rgb_fea_out
ir_att_out = ir_fea + ir_fea_out
rgb_fea = rgb_att_out + self.ffn(self.norm2(rgb_att_out))
ir_fea = ir_att_out + self.ffn(self.norm2(ir_att_out))
fea_cat = self.concat([rgb_fea, ir_fea])
out_fea = self.relu(self.conv(fea_cat))
return out_fea
- 参数说明以及使用实例
FDFTM(dim)dim:输入 RGB/IR 特征通道数,也是输出融合特征通道数。
- 内部子模块
FDCA(dim, bias):对应论文的 MFDAFDFFN(dim, ffn_expansion_factor=3, bias=False):对应论文的 FDFFL
- 输入
x[0]:RGB 特征,形状通常为B×C×H×Wx[1]:IR 特征,形状通常为B×C×H×W
- 输出
out_fea:融合后的单路特征,形状为B×C×H×W
- 调用方式
fdfam = FDFTM(dim=256)
rgb = torch.randn(2, 256, 80, 80)
ir = torch.randn(2, 256, 80, 80)
y = fdfam([rgb, ir]) # 输出形状: [2, 256, 80, 80]
- 整体运行流程
- 两路输入先进入
FDCA做频域跨模态注意力增强 - 各自残差相加
- 再分别进入
FDFFN做多尺度频域前馈增强 - 最后拼接两模态结果,经
1×1 Conv + ReLU得到融合输出
- 两路输入先进入
- 代码-论文原理对应说明
-
class FDFTM对应论文 FDFAM 总体- 论文定义:FDFAM 由
MFDA + 两个 FDFFL + concat + conv组成。 - 代码中:
self.attn = FDCA(dim, bias):对应 MFDAself.ffn = FDFFN(dim, ffn_expansion_factor, bias):对应 FDFFLself.concat、self.conv、self.relu:对应公式(4)的最终融合输出
- 注意:论文图 5 画的是每个模态各有一个 FDFFL 分支;代码中复用同一个
self.ffn分别处理 RGB 和 IR,两次调用共享参数。这一点是代码实现细节,论文原文未明确说明“是否共享参数”。
- 论文定义:FDFAM 由
-
class FDCA对应论文 MFDAself.norm1 = LayerNorm(dim, LayerNorm_type='WithBias')- 对应论文公式
(1)中LN(E^G_RGB)和LN(E^G_IR)。
- 对应论文公式
self.to_hidden = nn.Conv2d(dim, dim * 6, kernel_size=1, bias=bias)- 先做
1×1 Conv投影,生成隐藏表示。
- 先做
self.to_hidden_dw = nn.Conv2d(dim * 6, dim * 6, kernel_size=3, ..., groups=dim * 6, bias=bias)- 深度卷积后
chunk(3, dim=1)得到Q/K/V。 - 对应论文“标准
1×1卷积层和3×3 depth-wise convolution生成Q,K,V”。
- 深度卷积后
rgb_q_fft = torch.fft.rfft2(rgb_q.float())rgb_k_fft = torch.fft.rfft2(rgb_k.float())- 对应论文对
Q、K进行FFT。
- 对应论文对
rgb_out = rgb_q_fft * rgb_k_fft- 对应论文中“在频域使用逐元素乘法估计相关性”。
- 同理
ir_out = ir_q_fft * ir_k_fft。
rgb_out = torch.fft.irfft2(rgb_out, s=(rgb_q.shape[2], rgb_q.shape[3]))- 对应
IFFT恢复回空间域。
- 对应
rgb_out = self.norm(rgb_out)- 对应论文中的归一化步骤
LN(...)。
- 对应论文中的归一化步骤
rgb_output = ir_v * rgb_out- 对应论文公式
(1)中LN(...) ⊙ V_IR,即用另一模态的V做交互。
- 对应论文公式
ir_output = rgb_v * ir_out- 对应另一方向交互。
rgb_output = self.project_out(rgb_output)- 对应论文中的
f^{1×1}_{conv}(...)。
- 对应论文中的
- 与论文的一个实现差异:
- 论文公式写的是
F(Q_RGB) ⊙ F(K_RGB)和F(Q_IR) ⊙ F(K_IR),再分别与另一模态V交互。 - 代码严格遵循这一点。
- 但论文表述称 MFDA“捕获不同模态间相关性”,从代码看,这种相关性不是直接
Q_RGB与K_IR相乘,而是先各模态内部生成频域响应,再通过与对方V相乘实现跨模态注入。
- 论文公式写的是
-
class FDFFN对应论文 FDFFLhidden_features = int(dim * ffn_expansion_factor)- 对应论文先扩展通道,构建前馈层的隐藏维度。
self.project_in = nn.Conv2d(dim, hidden_features, kernel_size=1, bias=bias)- 对应公式
(2)中fconv^{1×1}(X_l)。
- 对应公式
self.dwconv3x3 / 5x5 / 7x7- 对应论文三条多尺度分支
3×3 / 5×5 / 7×7 depth-wise conv。
- 对应论文三条多尺度分支
x3 = self.relu3(self.dwconv3x3(x))等- 对应公式
(2)中卷积后跟ReLU。
- 对应公式
torch.fft.rfft2(...)- 对应各分支映射到频域。
x1_3, x2_3, x3_3 = x3_patch_fft.chunk(3, dim=1)等- 对应论文中“将每个频域特征沿通道维切成三块”。
x3_patch_fft = torch.cat([x1_3, x1_5, x1_7], dim=1)等- 对应论文“以 mutual mixing 方式跨尺度重组拼接”,这是 FDFFL 最关键的实现点。
torch.fft.irfft2(...)- 对应公式
(3)中 IFFT 回到空间域。
- 对应公式
self.dwconv3x3_1 / 5x5_1 / 7x7_1- 对应论文中 IFFT 后分别再接不同尺度的 depth-wise conv。
x = torch.cat([x3, x5, x7], dim=1)x = self.project_out(x)- 对应公式
(3)中[X3, X5, X7]拼接后再1×1 Conv降维。
- 对应公式
- 与论文的一个实现差异:
- 论文公式
(3)明确给出\bar{X} = X + fconv^{1×1}([X3, X5, X7]),即 FDFFL 内部含残差。 - 代码
FDFFN.forward()只返回变换结果x,并未在模块内部加回输入。 - 残差是在
FDFTM.forward()中外部完成:rgb_fea = rgb_att_out + self.ffn(self.norm2(rgb_att_out))ir_fea = ir_att_out + self.ffn(self.norm2(ir_att_out))
- 功能上等价于“Pre-Norm + 残差式 FFN”,只是残差位置放在外层封装里。
- 论文公式
-
FDFTM.forward()对应论文完整流程rgb_fea_out, ir_fea_out = self.attn([rgb_fea, ir_fea])- 先做 MFDA。
rgb_att_out = rgb_fea + rgb_fea_outir_att_out = ir_fea + ir_fea_out- 对应论文公式
(1)的残差输出Fout_RGB、Fout_IR。
- 对应论文公式
rgb_fea = rgb_att_out + self.ffn(self.norm2(rgb_att_out))ir_fea = ir_att_out + self.ffn(self.norm2(ir_att_out))- 对应论文图 5 中两路特征各自通过 FDFFL。
fea_cat = self.concat([rgb_fea, ir_fea])out_fea = self.relu(self.conv(fea_cat))- 对应论文公式
(4):X_f = σ(fconv^{1×1}([X_RGB, X_IR]))
- 对应论文公式
return out_fea- 输出最终融合特征。
-
辅助归一化代码对应论文 LayerNorm
LayerNorm、WithBias_LayerNorm、BiasFree_LayerNorm- 是 FDFAM 内 LN 的实现基础。
to_3d/to_4d- 将
B×C×H×W转成B×HW×C做 LayerNorm,再还原。 - 这是代码层面对论文 LN 的具体落地方式。
- 将
总结一下代码与论文的一致性:
- 高度一致部分:
- FDFAM 的总结构
- MFDA 的
Q/K/V -> FFT -> 逐元素乘 -> IFFT -> LN -> 跨模态V交互 - FDFFL 的多尺度频域混合策略
- 最终
concat + 1×1 conv + activation融合输出
- 需要注意的实现细节:
- 代码中模块名分别为
FDFTM / FDCA / FDFFN,而论文中名称是FDFAM / MFDA / FDFFL - 代码里的 FDFFN 残差在外层
FDFTM中实现,而不是在FDFFN内部 - RGB/IR 两路 FDFFL 在代码中共享同一个
self.ffn实例,这属于实现选择,论文未展开说明
- 代码中模块名分别为
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐


所有评论(0)