【AI核心原理30讲】-Transformer架构(一)
Self-Attention:Transformer 的第一个核心部分详细拆解
前置知识:建议先阅读 01-Transformer架构 获得整体认识。
🔔 本专栏《AI核心原理30讲》:专注AI核心原理,回归技术本质。
一、开篇:为什么 Self-Attention 是革命性的?
1.1 RNN 的困境
要理解 Self-Attention 的价值,先看它的对手——RNN(循环神经网络)。
RNN 处理序列的方式:
时间步 1 时间步 2 时间步 3 时间步 4 时间步 5
↓ ↓ ↓ ↓ ↓
"The" → "cat" → "sat" → "on" → "the"
↓ ↓ ↓ ↓ ↓
h₁ h₂ h₃ h₄ h₅
RNN 的信息传递是链式的:
- h₁ 包含 “The” 的信息
- h₂ = f(h₁, “cat”),这时 “The” 的信息被编码进了 h₂
- h₃ = f(h₂, “sat”),“The” 和 “cat” 的信息继续传递,但已经有所损失
- 以此类推……
问题在哪?
当序列很长时(比如1000个词),序列开头的 “The” 要经过1000次传递才能影响最后一个词的表示。每次传递都可能有信息损失,到达末端时,早期信息已经被稀释得几乎看不见。
这就是所谓的长期依赖问题(Long-Range Dependency Problem)。
1.2 Self-Attention 如何破局?
Self-Attention 的核心思想:让任意两个位置之间的信息直接交互,不经过任何中间传递。
Self-Attention 的连接方式(完全并行):
┌─────────────────────────┐
"The" ─────────────┼──→ "cat" │
└──→ ─ ─ ─ ─ ──┼──────────→ "sat" │
│ ↓ │
│ ↓ │
│ ↓ │
└──────────────→ "on" ←───┘
(所有位置两两直接相连)
关键对比:
| 特性 | RNN | Self-Attention |
|---|---|---|
| 信息传递路径长度 | O(n) | O(1) |
| 两个位置间的依赖 | 必须顺序传递 | 直接建模 |
| 并行化能力 | 差(必须顺序计算) | 强(完全并行) |
| 长距离信息保留 | 差(逐层稀释) | 强(直接连接) |
二、Self-Attention 的完整数学推导
2.1 从输入到 Q、K、V
假设输入序列是 ["The", "cat", "sat"],每个词已经经过 embedding 得到向量:
输入 X = [x₁, x₂, x₃] 形状: (3, d_model)
x₁ = embedding("The") → (d_model,)
x₂ = embedding("cat") → (d_model,)
x₃ = embedding("sat") → (d_model,)
然后,通过三个独立的线性变换,将每个词的 embedding 投影到 Q、K、V 空间:
Q = X · W_Q 形状: (3, d_k)
K = X · W_K 形状: (3, d_k)
V = X · W_V 形状: (3, d_v)
其中:
W_Q,W_K,W_V是可学习的权重矩阵,形状均为(d_model, d_k)或(d_model, d_v)- 通常
d_k = d_v = d_model / num_heads
为什么要投影?
直接用原始 embedding 做 attention 不是不行,但投影后的 Q/K/V 能学习到更有意义的表示。每个投影空间让模型能够关注不同的"方面"。
2.2 注意力分数计算
对于序列中的每个词,计算它对所有词的注意力分数:
scores = Q · K^T / √d_k 形状: (3, 3)
展开来看:
scores[i,j] = q_i · k_j / √d_k
其中:
- q_i = x_i · W_Q (第 i 个词的 query)
- k_j = x_j · W_K (第 j 个词的 key)
- √d_k 是缩放因子
为什么要除以 √d_k?
假设 Q 和 K 的每个分量是均值为0、方差为1的独立随机变量,那么 Q·K^T 的方差会是 d_k。当 d_k 较大时,点积的值会很大,导致 softmax 函数进入饱和区,梯度接近于零。
除以 √d_k 后,点积的方差恢复到 1,softmax 的输出分布更加均匀,梯度也更稳定。
2.3 Softmax 归一化
attention_weights = softmax(scores, dim=-1) 形状: (3, 3)
Softmax 将每一行转换为概率分布:
softmax(x_i) = exp(x_i) / Σ exp(x_j)
每行的所有值加和为 1,表示当前位置对序列中各个位置的"关注程度"。
2.4 加权求和得到输出
output = attention_weights · V 形状: (3, d_v)
这步操作用注意力权重对 V 向量做加权平均:
output[i] = Σ attention_weights[i,j] · v_j
意思是:对于第 i 个词,我根据它对其他词的关注程度,取其他词的 value 向量的加权平均。
2.5 完整公式
Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V Attention(Q,K,V)=softmax(dkQKT)V
三、Multi-Head Attention 详解
3.1 为什么需要多个注意力头?
一个注意力头只能学到一种"匹配模式"。但语言的复杂性需要多种类型的关注:
以句子 “The animal didn’t cross the street because it was too tired” 为例:
- 指代消解:“it” 应该关注 “animal”(不是 “street”)
- 因果关系:“because” 连接了 “didn’t cross” 和 “tired”
- 位置关系:“street” 和 “cross” 紧密相连
单一 attention 头难以同时捕捉这些关系。Multi-Head 让模型能在不同的子空间并行学习不同的关系。
3.2 Multi-Head 的计算
MultiHead(Q, K, V) = Concat(head_1, head_2, ..., head_h) · W_O
其中:
head_i = Attention(Q · W_Q_i, K · W_K_i, V · W_V_i)
张量形状变化:
输入 Q: (batch, seq_len, d_model)
↓ 投影 (h 个头)
Q_i: (batch, seq_len, d_k) 每个头 i = 1..h
K_i: (batch, seq_len, d_k)
V_i: (batch, seq_len, d_v)
↓ 分头 (view 操作)
Q_i: (batch, num_heads, seq_len, d_k)
K_i: (batch, num_heads, seq_len, d_k)
V_i: (batch, num_heads, seq_len, d_v)
↓ 注意力计算
head_i: (batch, num_heads, seq_len, d_v)
↓ 拼接
Concat: (batch, seq_len, h * d_v) = (batch, seq_len, d_model)
↓ 输出投影
output: (batch, seq_len, d_model)
3.3 论文中的标准配置
| 参数 | Base 模型 | Large 模型 |
|---|---|---|
| d_model | 512 | 1024 |
| num_heads | 8 | 16 |
| d_k = d_v | 64 | 64 |
| FFN 维度 | 2048 | 4096 |
四、代码实现:从理论到代码
4.1 最简版本(纯 Python / NumPy)
import numpy as np
def softmax(x, axis=-1):
"""Numerically stable softmax"""
exp_x = np.exp(x - np.max(x, axis=axis, keepdims=True))
return exp_x / np.sum(exp_x, axis=axis, keepdims=True)
def self_attention(Q, K, V, d_k):
"""
简化版 Self-Attention(无 batch 维度)
参数:
Q: (seq_len, d_k) 查询矩阵
K: (seq_len, d_k) 键矩阵
V: (seq_len, d_v) 值矩阵
d_k: 缩放因子
返回:
output: (seq_len, d_v) 注意力输出
weights: (seq_len, seq_len) 注意力权重
"""
# Step 1: 计算点积注意力分数
scores = np.dot(Q, K.T) / np.sqrt(d_k)
# Step 2: Softmax 归一化
attention_weights = softmax(scores, axis=-1)
# Step 3: 加权求和
output = np.dot(attention_weights, V)
return output, attention_weights
# 测试
d_k = 64
seq_len = 5
d_v = 64
# 模拟输入
Q = np.random.randn(seq_len, d_k)
K = np.random.randn(seq_len, d_k)
V = np.random.randn(seq_len, d_v)
output, weights = self_attention(Q, K, V, d_k)
print(f"输出形状: {output.shape}")
print(f"注意力权重形状: {weights.shape}")
print(f"权重验证(每行和=1): {weights.sum(axis=1)}")
4.2 PyTorch 完整实现
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class MultiHeadAttention(nn.Module):
"""Multi-Head Attention 的完整 PyTorch 实现"""
def __init__(self, d_model, num_heads, dropout=0.1):
super().__init__()
assert d_model % num_heads == 0, "d_model 必须能被 num_heads 整除"
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads # 每个注意力头的维度
# 四个线性变换层
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, query, key, value, mask=None):
"""
参数:
query: (batch, seq_len, d_model)
key: (batch, seq_len, d_model)
value: (batch, seq_len, d_model)
mask: (batch, seq_len, seq_len) 或 (batch, 1, seq_len, seq_len)
返回:
output: (batch, seq_len, d_model)
attention_weights: (batch, num_heads, seq_len, seq_len)
"""
batch_size = query.size(0)
seq_len = query.size(1)
# ========== Step 1: 线性投影 + 分头 ==========
# Q, K, V: (batch, seq_len, d_model)
Q = self.W_q(query)
K = self.W_k(key)
V = self.W_v(value)
# 视图分割:最后一项分成 num_heads × d_k
# (batch, seq_len, num_heads, d_k) 然后转置
# → (batch, num_heads, seq_len, d_k)
Q = Q.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
K = K.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
V = V.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
# ========== Step 2: 计算注意力分数 ==========
# scores: (batch, num_heads, seq_len, seq_len)
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
# ========== Step 3: 应用 Mask ==========
if mask is not None:
# 支持两种 mask 格式
if mask.dim() == 2: # (seq_len, seq_len)
mask = mask.unsqueeze(0).unsqueeze(0) # → (1, 1, seq_len, seq_len)
elif mask.dim() == 3: # (batch, seq_len, seq_len)
mask = mask.unsqueeze(1) # → (batch, 1, seq_len, seq_len)
scores = scores.masked_fill(mask == 0, float('-inf'))
# ========== Step 4: Softmax + Dropout ==========
attention_weights = F.softmax(scores, dim=-1)
attention_weights = self.dropout(attention_weights)
# ========== Step 5: 加权求和 ==========
# context: (batch, num_heads, seq_len, d_k)
context = torch.matmul(attention_weights, V)
# ========== Step 6: 合并多头 ==========
# (batch, num_heads, seq_len, d_k) → (batch, seq_len, num_heads, d_k)
context = context.transpose(1, 2).contiguous()
# → (batch, seq_len, d_model)
context = context.view(batch_size, seq_len, self.d_model)
# ========== Step 7: 最终线性投影 ==========
output = self.W_o(context)
return output, attention_weights
4.3 使用示例
# 创建一个 Multi-Head Attention 层
d_model = 512
num_heads = 8
attention = MultiHeadAttention(d_model, num_heads)
# 模拟输入
batch_size = 2
seq_len = 10
x = torch.randn(batch_size, seq_len, d_model)
# 前向传播
output, attn_weights = attention(x, x, x)
print(f"输出形状: {output.shape}") # (2, 10, 512)
print(f"注意力权重形状: {attn_weights.shape}") # (2, 8, 10, 10)
# 可视化第 1 个样本第 1 个头的注意力权重
# attn_weights[0, 0] 形状是 (10, 10)
五、Mask 的作用与实现
5.1 为什么需要 Mask?
在 Transformer 中,Mask 主要有两种用途:
用途 1:Padding Mask
输入序列长度不一,需要 padding 到统一长度。但 padding 位置不应该参与注意力计算。
原始句子: ["The", "cat", "sat"]
Padding后: ["The", "cat", "sat", "[PAD]", "[PAD]"]
注意力权重应该是:
The cat sat [PAD] [PAD]
The [ 0.4 0.3 0.2 0.0 0.0 ]
cat [ 0.3 0.4 0.2 0.0 0.0 ]
sat [ 0.2 0.2 0.4 0.0 0.0 ]
[PAD] [ 0.0 0.0 0.0 0.0 0.0 ]
[PAD] [ 0.0 0.0 0.0 0.0 0.0 ]
↑
padding 位置权重为 0
用途 2:因果掩码(Causal Mask / Look-Ahead Mask)
在解码器中,预测第 N 个词时不能看到第 N 个词之后的任何信息。
目标序列: ["The", "cat", "sat", "[EOS]"]
允许的注意力连接(✓ 表示可见):
Step1 Step2 Step3 Step4
Step1 → ✓ ✗ ✗ ✗
Step2 → ✓ ✓ ✗ ✗
Step3 → ✓ ✓ ✓ ✗
Step4 → ✓ ✓ ✓ ✓
解码器中真正的注意力权重:
Step1 [ 1.0 0.0 0.0 0.0 ]
Step2 [ 0.5 0.5 0.0 0.0 ]
Step3 [ 0.3 0.3 0.4 0.0 ]
Step4 [ 0.2 0.2 0.2 0.4 ]
5.2 Mask 的 PyTorch 实现
def create_padding_mask(seq, pad_idx=0):
"""
创建 Padding Mask
参数:
seq: (batch, seq_len) token IDs
pad_idx: padding 的 token ID,默认为 0
返回:
mask: (batch, 1, 1, seq_len) True 表示有效位置,False 表示需要 mask
"""
mask = (seq != pad_idx).unsqueeze(1).unsqueeze(2)
return mask # (batch, 1, 1, seq_len) → 广播后 (batch, num_heads, seq_len, seq_len)
def create_causal_mask(seq_len):
"""
创建因果掩码(上三角 mask)
返回:
mask: (1, 1, seq_len, seq_len) 上三角为 False(不可见)
"""
# torch.triu: 上三角(不含对角线)为 1
mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
# → (seq_len, seq_len)
# 对角线及以上为 True(需要 mask),对角线以下为 False(可见)
mask = mask.unsqueeze(0).unsqueeze(0) # → (1, 1, seq_len, seq_len)
return ~mask # 取反:True 表示可见,False 表示 mask
# 使用示例
batch_size = 2
seq_len = 5
# Padding Mask
seq = torch.tensor([[1, 2, 3, 0, 0], # 句子1: [PAD]=0
[1, 2, 0, 0, 0]]) # 句子2: [PAD]=0
padding_mask = create_padding_mask(seq)
print("Padding Mask:")
print(padding_mask[0]) # 句子1的 mask
# 因果 Mask
causal_mask = create_causal_mask(seq_len)
print("\n因果 Mask:")
print(causal_mask[0, 0])
六、注意力权重的可视化
6.1 典型的注意力模式
不同层的注意力头会捕捉不同类型的依赖:
Layer 1 Head 1(捕捉局部关系):
The cat sat on mat
The [ 0.5 0.3 0.1 0.05 0.05 ]
cat [ 0.3 0.4 0.2 0.05 0.05 ]
sat [ 0.1 0.2 0.4 0.2 0.1 ]
on [ 0.05 0.05 0.2 0.4 0.3 ]
mat [ 0.05 0.05 0.1 0.3 0.5 ]
↑ 局部性:关注相邻词
Layer 3 Head 5(捕捉语法关系):
The cat sat on mat
The [ 0.4 0.1 0.1 0.1 0.1 ]
cat [ 0.1 0.6 0.1 0.1 0.1 ] ← "cat" 关注自身(主语)
sat [ 0.1 0.1 0.6 0.1 0.1 ] ← "sat" 关注自身(谓语)
on [ 0.1 0.1 0.1 0.6 0.1 ]
mat [ 0.1 0.1 0.1 0.1 0.6 ]
↑ 语法:每个词更关注自己(完整实体表示)
6.2 可视化代码
import matplotlib.pyplot as plt
import seaborn as sns
def plot_attention_weights(weights, tokens, save_path=None):
"""
可视化注意力权重热力图
参数:
weights: (num_heads, seq_len, seq_len) 或 (seq_len, seq_len)
tokens: list of strings,词元列表
save_path: 可选,保存路径
"""
if weights.dim() == 4:
# 取第一个样本的第一个头
weights = weights[0, 0].detach().numpy()
else:
weights = weights.detach().numpy()
plt.figure(figsize=(10, 8))
sns.heatmap(weights,
xticklabels=tokens,
yticklabels=tokens,
cmap='Blues',
annot=False,
fmt='.2f')
plt.xlabel('Key 位置')
plt.ylabel('Query 位置')
plt.title('Self-Attention 权重热力图')
plt.tight_layout()
if save_path:
plt.savefig(save_path)
plt.show()
# 使用示例
tokens = ["The", "cat", "sat", "on", "the", "mat"]
# 假设已经得到 attention_weights
# plot_attention_weights(attention_weights, tokens)
七、Self-Attention 的复杂度分析
7.1 时间复杂度
Self-Attention 的主要计算:
Step 1: Q, K, V 投影 O(n · d_model · d_k) × 3
Step 2: 计算 QK^T O(n² · d_k)
Step 3: Softmax O(n²)
Step 4: 加权求和 O(n² · d_v)
总计: O(n² · d_model)
其中 n 是序列长度,d_model 是模型维度。
关键点:Self-Attention 的时间复杂度是序列长度的平方(O(n²))。这是 Transformer 的主要瓶颈。
7.2 空间复杂度
存储 Q, K, V: O(n · d_model) × 3
存储注意力权重矩阵: O(n²)
存储输出: O(n · d_model)
总计: O(n² + n · d_model)
注意力权重矩阵 O(n²) 是最大的开销。
7.3 复杂度对比
| 模型 | 注意力复杂度 |
|---|---|
| RNN/LSTM | O(n · d) |
| Self-Attention | O(n² · d) |
| 局部注意力 (窗口= k) | O(n · k · d) |
| Linformer | O(n · k) |
| Reformer | O(n · log n) |
这就是为什么 Long Context 是一个热门研究方向——标准 Self-Attention 无法直接处理超长序列。
八、Self-Attention 在 Transformer 中的位置
┌─────────────────────────────────────────────────────────────┐
│ Encoder Layer │
│ │
│ Input X ──→ Multi-Head Self-Attention ──→ Add & Norm ──┐ │
│ │ │
│ ↑ │ │
│ │ │ │
│ └──────── Feed Forward ────────────────────────┘ │
│ │
│ × N 层 │
└─────────────────────────────────────────────────────────────┘
每层内部:
x_input
│
├──→ Multi-Head Self-Attention ──→ Add(x_input, attention_output)
│ │
│ ↓
│ LayerNorm
│ │
│ ↓ (sublayer_1 output)
│ │
└──→ Feed Forward Network ─────────→ Add(sublayer_1, ffn_output)
│
↓
LayerNorm
│
↓
x_output (传给下一层)
Self-Attention 的核心作用:
- Encoder Self-Attention:让每个位置能看到序列中所有其他位置,学习输入的上下文表示
- Decoder Self-Attention:类似,但有 Mask,确保自回归生成时不泄露未来信息
- Cross Attention:Decoder 层中,Q 来自 Decoder,K/V 来自 Encoder 输出,实现跨模块交互
九、关键设计选择的原因
| 设计选择 | 选择 | 原因 |
|---|---|---|
| 缩放因子 √d_k | QK^T / √d_k | 防止 d_k 较大时 softmax 梯度消失 |
| 多头注意力 | h 个独立头 | 不同头学习不同类型的依赖关系 |
| Q=K=V 输入 | 自注意力 | 让序列内部进行自我比较,学习内部结构 |
| Linear 投影 | 学习的 W_Q, W_K, W_V | 增加模型表达能力,让投影空间更有意义 |
| 拼接后投影 | Concat → W_O | 合并多头的不同子空间表示 |
十、总结与延伸
10.1 核心要点回顾
-
Self-Attention 通过直接建模任意位置间的依赖,解决了 RNN 的长距离依赖问题
-
Q/K/V 三元组让每个词既能"问问题"(Query)也能"回答问题"(Key/Value)
-
缩放因子 √d_k 是关键细节,防止大维度下的梯度消失
-
Multi-Head 扩展了模型的表示能力,不同头学习不同类型的关系
-
Mask 机制让 Transformer 能处理变长序列和控制信息流动
10.2 延伸阅读方向
| 方向 | 关键技术 | 适用场景 |
|---|---|---|
| 高效注意力 | Flash Attention, Sparse Attention | 长上下文 |
| 位置编码 | RoPE, ALiBi,绝对位置编码 | 位置感知 |
| 注意力变体 | Grouped Query Attention | 高效推理 |
| 跨模态 | Cross Attention | 多模态融合 |
参考资料
- Vaswani et al., “Attention Is All You Need”, NeurIPS 2017
- The Illustrated Transformer: http://nlp.seas.harvard.edu/2018/04/03/attention.html
- Lilian Weng, “Attention? Attention!”, Lil’Log
- PyTorch Transformer Documentation: https://pytorch.org/docs/stable/generated/torch.nn.Transformer.html
📢 本专栏持续更新,下一篇:《Feed Forward Network:注意力之外的另一条腿》
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐



所有评论(0)