模块来源

在这里插入图片描述

模块整体简介

FDFAM(Frequency Domain Feature Aggregation Module)是 FreDFT 中负责跨模态深度融合的核心模块,用于将经过 LFEM 与 CGMM 处理后的 RGB/IR 特征进一步聚合。论文指出,现有方法大多在空间域用 Transformer 建模模态互补性,却忽略了频域在纹理细节与热结构信息解耦方面的优势。为此,FDFAM 由多模态频域注意力 MFDA 和频域前馈层 FDFFL 构成:前者借助 FFT/IFFT 与逐元素乘法建模跨模态相关性,后者通过多尺度频谱分块重组增强全局表征。其核心价值在于,以频域变换替代部分空间域高复杂度关系计算,提升 RGB-IR 互补信息挖掘质量与融合鲁棒性。

模块结构展示

  • 论文位置:Section III-D “Frequency domain feature aggregation module”,对应论文第 4-6 页;结构图为 Fig. 5,见论文第 5 页。
  • 模块组成
    • 输入:来自 CGMM 的两路特征 E^G_RGBE^G_IR
    • 子模块 1:MFDA
      • 对每个模态生成 Q/K/V
      • Q/KFFT
      • 在频域中逐元素乘法计算相关性
      • IFFT 回到空间域并归一化
      • 与另一模态 V 做逐元素交互
      • 1×1 Conv + Residual 得到两路增强特征
    • 子模块 2:FDFFL
      • 每路特征经 LayerNorm
      • 进入 3×3 / 5×5 / 7×7 三个深度卷积分支
      • 各分支结果经 FFT
      • 在通道维分块后做跨尺度重组拼接
      • IFFT 恢复
      • 各分支再做对应深度卷积
      • 拼接并经 1×1 Conv 降维,再残差加回输入
    • 输出融合:
      • 将 RGB 与 IR 经 FDFFL 后的结果拼接
      • 1×1 Conv + ReLU
      • 输出最终融合特征 X_f

设计出发点

  • 现有方法缺陷
    • 论文指出多数 RGB-IR 检测方法采用空间域 Transformer 建模跨模态关系,忽略了频域信息的互补价值。
    • 在复杂场景下,仅靠空间域交互容易受噪声、遮挡、模态异质性影响,限制特征判别性。
    • 空间域 cross-attention 通常依赖矩阵乘法,计算复杂度较高。
  • 模块解决的核心问题
    • 如何更有效地挖掘 RGB 纹理与 IR 热结构之间的互补关系。
    • 如何在跨模态融合中引入多尺度频域表征,增强复杂环境下的稳健性。
  • 设计初衷与创新思路
    • 用频域中的逐元素乘法近似替代空间域中的相关性计算,构造 MFDA
    • 用混合尺度频谱重组策略设计 FDFFL,将不同感受野的频域信息进行交叉混合。
    • 将“跨模态相关性建模”和“多尺度频域增强”整合到一个统一融合模块 FDFAM 中。

适用场景与效果

  • 适用场景
    • 可见光-红外目标检测
    • 低照度、雾天、雨天、烟雾、遮挡等复杂环境
    • 需要利用纹理与热目标互补性的跨模态视觉任务
  • 性能表现
    • 在 FLIR 上,完整 FreDFT 达到 83.5 mAP50 / 42.6 mAP,表 V 显示仅基线为 79.8 / 41.7,加入 FDFAM 后提升到 81.7 / 42.3,说明 FDFAM 单独已有显著增益。
    • 表 VI 比较了空间域注意力与频域注意力:
      • MSDA: 82.4 / 42.2
      • MFDA: 83.5 / 42.6
      • 说明 FDFAM 内的 MFDA 优于空间域版本。
    • 表 VII 比较前馈层设计:
      • standard MLP: 87.6 / 59.3
      • FDFFL: 88.4 / 59.7
      • 说明频域前馈层优于标准 MLP。
  • 局限性
    • 论文表 IV 显示 FreDFT 的 FLOPs 为 464.5G,高于多种对比方法,说明频域模块虽有效,但计算开销不低。
    • 论文 Discussion 指出,在小目标、重遮挡、低光模糊区域仍存在误检与漏检。
    • FDFAM 依赖前置 LFEM/CGMM 提供较高质量特征,其效果并非独立于整体框架存在。

代码实现与模块对应解析

  1. 核心代码片段
class FDFFN(nn.Module):
    def __init__(self, dim, ffn_expansion_factor, bias):
        super(FDFFN, self).__init__()

        self.patch_size = 2
        hidden_features = int(dim * ffn_expansion_factor)

        self.project_in = nn.Conv2d(dim, hidden_features, kernel_size=1, bias=bias)

        self.dwconv3x3 = nn.Conv2d(hidden_features, hidden_features, kernel_size=3, stride=1, padding=1,
                                   groups=hidden_features, bias=bias)
        self.dwconv5x5 = nn.Conv2d(hidden_features, hidden_features, kernel_size=5, stride=1, padding=2,
                                   groups=hidden_features, bias=bias)
        self.dwconv7x7 = nn.Conv2d(hidden_features, hidden_features, kernel_size=7, stride=1, padding=3,
                                   groups=hidden_features, bias=bias)
        self.relu3 = nn.ReLU()
        self.relu5 = nn.ReLU()
        self.relu7 = nn.ReLU()

        self.dwconv3x3_1 = nn.Conv2d(hidden_features, hidden_features, kernel_size=3, stride=1, padding=1,
                                     groups=hidden_features, bias=bias)
        self.dwconv5x5_1 = nn.Conv2d(hidden_features, hidden_features, kernel_size=5, stride=1, padding=2,
                                     groups=hidden_features, bias=bias)
        self.dwconv7x7_1 = nn.Conv2d(hidden_features, hidden_features, kernel_size=7, stride=1, padding=3,
                                     groups=hidden_features, bias=bias)

        self.relu3_1 = nn.ReLU()
        self.relu5_1 = nn.ReLU()
        self.relu7_1 = nn.ReLU()

        self.project_out = nn.Conv2d(hidden_features * 3, dim, kernel_size=1, bias=bias)

    def forward(self, x):
        x = self.project_in(x)
        x3 = self.relu3(self.dwconv3x3(x))
        x5 = self.relu5(self.dwconv5x5(x))
        x7 = self.relu7(self.dwconv7x7(x))

        x3_patch_fft = torch.fft.rfft2(x3.float())
        x5_patch_fft = torch.fft.rfft2(x5.float())
        x7_patch_fft = torch.fft.rfft2(x7.float())

        x1_3, x2_3, x3_3 = x3_patch_fft.chunk(3, dim=1)
        x1_5, x2_5, x3_5 = x5_patch_fft.chunk(3, dim=1)
        x1_7, x2_7, x3_7 = x7_patch_fft.chunk(3, dim=1)

        x3_patch_fft = torch.cat([x1_3, x1_5, x1_7], dim=1)
        x5_patch_fft = torch.cat([x2_3, x2_5, x2_7], dim=1)
        x7_patch_fft = torch.cat([x3_3, x3_5, x3_7], dim=1)

        x3 = torch.fft.irfft2(x3_patch_fft, s=(x3.shape[2], x3.shape[3]))
        x5 = torch.fft.irfft2(x5_patch_fft, s=(x5.shape[2], x5.shape[3]))
        x7 = torch.fft.irfft2(x7_patch_fft, s=(x7.shape[2], x7.shape[3]))

        x3 = self.relu3_1(self.dwconv3x3_1(x3))
        x5 = self.relu5_1(self.dwconv5x5_1(x5))
        x7 = self.relu7_1(self.dwconv7x7_1(x7))

        x = torch.cat([x3, x5, x7], dim=1)
        x = self.project_out(x)
        return x


class FDCA(nn.Module):
    def __init__(self, dim, bias):
        super(FDCA, self).__init__()
        self.norm1 = LayerNorm(dim, LayerNorm_type='WithBias')

        self.to_hidden = nn.Conv2d(dim, dim * 6, kernel_size=1, bias=bias)
        self.to_hidden_dw = nn.Conv2d(dim * 6, dim * 6, kernel_size=3, stride=1, padding=1, groups=dim * 6, bias=bias)

        self.project_out = nn.Conv2d(dim * 2, dim, kernel_size=1, bias=bias)

        self.norm = LayerNorm(dim * 2, LayerNorm_type='WithBias')

        self.patch_size = 2

    def forward(self, x):
        rgb_fea = x[0]
        ir_fea = x[1]

        rgb_fea_norm = self.norm1(rgb_fea)
        ir_fea_norm = self.norm1(ir_fea)

        rgb_hidden = self.to_hidden(rgb_fea_norm)
        ir_hidden = self.to_hidden(ir_fea_norm)

        rgb_q, rgb_k, rgb_v = self.to_hidden_dw(rgb_hidden).chunk(3, dim=1)
        ir_q, ir_k, ir_v = self.to_hidden_dw(ir_hidden).chunk(3, dim=1)

        rgb_q_fft = torch.fft.rfft2(rgb_q.float())
        rgb_k_fft = torch.fft.rfft2(rgb_k.float())

        ir_q_fft = torch.fft.rfft2(ir_q.float())
        ir_k_fft = torch.fft.rfft2(ir_k.float())

        rgb_out = rgb_q_fft * rgb_k_fft
        rgb_out = torch.fft.irfft2(rgb_out, s=(rgb_q.shape[2], rgb_q.shape[3]))

        ir_out = ir_q_fft * ir_k_fft
        ir_out = torch.fft.irfft2(ir_out, s=(ir_q.shape[2], ir_q.shape[3]))

        rgb_out = self.norm(rgb_out)
        ir_out = self.norm(ir_out)

        rgb_output = ir_v * rgb_out
        ir_output = rgb_v * ir_out

        rgb_output = self.project_out(rgb_output)
        ir_output = self.project_out(ir_output)

        return rgb_output, ir_output


class FDFTM(nn.Module):
    def __init__(self, dim):
        super(FDFTM, self).__init__()
        bias = False
        LayerNorm_type = 'WithBias'
        ffn_expansion_factor = 3

        self.attn = FDCA(dim, bias)
        self.norm2 = LayerNorm(dim, LayerNorm_type)
        self.ffn = FDFFN(dim, ffn_expansion_factor, bias)

        self.concat = Concat(dimension=1)
        self.conv = nn.Conv2d(dim * 2, dim, kernel_size=1, bias=bias)
        self.relu = nn.ReLU()

    def forward(self, x):
        rgb_fea = x[0]
        ir_fea = x[1]

        rgb_fea_out, ir_fea_out = self.attn([rgb_fea, ir_fea])
        rgb_att_out = rgb_fea + rgb_fea_out
        ir_att_out = ir_fea + ir_fea_out
        rgb_fea = rgb_att_out + self.ffn(self.norm2(rgb_att_out))
        ir_fea = ir_att_out + self.ffn(self.norm2(ir_att_out))

        fea_cat = self.concat([rgb_fea, ir_fea])
        out_fea = self.relu(self.conv(fea_cat))

        return out_fea
  1. 参数说明以及使用实例
  • FDFTM(dim)
    • dim:输入 RGB/IR 特征通道数,也是输出融合特征通道数。
  • 内部子模块
    • FDCA(dim, bias):对应论文的 MFDA
    • FDFFN(dim, ffn_expansion_factor=3, bias=False):对应论文的 FDFFL
  • 输入
    • x[0]:RGB 特征,形状通常为 B×C×H×W
    • x[1]:IR 特征,形状通常为 B×C×H×W
  • 输出
    • out_fea:融合后的单路特征,形状为 B×C×H×W
  • 调用方式
fdfam = FDFTM(dim=256)
rgb = torch.randn(2, 256, 80, 80)
ir = torch.randn(2, 256, 80, 80)
y = fdfam([rgb, ir])   # 输出形状: [2, 256, 80, 80]
  • 整体运行流程
    • 两路输入先进入 FDCA 做频域跨模态注意力增强
    • 各自残差相加
    • 再分别进入 FDFFN 做多尺度频域前馈增强
    • 最后拼接两模态结果,经 1×1 Conv + ReLU 得到融合输出
  1. 代码-论文原理对应说明
  • class FDFTM 对应论文 FDFAM 总体

    • 论文定义:FDFAM 由 MFDA + 两个 FDFFL + concat + conv 组成。
    • 代码中:
      • self.attn = FDCA(dim, bias):对应 MFDA
      • self.ffn = FDFFN(dim, ffn_expansion_factor, bias):对应 FDFFL
      • self.concatself.convself.relu:对应公式 (4) 的最终融合输出
    • 注意:论文图 5 画的是每个模态各有一个 FDFFL 分支;代码中复用同一个 self.ffn 分别处理 RGB 和 IR,两次调用共享参数。这一点是代码实现细节,论文原文未明确说明“是否共享参数”。
  • class FDCA 对应论文 MFDA

    • self.norm1 = LayerNorm(dim, LayerNorm_type='WithBias')
      • 对应论文公式 (1)LN(E^G_RGB)LN(E^G_IR)
    • self.to_hidden = nn.Conv2d(dim, dim * 6, kernel_size=1, bias=bias)
      • 先做 1×1 Conv 投影,生成隐藏表示。
    • self.to_hidden_dw = nn.Conv2d(dim * 6, dim * 6, kernel_size=3, ..., groups=dim * 6, bias=bias)
      • 深度卷积后 chunk(3, dim=1) 得到 Q/K/V
      • 对应论文“标准 1×1 卷积层和 3×3 depth-wise convolution 生成 Q,K,V”。
    • rgb_q_fft = torch.fft.rfft2(rgb_q.float())
    • rgb_k_fft = torch.fft.rfft2(rgb_k.float())
      • 对应论文对 QK 进行 FFT
    • rgb_out = rgb_q_fft * rgb_k_fft
      • 对应论文中“在频域使用逐元素乘法估计相关性”。
      • 同理 ir_out = ir_q_fft * ir_k_fft
    • rgb_out = torch.fft.irfft2(rgb_out, s=(rgb_q.shape[2], rgb_q.shape[3]))
      • 对应 IFFT 恢复回空间域。
    • rgb_out = self.norm(rgb_out)
      • 对应论文中的归一化步骤 LN(...)
    • rgb_output = ir_v * rgb_out
      • 对应论文公式 (1)LN(...) ⊙ V_IR,即用另一模态的 V 做交互。
    • ir_output = rgb_v * ir_out
      • 对应另一方向交互。
    • rgb_output = self.project_out(rgb_output)
      • 对应论文中的 f^{1×1}_{conv}(...)
    • 与论文的一个实现差异
      • 论文公式写的是 F(Q_RGB) ⊙ F(K_RGB)F(Q_IR) ⊙ F(K_IR),再分别与另一模态 V 交互。
      • 代码严格遵循这一点。
      • 但论文表述称 MFDA“捕获不同模态间相关性”,从代码看,这种相关性不是直接 Q_RGBK_IR 相乘,而是先各模态内部生成频域响应,再通过与对方 V 相乘实现跨模态注入。
  • class FDFFN 对应论文 FDFFL

    • hidden_features = int(dim * ffn_expansion_factor)
      • 对应论文先扩展通道,构建前馈层的隐藏维度。
    • self.project_in = nn.Conv2d(dim, hidden_features, kernel_size=1, bias=bias)
      • 对应公式 (2)fconv^{1×1}(X_l)
    • self.dwconv3x3 / 5x5 / 7x7
      • 对应论文三条多尺度分支 3×3 / 5×5 / 7×7 depth-wise conv
    • x3 = self.relu3(self.dwconv3x3(x))
      • 对应公式 (2) 中卷积后跟 ReLU
    • torch.fft.rfft2(...)
      • 对应各分支映射到频域。
    • x1_3, x2_3, x3_3 = x3_patch_fft.chunk(3, dim=1)
      • 对应论文中“将每个频域特征沿通道维切成三块”。
    • x3_patch_fft = torch.cat([x1_3, x1_5, x1_7], dim=1)
      • 对应论文“以 mutual mixing 方式跨尺度重组拼接”,这是 FDFFL 最关键的实现点。
    • torch.fft.irfft2(...)
      • 对应公式 (3) 中 IFFT 回到空间域。
    • self.dwconv3x3_1 / 5x5_1 / 7x7_1
      • 对应论文中 IFFT 后分别再接不同尺度的 depth-wise conv。
    • x = torch.cat([x3, x5, x7], dim=1)
    • x = self.project_out(x)
      • 对应公式 (3)[X3, X5, X7] 拼接后再 1×1 Conv 降维。
    • 与论文的一个实现差异
      • 论文公式 (3) 明确给出 \bar{X} = X + fconv^{1×1}([X3, X5, X7]),即 FDFFL 内部含残差。
      • 代码 FDFFN.forward() 只返回变换结果 x,并未在模块内部加回输入。
      • 残差是在 FDFTM.forward() 中外部完成:
        • rgb_fea = rgb_att_out + self.ffn(self.norm2(rgb_att_out))
        • ir_fea = ir_att_out + self.ffn(self.norm2(ir_att_out))
      • 功能上等价于“Pre-Norm + 残差式 FFN”,只是残差位置放在外层封装里。
  • FDFTM.forward() 对应论文完整流程

    • rgb_fea_out, ir_fea_out = self.attn([rgb_fea, ir_fea])
      • 先做 MFDA。
    • rgb_att_out = rgb_fea + rgb_fea_out
    • ir_att_out = ir_fea + ir_fea_out
      • 对应论文公式 (1) 的残差输出 Fout_RGBFout_IR
    • rgb_fea = rgb_att_out + self.ffn(self.norm2(rgb_att_out))
    • ir_fea = ir_att_out + self.ffn(self.norm2(ir_att_out))
      • 对应论文图 5 中两路特征各自通过 FDFFL。
    • fea_cat = self.concat([rgb_fea, ir_fea])
    • out_fea = self.relu(self.conv(fea_cat))
      • 对应论文公式 (4)X_f = σ(fconv^{1×1}([X_RGB, X_IR]))
    • return out_fea
      • 输出最终融合特征。
  • 辅助归一化代码对应论文 LayerNorm

    • LayerNormWithBias_LayerNormBiasFree_LayerNorm
      • 是 FDFAM 内 LN 的实现基础。
    • to_3d / to_4d
      • B×C×H×W 转成 B×HW×C 做 LayerNorm,再还原。
      • 这是代码层面对论文 LN 的具体落地方式。

总结一下代码与论文的一致性:

  • 高度一致部分:
    • FDFAM 的总结构
    • MFDA 的 Q/K/V -> FFT -> 逐元素乘 -> IFFT -> LN -> 跨模态V交互
    • FDFFL 的多尺度频域混合策略
    • 最终 concat + 1×1 conv + activation 融合输出
  • 需要注意的实现细节:
    • 代码中模块名分别为 FDFTM / FDCA / FDFFN,而论文中名称是 FDFAM / MFDA / FDFFL
    • 代码里的 FDFFN 残差在外层 FDFTM 中实现,而不是在 FDFFN 内部
    • RGB/IR 两路 FDFFL 在代码中共享同一个 self.ffn 实例,这属于实现选择,论文未展开说明
Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐