背景与动机

FM 的局限

FM 的二阶交互:

y = w₀ + Σᵢ wᵢxᵢ + Σᵢⱼ vᵢ·vⱼ xᵢxⱼ

问题: 所有特征对的交互项权重相同,无法区分重要性。

例子:

用户A × 广告1: 交互强度 0.8
用户A × 广告2: 交互强度 0.9
用户B × 广告1: 交互强度 0.1

FM 认为所有交互同等重要!

AFM 的创新

引入注意力机制:

y = w₀ + Σᵢ wᵢxᵢ + Σᵢⱼ aᵢⱼ (vᵢ·vⱼ) xᵢxⱼ
                            ↑
                      注意力权重

注意力权重: 学习每个特征交互的重要性


核心创新:注意力机制

注意力网络 (Attention Network)

输入: [v₁·v₂, v₁·v₃, v₂·v₃, ...]  # 所有特征对的交互
         ↓
    Attention Network
         ↓
    注意力权重: [a₁₂, a₁₃, a₂₃, ...]

注意力权重计算

Softmax 池和池化:

aᵢⱼ = softmax(Attention-Net(vᵢ, vⱼ))

作用:
1. 计算每个特征对的注意分数
2. 通过 softmax 归一化
3. 加权平均,突出重要的交互

为什么需要注意力?

情景 FM AFM
重要交互 无法区分 高权重
噪声交互 同等处理 低权重
解释性 强(注意力可视化)

AFM vs FM/DeepFM 对比

核心区别

维度 FM AFM DeepFM
二阶交互 固定权重 注意力加权 囟定权重
学习机制 线性组合 注意力机制 DNN 隐式
参数量 稍多
可解释性 强(注意力)

公式对比

FM:

y = Σᵢⱼ (vᵢ·vⱼ) xᵢxⱼ

AFM:

y = Σᵢⱼ aᵢⱼ (vᵢ·vⱼ) xᵢxⱼ

其中: aᵢⱼ = Attention-Net(vᵢ, vⱼ)

模型架构

整体结构

              输入特征 (离散索引)
                       ↓
              Embedding 层
                       ↓
              二阶交互计算
               (vᵢ · vⱼ)
                       ↓
            ┌──────┴──────┐
            ↓              ↓
       AFM 部分     Linear 部分
    (注意力加权)    (线性影响)
            └──────┬──────┘
                       ↓
                   最终输出

组件详解

1. Embedding 层

将离散特征映射到稠密向量

用户 123 → [0.23, 0.15, ...]
广告 45  → [0.67, 0.32, ...]
2. 注意力网络 (Attention Network)

结构:

输入: vᵢ, vⱼ  # 两个特征向量
  ↓
 Concat: [vᵢ, vⱼ]  # 拼接
  ↓
 MLP: ReLU 层
  ↓
输出: 注意力分数  # 标量

代码实现:

# 注意力网络
att_net = nn.Sequential(
    nn.Linear(2 * k, t),  # k: embedding 维度, t: 隐层维度
    nn.ReLU(),
    nn.Linear(t, 1)
)
3. 注意力计算

步骤:

  1. 计算所有特征对的交互

    pairwise = [v₁·v₂, v₁·v₃, v₂·v₃, ...]
    
  2. 计算注意力分数

    attention_scores = att_net([vᵢ, vⱼ])
    
  3. Softmax 归一化

    attention_weights = softmax(attention_scores)
    
  4. 加权求和

    output = Σᵢⱼ attention_weights × pairwise
    

代码实现

PyTorch 实现

import torch
import torch.nn as nn


class AttentionalFM(nn.Module):
    """
    Attentional Factorization Machine (AFM)

    核心创新:
        在 FM 的二阶交互基础上引入注意力机制

    公式:
        y = w₀ + Σᵢ wᵢxᵢ + Σᵢⱼ aᵢⱼ (vᵢ·vⱼ) xᵢxⱼ

    其中: aᵢⱼ = Attention-Net(vᵢ, vⱼ)
    """

    def __init__(self, feature_dims, embedding_dim=8, attention_hidden_dim=8):
        """
        Args:
            feature_dims: 每个特征的可能取值数
            embedding_dim: embedding 向量维度
            attention_hidden_dim: 注意力网络的隐藏层维度
        """
        super().__init__()

        self.feature_dims = feature_dims
        self.num_features = len(feature_dims)
        self.embedding_dim = embedding_dim

        # ==================== Embedding 层 ====================
        self.embeddings = nn.ModuleList([
            nn.Embedding(dim, embedding_dim) for dim in feature_dims
        ])

        # ==================== FM 一阶部分 ====================
        self.linear = nn.Linear(self.num_features, 1)

        # ==================== 注意力网络 ====================
        # 输入: 两个 embedding 向量拼接
        # 输出: 注意力分数
        self.attention_net = nn.Sequential(
            nn.Linear(2 * embedding_dim, attention_hidden_dim),
            nn.ReLU(),
            nn.Linear(attention_hidden_dim, 1)
        )

    def forward(self, x):
        """
        Args:
            x: (batch_size, num_features) 离散特征索引

        Returns:
            logits: (batch_size, 1) 预测分数
        """
        batch_size = x.shape[0]

        # ==================== Embedding ====================
        embedded_features = []
        for i, emb in enumerate(self.embeddings):
            emb_i = emb(x[:, i])
            embedded_features.append(emb_i)

        all_embeddings = torch.cat(embedded_features, dim=1)
        all_embeddings = all_embeddings.view(
            batch_size, self.num_features, self.embedding_dim
        )

        # ==================== FM 一阶 ====================
        linear_part = self.linear(x.float())

        # ==================== 二阶交互(带注意力) ====================
        pairwise_interactions = []

        # 计算所有特征对的交互和注意力
        for i in range(self.num_features):
            for j in range(i + 1, self.num_features):
                # 1. 计算特征交互: vᵢ · vⱼ
                interaction = torch.sum(
                    all_embeddings[:, i, :] * all_embeddings[:, j, :],
                    dim=1,
                    keepdim=True
                )

                # 2. 计算注意力: Attention-Net(vᵢ, vⱼ)
                pair_embeddings = torch.cat([
                    all_embeddings[:, i, :],
                    all_embeddings[:, j, :]
                ], dim=1)

                attention_score = self.attention_net(pair_embeddings)

                # 3. 归一化: sigmoid(单注意力)
                attention_weight = torch.sigmoid(attention_score)

                # 4. 加权
                weighted_interaction = attention_weight * interaction

                pairwise_interactions.append(weighted_interaction)

        # 汇总所有交互项
        second_order = torch.sum(torch.stack(pairwise_interactions, dim=0), dim=0)

        # ==================== 输出 ====================
        output = linear_part + second_order

        return output


# ==================== 使用示例 ====================
if __name__ == '__main__':
    # 特征定义
    feature_dims = [1000, 500, 5, 4, 10]

    # 创建 AFM 模型
    model = AttentionalFM(
        feature_dims=feature_dims,
        embedding_dim=8,
        attention_hidden_dim=8
    )

    print('=== AFM 模型结构 ===')
    print(model)

    # 参数量分析
    total_params = sum(p.numel() for p in model.parameters())
    embedding_params = sum(p.numel() for p in model.embeddings.parameters())
    linear_params = sum(p.numel() for p in model.linear.parameters())
    attention_params = sum(p.numel() for p in model.attention_net.parameters())

    print(f'\n参数量分析:')
    print(f'  总参数量: {total_params:,}')
    print(f'  Embedding: {embedding_params:,}')
    print(f'  Linear 部分: {linear_params:,}')
    print(f'  注意力网络: {attention_params:,}')

    # 生成训练数据
    batch_size = 32
    x = torch.tensor([
        [torch.randint(0, dim, size=(1,)).item() for dim in feature_dims]
        for _ in range(batch_size)
    ])
    y = torch.randint(0, 2, (batch_size, 1), dtype=torch.float32)

    # 训练
    criterion = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    print(f'\n=== 开始训练 ===')
    for epoch in range(100):
        pred = model(x)
        loss = criterion(pred, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (epoch + 1) % 20 == 0:
            print(f'Epoch {epoch + 1:3d}, Loss: {loss.item():.6f}')

    # 预测
    model.eval()
    with torch.no_grad():
        test_x = torch.tensor([[
            torch.randint(0, dim, size=(1,)).item() for dim in feature_dims
        ]])

        logits = model(test_x)
        click_prob = torch.sigmoid(logits)

        print(f'\n=== 预测结果 ===')
        print(f'模型输出 (logits): {logits.item():.4f}')
        print(f'点击概率 (sigmoid): {click_prob.item():.4f}')

面试常见问题

Q1: AFM 和 FM 的核心区别是什么?

A:

  • FM: 所有特征对的交互权重相同
  • AFM: 每个特征对有独立的注意力权重

Q2: 注意力权重是如何计算的?

A:

1. Attention-Net(vᵢ, vⱼ) → 注意力分数
2. sigmoid(分数) → 归一化权重
3. 权重 × 交互 → 加权交互

Q3: AFM 为什么使用 sigmoid 而不是 softmax?

A:

  • sigmoid: 每个特征对独立归一化 (0-1)
  • softmax: 所有特征对竞争,和为 1

AFM 使用 sigmoid 是因为每个交互项相对独立。

Q4: AFM 的参数量如何?

A:

Embedding: n × k
Linear: n
Attention: (2k × h + h × 1) × 每对数

增加的参数主要来自注意力网络

Q5: AFM 的适用场景?

A:

场景 是否推荐 原因
特征交互重要性不同 ✅ 推荐 注意力突出重要交互
需要解释性 ✅ 推荐 可视化注意力权重
数据量大 ✅ 推荐 能学到注意力模式
延迟敏感 ❌ 不推荐 注意力增加计算

Q6: AFM vs xDeepFM 如何选择?

A:

模型 优势 劣势
AFM 参数少,可解释 只能加成二阶交互
xDeepFM 显式高阶 参数多,复杂

选择建议:

  • 需要可解释性 → AFM
  • 需要高阶交互 → xDeepFM

模型演进

FM (2010)
   ↓ 固定权重二阶交互
DeepFM (2017)
   ↓ 添加 DNN 高阶
xDeepFM (2018)
   ↓ CIN 显式高阶
AFM (2017) ⭐
   ↓ 注意力机制
DIN / DIEN (2018-2019)
   ↓ 序列化注意力
AutoInt (2020)
   ↓ 自动交互阶数

快速检查清单

理解 AFM,你应该能回答:

  • 解释 AFM 和 FM 的区别
  • 说明注意力网络的作用
  • 计算注意力权重
  • 了解 sigmoid vs softmax 的区别
  • 比较 AFM vs DeepFM/xDeepFM
  • 知道 AFM 的参数量计算
  • 能从零实现简单的 AFM

参考资料

  • AFM 原始论文: https://arxiv.org/abs/1708.04621
  • 注意力机制: https://distill.pub/2016/a-fundamental-approach-to-neural-attention
  • 推荐系统模型库: https://github.com/shenweichen/DeepCTR

实操案例(CTR预测)

import torch
import torch.nn as nn


class AttentionalFM(nn.Module):
    """
    Attentional Factorization Machine (AFM)

    核心创新:
        在 FM 的二阶交互基础上引入注意力机制

    公式:
        y = w₀ + Σᵢ wᵢxᵢ + Σᵢⱼ aᵢⱼ (vᵢ·vⱼ) xᵢxⱼ

    其中: aᵢⱼ = Attention-Net(vᵢ, vⱼ) 是注意力权重
    """

    def __init__(self, feature_dims, embedding_dim=8, attention_hidden_dim=8):
        """
        Args:
            feature_dims: 每个特征的可能取值数
            embedding_dim: embedding 向量维度
            attention_hidden_dim: 注意力网络的隐藏层维度
        """
        super().__init__()

        self.feature_dims = feature_dims
        self.num_features = len(feature_dims)
        self.embedding_dim = embedding_dim

        # ==================== Embedding 层 ====================
        self.embeddings = nn.ModuleList([
            nn.Embedding(dim, embedding_dim) for dim in feature_dims
        ])

        # ==================== FM 一阶部分 ====================
        self.linear = nn.Linear(self.num_features, 1)

        # ==================== 注意力网络 ====================
        # 输入: 两个 embedding 向量拼接
        # 输出: 注意力分数
        self.attention_net = nn.Sequential(
            nn.Linear(2 * embedding_dim, attention_hidden_dim),
            nn.ReLU(),
            nn.Linear(attention_hidden_dim, 1)
        )

    def forward(self, x):
        """
        Args:
            x: (batch_size, num_features) 离散特征索引

        Returns:
            logits: (batch_size, 1) 预测分数
        """
        batch_size = x.shape[0]

        # ==================== Embedding ====================
        embedded_features = []
        for i, emb in enumerate(self.embeddings):
            emb_i = emb(x[:, i])
            embedded_features.append(emb_i)

        all_embeddings = torch.cat(embedded_features, dim=1)
        all_embeddings = all_embeddings.view(
            batch_size, self.num_features, self.embedding_dim
        )

        # ==================== FM 一阶 ====================
        linear_part = self.linear(x.float())

        # ==================== 二阶交互(带注意力) ====================
        pairwise_interactions = []

        # 计算所有特征对的交互和注意力
        for i in range(self.num_features):
            for j in range(i + 1, self.num_features):
                # 1. 计算特征交互: vᵢ · vⱼ
                interaction = torch.sum(
                    all_embeddings[:, i, :] * all_embeddings[:, j, :],
                    dim=1,
                    keepdim=True
                )

                # 2. 计算注意力: Attention-Net(vᵢ, vⱼ)
                pair_embeddings = torch.cat([
                    all_embeddings[:, i, :],
                    all_embeddings[:, j, :]
                ], dim=1)

                attention_score = self.attention_net(pair_embeddings)

                # 3. 归一化: sigmoid(单注意力)
                attention_weight = torch.sigmoid(attention_score)

                # 4. 加权
                weighted_interaction = attention_weight * interaction

                pairwise_interactions.append(weighted_interaction)

        # 汇总所有交互项
        second_order = torch.sum(torch.stack(pairwise_interactions, dim=0), dim=0)

        # ==================== 输出 ====================
        output = linear_part + second_order

        return output


# ==================== 使用示例 ====================
if __name__ == '__main__':
    # 特征定义
    # 特征: [用户ID, 广告ID, 设备类型, 时间段, 位置]
    feature_dims = [1000, 500, 5, 4, 10]

    # 创建 AFM 模型
    model = AttentionalFM(
        feature_dims=feature_dims,
        embedding_dim=8,
        attention_hidden_dim=8
    )

    print('=== AFM 模型结构 ===')
    print(model)

    # ==================== 参数量分析 ====================
    total_params = sum(p.numel() for p in model.parameters())

    embedding_params = sum(p.numel() for p in model.embeddings.parameters())
    linear_params = sum(p.numel() for p in model.linear.parameters())
    attention_params = sum(p.numel() for p in model.attention_net.parameters())

    print(f'\n参数量分析:')
    print(f'  总参数量: {total_params:,}')
    print(f'  Embedding: {embedding_params:,}')
    print(f'  Linear 部分: {linear_params:,}')
    print(f'  注意力网络: {attention_params:,}')

    # ==================== 生成训练数据 ====================
    batch_size = 32

    # 生成随机训练样本
    x = torch.tensor([
        [torch.randint(0, dim, size=(1,)).item() for dim in feature_dims]
        for _ in range(batch_size)
    ])

    # 生成随机标签
    y = torch.randint(0, 2, (batch_size, 1), dtype=torch.float32)

    # ==================== 训练配置 ====================
    criterion = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    print(f'\n=== 开始训练 ===')
    print(f'batch_size: {batch_size}')
    print(f'特征数: {model.num_features}')

    # ==================== 训练循环 ====================
    for epoch in range(12000):
        pred = model(x)
        loss = criterion(pred, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (epoch + 1) % 20 == 0:
            print(f'Epoch {epoch + 1:3d}, Loss: {loss.item():.6f}')

    # ==================== 预测示例 ====================
    model.eval()

    with torch.no_grad():
        # 生成测试样本
        test_x = torch.tensor([[
            torch.randint(0, dim, size=(1,)).item() for dim in feature_dims
        ]])

        logits = model(test_x)
        click_prob = torch.sigmoid(logits)

        print(f'\n=== 预测结果 ===')
        print(f'模型输出 (logits): {logits.item():.4f}')
        print(f'点击概率 (sigmoid): {click_prob.item():.4f}')

    # ==================== 注意力机制说明 ====================
    print(f'\n=== 注意力机制工作原理 ===')
    print('1. 计算特征交互: vᵢ · vⱼ (内积)')
    print('2. 输入注意力网络: Attention-Net([vᵢ, vⱼ])')
    print('3. 输出注意力分数: sigmoid(Attention-Net(...))')
    print('4. 加权: 注意力分数 × 交互项')
    print('\n优势: 突出重要的特征交互,抑制不重要的')

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐