Self-Attention:Transformer 的第一个核心部分详细拆解

前置知识:建议先阅读 01-Transformer架构 获得整体认识。

🔔 本专栏《AI核心原理30讲》:专注AI核心原理,回归技术本质。


一、开篇:为什么 Self-Attention 是革命性的?

1.1 RNN 的困境

要理解 Self-Attention 的价值,先看它的对手——RNN(循环神经网络)。

RNN 处理序列的方式

时间步 1       时间步 2       时间步 3       时间步 4       时间步 5
  ↓            ↓            ↓            ↓            ↓
"The"   →    "cat"    →    "sat"   →    "on"    →    "the"
  ↓            ↓            ↓            ↓            ↓
 h₁           h₂           h₃           h₄           h₅

RNN 的信息传递是链式的:

  • h₁ 包含 “The” 的信息
  • h₂ = f(h₁, “cat”),这时 “The” 的信息被编码进了 h₂
  • h₃ = f(h₂, “sat”),“The” 和 “cat” 的信息继续传递,但已经有所损失
  • 以此类推……

问题在哪?

当序列很长时(比如1000个词),序列开头的 “The” 要经过1000次传递才能影响最后一个词的表示。每次传递都可能有信息损失,到达末端时,早期信息已经被稀释得几乎看不见

这就是所谓的长期依赖问题(Long-Range Dependency Problem)

1.2 Self-Attention 如何破局?

Self-Attention 的核心思想:让任意两个位置之间的信息直接交互,不经过任何中间传递。

Self-Attention 的连接方式(完全并行):
                    ┌─────────────────────────┐
"The"  ─────────────┼──→ "cat"                 │
     └──→ ─ ─ ─ ─ ──┼──────────→ "sat"        │
                    │              ↓          │
                    │              ↓          │
                    │              ↓          │
                    └──────────────→ "on" ←───┘
                      (所有位置两两直接相连)

关键对比

特性 RNN Self-Attention
信息传递路径长度 O(n) O(1)
两个位置间的依赖 必须顺序传递 直接建模
并行化能力 差(必须顺序计算) 强(完全并行)
长距离信息保留 差(逐层稀释) 强(直接连接)

二、Self-Attention 的完整数学推导

2.1 从输入到 Q、K、V

假设输入序列是 ["The", "cat", "sat"],每个词已经经过 embedding 得到向量:

输入 X = [x₁, x₂, x₃]  形状: (3, d_model)
       x₁ = embedding("The")  → (d_model,)
       x₂ = embedding("cat")   → (d_model,)
       x₃ = embedding("sat")   → (d_model,)

然后,通过三个独立的线性变换,将每个词的 embedding 投影到 Q、K、V 空间:

Q = X · W_Q   形状: (3, d_k)
K = X · W_K   形状: (3, d_k)
V = X · W_V   形状: (3, d_v)

其中:

  • W_Q, W_K, W_V 是可学习的权重矩阵,形状均为 (d_model, d_k)(d_model, d_v)
  • 通常 d_k = d_v = d_model / num_heads

为什么要投影?

直接用原始 embedding 做 attention 不是不行,但投影后的 Q/K/V 能学习到更有意义的表示。每个投影空间让模型能够关注不同的"方面"。

2.2 注意力分数计算

对于序列中的每个词,计算它对所有词的注意力分数:

scores = Q · K^T / √d_k   形状: (3, 3)

展开来看:

scores[i,j] = q_i · k_j / √d_k

其中:
- q_i = x_i · W_Q  (第 i 个词的 query)
- k_j = x_j · W_K  (第 j 个词的 key)
- √d_k 是缩放因子

为什么要除以 √d_k?

假设 Q 和 K 的每个分量是均值为0、方差为1的独立随机变量,那么 Q·K^T 的方差会是 d_k。当 d_k 较大时,点积的值会很大,导致 softmax 函数进入饱和区,梯度接近于零。

除以 √d_k 后,点积的方差恢复到 1,softmax 的输出分布更加均匀,梯度也更稳定。

2.3 Softmax 归一化

attention_weights = softmax(scores, dim=-1)   形状: (3, 3)

Softmax 将每一行转换为概率分布:

softmax(x_i) = exp(x_i) / Σ exp(x_j)

每行的所有值加和为 1,表示当前位置对序列中各个位置的"关注程度"。

2.4 加权求和得到输出

output = attention_weights · V   形状: (3, d_v)

这步操作用注意力权重对 V 向量做加权平均:

output[i] = Σ attention_weights[i,j] · v_j

意思是:对于第 i 个词,我根据它对其他词的关注程度,取其他词的 value 向量的加权平均。

2.5 完整公式

Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V Attention(Q,K,V)=softmax(dk QKT)V


三、Multi-Head Attention 详解

3.1 为什么需要多个注意力头?

一个注意力头只能学到一种"匹配模式"。但语言的复杂性需要多种类型的关注:

以句子 “The animal didn’t cross the street because it was too tired” 为例:

  • 指代消解:“it” 应该关注 “animal”(不是 “street”)
  • 因果关系:“because” 连接了 “didn’t cross” 和 “tired”
  • 位置关系:“street” 和 “cross” 紧密相连

单一 attention 头难以同时捕捉这些关系。Multi-Head 让模型能在不同的子空间并行学习不同的关系。

3.2 Multi-Head 的计算

MultiHead(Q, K, V) = Concat(head_1, head_2, ..., head_h) · W_O

其中:
  head_i = Attention(Q · W_Q_i, K · W_K_i, V · W_V_i)

张量形状变化

输入 Q: (batch, seq_len, d_model)

      ↓ 投影 (h 个头)
      
Q_i: (batch, seq_len, d_k)  每个头 i = 1..h
K_i: (batch, seq_len, d_k)
V_i: (batch, seq_len, d_v)

      ↓ 分头 (view 操作)
      
Q_i: (batch, num_heads, seq_len, d_k)
K_i: (batch, num_heads, seq_len, d_k)
V_i: (batch, num_heads, seq_len, d_v)

      ↓ 注意力计算
       
head_i: (batch, num_heads, seq_len, d_v)

      ↓ 拼接
       
Concat: (batch, seq_len, h * d_v) = (batch, seq_len, d_model)

      ↓ 输出投影
      
output: (batch, seq_len, d_model)

3.3 论文中的标准配置

参数 Base 模型 Large 模型
d_model 512 1024
num_heads 8 16
d_k = d_v 64 64
FFN 维度 2048 4096

四、代码实现:从理论到代码

4.1 最简版本(纯 Python / NumPy)

import numpy as np

def softmax(x, axis=-1):
    """Numerically stable softmax"""
    exp_x = np.exp(x - np.max(x, axis=axis, keepdims=True))
    return exp_x / np.sum(exp_x, axis=axis, keepdims=True)

def self_attention(Q, K, V, d_k):
    """
    简化版 Self-Attention(无 batch 维度)
    
    参数:
        Q: (seq_len, d_k) 查询矩阵
        K: (seq_len, d_k) 键矩阵
        V: (seq_len, d_v) 值矩阵
        d_k: 缩放因子
    
    返回:
        output: (seq_len, d_v) 注意力输出
        weights: (seq_len, seq_len) 注意力权重
    """
    # Step 1: 计算点积注意力分数
    scores = np.dot(Q, K.T) / np.sqrt(d_k)
    
    # Step 2: Softmax 归一化
    attention_weights = softmax(scores, axis=-1)
    
    # Step 3: 加权求和
    output = np.dot(attention_weights, V)
    
    return output, attention_weights

# 测试
d_k = 64
seq_len = 5
d_v = 64

# 模拟输入
Q = np.random.randn(seq_len, d_k)
K = np.random.randn(seq_len, d_k)
V = np.random.randn(seq_len, d_v)

output, weights = self_attention(Q, K, V, d_k)
print(f"输出形状: {output.shape}")
print(f"注意力权重形状: {weights.shape}")
print(f"权重验证(每行和=1): {weights.sum(axis=1)}")

4.2 PyTorch 完整实现

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class MultiHeadAttention(nn.Module):
    """Multi-Head Attention 的完整 PyTorch 实现"""
    
    def __init__(self, d_model, num_heads, dropout=0.1):
        super().__init__()
        assert d_model % num_heads == 0, "d_model 必须能被 num_heads 整除"
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads  # 每个注意力头的维度
        
        # 四个线性变换层
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, query, key, value, mask=None):
        """
        参数:
            query: (batch, seq_len, d_model)
            key:   (batch, seq_len, d_model)
            value: (batch, seq_len, d_model)
            mask:  (batch, seq_len, seq_len) 或 (batch, 1, seq_len, seq_len)
        
        返回:
            output: (batch, seq_len, d_model)
            attention_weights: (batch, num_heads, seq_len, seq_len)
        """
        batch_size = query.size(0)
        seq_len = query.size(1)
        
        # ========== Step 1: 线性投影 + 分头 ==========
        # Q, K, V: (batch, seq_len, d_model)
        Q = self.W_q(query)
        K = self.W_k(key)
        V = self.W_v(value)
        
        # 视图分割:最后一项分成 num_heads × d_k
        # (batch, seq_len, num_heads, d_k) 然后转置
        # → (batch, num_heads, seq_len, d_k)
        Q = Q.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        K = K.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        V = V.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        
        # ========== Step 2: 计算注意力分数 ==========
        # scores: (batch, num_heads, seq_len, seq_len)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        
        # ========== Step 3: 应用 Mask ==========
        if mask is not None:
            # 支持两种 mask 格式
            if mask.dim() == 2:  # (seq_len, seq_len)
                mask = mask.unsqueeze(0).unsqueeze(0)  # → (1, 1, seq_len, seq_len)
            elif mask.dim() == 3:  # (batch, seq_len, seq_len)
                mask = mask.unsqueeze(1)  # → (batch, 1, seq_len, seq_len)
            scores = scores.masked_fill(mask == 0, float('-inf'))
        
        # ========== Step 4: Softmax + Dropout ==========
        attention_weights = F.softmax(scores, dim=-1)
        attention_weights = self.dropout(attention_weights)
        
        # ========== Step 5: 加权求和 ==========
        # context: (batch, num_heads, seq_len, d_k)
        context = torch.matmul(attention_weights, V)
        
        # ========== Step 6: 合并多头 ==========
        # (batch, num_heads, seq_len, d_k) → (batch, seq_len, num_heads, d_k)
        context = context.transpose(1, 2).contiguous()
        # → (batch, seq_len, d_model)
        context = context.view(batch_size, seq_len, self.d_model)
        
        # ========== Step 7: 最终线性投影 ==========
        output = self.W_o(context)
        
        return output, attention_weights

4.3 使用示例

# 创建一个 Multi-Head Attention 层
d_model = 512
num_heads = 8
attention = MultiHeadAttention(d_model, num_heads)

# 模拟输入
batch_size = 2
seq_len = 10
x = torch.randn(batch_size, seq_len, d_model)

# 前向传播
output, attn_weights = attention(x, x, x)

print(f"输出形状: {output.shape}")              # (2, 10, 512)
print(f"注意力权重形状: {attn_weights.shape}")    # (2, 8, 10, 10)

# 可视化第 1 个样本第 1 个头的注意力权重
# attn_weights[0, 0] 形状是 (10, 10)

五、Mask 的作用与实现

5.1 为什么需要 Mask?

在 Transformer 中,Mask 主要有两种用途:

用途 1:Padding Mask

输入序列长度不一,需要 padding 到统一长度。但 padding 位置不应该参与注意力计算。

原始句子: ["The", "cat", "sat"]
Padding后: ["The", "cat", "sat", "[PAD]", "[PAD]"]

注意力权重应该是:
          The   cat   sat   [PAD] [PAD]
  The    [ 0.4   0.3   0.2    0.0   0.0 ]
  cat    [ 0.3   0.4   0.2    0.0   0.0 ]
  sat    [ 0.2   0.2   0.4    0.0   0.0 ]
  [PAD]  [ 0.0   0.0   0.0    0.0   0.0 ]
  [PAD]  [ 0.0   0.0   0.0    0.0   0.0 ]
                ↑
           padding 位置权重为 0

用途 2:因果掩码(Causal Mask / Look-Ahead Mask)

在解码器中,预测第 N 个词时不能看到第 N 个词之后的任何信息

目标序列: ["The", "cat", "sat", "[EOS]"]

允许的注意力连接(✓ 表示可见):
              Step1  Step2  Step3  Step4
  Step1   →   ✓      ✗      ✗      ✗
  Step2   →   ✓      ✓      ✗      ✗
  Step3   →   ✓      ✓      ✓      ✗
  Step4   →   ✓      ✓      ✓      ✓
  
解码器中真正的注意力权重:
  Step1   [ 1.0   0.0   0.0   0.0 ]
  Step2   [ 0.5   0.5   0.0   0.0 ]
  Step3   [ 0.3   0.3   0.4   0.0 ]
  Step4   [ 0.2   0.2   0.2   0.4 ]

5.2 Mask 的 PyTorch 实现

def create_padding_mask(seq, pad_idx=0):
    """
    创建 Padding Mask
    
    参数:
        seq: (batch, seq_len) token IDs
        pad_idx: padding 的 token ID,默认为 0
    
    返回:
        mask: (batch, 1, 1, seq_len) True 表示有效位置,False 表示需要 mask
    """
    mask = (seq != pad_idx).unsqueeze(1).unsqueeze(2)
    return mask  # (batch, 1, 1, seq_len) → 广播后 (batch, num_heads, seq_len, seq_len)

def create_causal_mask(seq_len):
    """
    创建因果掩码(上三角 mask)
    
    返回:
        mask: (1, 1, seq_len, seq_len) 上三角为 False(不可见)
    """
    # torch.triu: 上三角(不含对角线)为 1
    mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
    # → (seq_len, seq_len)
    # 对角线及以上为 True(需要 mask),对角线以下为 False(可见)
    
    mask = mask.unsqueeze(0).unsqueeze(0)  # → (1, 1, seq_len, seq_len)
    return ~mask  # 取反:True 表示可见,False 表示 mask


# 使用示例
batch_size = 2
seq_len = 5

# Padding Mask
seq = torch.tensor([[1, 2, 3, 0, 0],  # 句子1: [PAD]=0
                    [1, 2, 0, 0, 0]]) # 句子2: [PAD]=0
padding_mask = create_padding_mask(seq)
print("Padding Mask:")
print(padding_mask[0])  # 句子1的 mask

# 因果 Mask
causal_mask = create_causal_mask(seq_len)
print("\n因果 Mask:")
print(causal_mask[0, 0])

六、注意力权重的可视化

6.1 典型的注意力模式

不同层的注意力头会捕捉不同类型的依赖:

Layer 1 Head 1(捕捉局部关系):
              The   cat   sat   on   mat
     The    [ 0.5   0.3   0.1   0.05 0.05 ]
     cat    [ 0.3   0.4   0.2   0.05 0.05 ]
     sat    [ 0.1   0.2   0.4   0.2  0.1  ]
     on     [ 0.05  0.05  0.2  0.4  0.3  ]
     mat    [ 0.05  0.05  0.1  0.3  0.5  ]
     ↑ 局部性:关注相邻词

Layer 3 Head 5(捕捉语法关系):
              The   cat   sat   on   mat
     The    [ 0.4   0.1   0.1   0.1  0.1  ]
     cat    [ 0.1   0.6   0.1   0.1  0.1  ]  ← "cat" 关注自身(主语)
     sat    [ 0.1   0.1   0.6   0.1  0.1  ]  ← "sat" 关注自身(谓语)
     on     [ 0.1   0.1   0.1   0.6  0.1  ]  
     mat    [ 0.1   0.1   0.1   0.1  0.6  ]  
     ↑ 语法:每个词更关注自己(完整实体表示)

6.2 可视化代码

import matplotlib.pyplot as plt
import seaborn as sns

def plot_attention_weights(weights, tokens, save_path=None):
    """
    可视化注意力权重热力图
    
    参数:
        weights: (num_heads, seq_len, seq_len) 或 (seq_len, seq_len)
        tokens: list of strings,词元列表
        save_path: 可选,保存路径
    """
    if weights.dim() == 4:
        # 取第一个样本的第一个头
        weights = weights[0, 0].detach().numpy()
    else:
        weights = weights.detach().numpy()
    
    plt.figure(figsize=(10, 8))
    sns.heatmap(weights, 
                xticklabels=tokens, 
                yticklabels=tokens,
                cmap='Blues',
                annot=False,
                fmt='.2f')
    plt.xlabel('Key 位置')
    plt.ylabel('Query 位置')
    plt.title('Self-Attention 权重热力图')
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path)
    plt.show()

# 使用示例
tokens = ["The", "cat", "sat", "on", "the", "mat"]
# 假设已经得到 attention_weights
# plot_attention_weights(attention_weights, tokens)

七、Self-Attention 的复杂度分析

7.1 时间复杂度

Self-Attention 的主要计算:

Step 1: Q, K, V 投影        O(n · d_model · d_k) × 3
Step 2: 计算 QK^T           O(n² · d_k)
Step 3: Softmax             O(n²)
Step 4: 加权求和            O(n² · d_v)

总计: O(n² · d_model)

其中 n 是序列长度,d_model 是模型维度。

关键点:Self-Attention 的时间复杂度是序列长度的平方(O(n²))。这是 Transformer 的主要瓶颈。

7.2 空间复杂度

存储 Q, K, V:          O(n · d_model) × 3
存储注意力权重矩阵:    O(n²)
存储输出:              O(n · d_model)

总计: O(n² + n · d_model)

注意力权重矩阵 O(n²) 是最大的开销。

7.3 复杂度对比

模型 注意力复杂度
RNN/LSTM O(n · d)
Self-Attention O(n² · d)
局部注意力 (窗口= k) O(n · k · d)
Linformer O(n · k)
Reformer O(n · log n)

这就是为什么 Long Context 是一个热门研究方向——标准 Self-Attention 无法直接处理超长序列。


八、Self-Attention 在 Transformer 中的位置

┌─────────────────────────────────────────────────────────────┐
│                      Encoder Layer                           │
│                                                              │
│  Input X ──→ Multi-Head Self-Attention ──→ Add & Norm ──┐  │
│                                                              │  │
│              ↑                                              │  │
│              │                                              │  │
│              └──────── Feed Forward ────────────────────────┘  │
│                                                              │
│                         × N 层                                │
└─────────────────────────────────────────────────────────────┘

每层内部:

  x_input
    │
    ├──→ Multi-Head Self-Attention ──→ Add(x_input, attention_output)
    │                                        │
    │                                        ↓
    │                                    LayerNorm
    │                                        │
    │                                        ↓ (sublayer_1 output)
    │                                        │
    └──→ Feed Forward Network ─────────→ Add(sublayer_1, ffn_output)
                                             │
                                             ↓
                                         LayerNorm
                                             │
                                             ↓
                                       x_output (传给下一层)

Self-Attention 的核心作用

  1. Encoder Self-Attention:让每个位置能看到序列中所有其他位置,学习输入的上下文表示
  2. Decoder Self-Attention:类似,但有 Mask,确保自回归生成时不泄露未来信息
  3. Cross Attention:Decoder 层中,Q 来自 Decoder,K/V 来自 Encoder 输出,实现跨模块交互

九、关键设计选择的原因

设计选择 选择 原因
缩放因子 √d_k QK^T / √d_k 防止 d_k 较大时 softmax 梯度消失
多头注意力 h 个独立头 不同头学习不同类型的依赖关系
Q=K=V 输入 自注意力 让序列内部进行自我比较,学习内部结构
Linear 投影 学习的 W_Q, W_K, W_V 增加模型表达能力,让投影空间更有意义
拼接后投影 Concat → W_O 合并多头的不同子空间表示

十、总结与延伸

10.1 核心要点回顾

  1. Self-Attention 通过直接建模任意位置间的依赖,解决了 RNN 的长距离依赖问题

  2. Q/K/V 三元组让每个词既能"问问题"(Query)也能"回答问题"(Key/Value)

  3. 缩放因子 √d_k 是关键细节,防止大维度下的梯度消失

  4. Multi-Head 扩展了模型的表示能力,不同头学习不同类型的关系

  5. Mask 机制让 Transformer 能处理变长序列和控制信息流动

10.2 延伸阅读方向

方向 关键技术 适用场景
高效注意力 Flash Attention, Sparse Attention 长上下文
位置编码 RoPE, ALiBi,绝对位置编码 位置感知
注意力变体 Grouped Query Attention 高效推理
跨模态 Cross Attention 多模态融合

参考资料

  1. Vaswani et al., “Attention Is All You Need”, NeurIPS 2017
  2. The Illustrated Transformer: http://nlp.seas.harvard.edu/2018/04/03/attention.html
  3. Lilian Weng, “Attention? Attention!”, Lil’Log
  4. PyTorch Transformer Documentation: https://pytorch.org/docs/stable/generated/torch.nn.Transformer.html

📢 本专栏持续更新,下一篇:《Feed Forward Network:注意力之外的另一条腿》

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐