手撕 Transformer 核心组件:多头注意力 + 层归一化(PyTorch 逐行详解)
【学习记录】手撕 Transformer 核心组件:多头注意力 + 层归一化(PyTorch 逐行详解)
本文从零实现 Transformer 的两大核心模块:多头注意力(Multi-Head Attention) 和 层归一化(Layer Normalization)。包含完整的 PyTorch 代码、形状变化图解、数学原理以及常见陷阱。理解这些代码,就掌握了现代大模型的底层基石。
📌 题目清单
| 模块 | 核心考点 |
|---|---|
| 多头注意力 | 缩放点积、因果掩码、多头拆分与合并 |
| 层归一化 | 特征维度标准化、可学习参数 |
一、多头注意力(Multi-Head Attention)
1.1 原理解析
多头注意力的核心公式:
Attention(Q,K,V)=softmax(QKTdk)V \text{Attention}(Q,K,V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V Attention(Q,K,V)=softmax(dkQKT)V
多头则是在 d_model 维度上切分成 n_head 个头,每个头独立计算注意力,最后拼接并通过线性变换输出。
作用:让模型从不同子空间同时关注信息,增强表达能力。
1.2 代码逐行详解
import torch
from torch import nn
import math
# 生成模拟输入:batch_size=16, 序列长度=64, 特征维度=512
X = torch.randn(16, 64, 512)
d_model = 512
n_head = 8
定义多头注意力类
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, n_head):
super().__init__()
self.n_head = n_head
self.d_model = d_model
# Q/K/V 线性映射层
self.w_q = nn.Linear(d_model, d_model)
self.w_k = nn.Linear(d_model, d_model)
self.w_v = nn.Linear(d_model, d_model)
self.w_o = nn.Linear(d_model, d_model) # 输出映射
self.softmax = nn.Softmax(dim=-1)
参数说明:
d_model:模型隐藏维度(512)n_head:注意力头数(8)d_head = d_model // n_head = 64,每个头的维度。
前向传播
def forward(self, q, k, v, mask=None):
B, T, D = q.shape # B=16, T=64, D=512
n_d = self.d_model // self.n_head # 64
步骤 1:线性映射并拆分多头
q = self.w_q(q) # (B, T, 512)
k = self.w_k(k)
v = self.w_v(v)
# view 拆分最后一维,transpose 交换序列维和头维
q = q.view(B, T, self.n_head, n_d).transpose(1, 2) # (B, n_head, T, n_d)
k = k.view(B, T, self.n_head, n_d).transpose(1, 2)
v = v.view(B, T, self.n_head, n_d).transpose(1, 2)
形状变化图解:
输入: (B, T, 512)
view: (B, T, 8, 64)
transpose: (B, 8, T, 64) ← 每个头独立处理序列
步骤 2:缩放点积注意力
score = q @ k.transpose(-2, -1) / math.sqrt(n_d) # (B, n_head, T, T)
q @ k.transpose(-2,-1):矩阵乘法(B,8,T,64) @ (B,8,64,T)→(B,8,T,T)- 除以
√64 = 8防止 softmax 梯度饱和。
步骤 3:因果掩码(可选,用于自回归)
# 生成下三角掩码(只允许看到当前及之前的位置)
causal_mask = torch.tril(torch.ones(T, T, dtype=torch.bool))
# 将未来位置填充为极小值
score = score.masked_fill(causal_mask == 0, -1e9)
步骤 4:Softmax 与加权求和
score = self.softmax(score) # 注意力权重,形状不变
out = score @ v # (B, n_head, T, n_d)
步骤 5:合并多头并输出
# 形状: (B, n_head, T, n_d) -> (B, T, n_head, n_d) -> (B, T, d_model)
out = out.transpose(1, 2).contiguous().view(B, T, self.d_model)
out = self.w_o(out) # (B, T, 512)
return out
1.3 测试代码
attn = MultiHeadAttention(d_model, n_head)
Y = attn(X, X, X) # 自注意力
print(Y.shape) # torch.Size([16, 64, 512])
1.4 复杂度分析
| 操作 | 时间复杂度 | 空间复杂度 |
|---|---|---|
| Q/K/V 线性映射 | O(B × T × d_model²) | O(B × T × d_model) |
| 拆分多头 | O(B × T × d_model) | O(B × n_head × T × d_head) |
| 注意力分数计算 | O(B × n_head × T² × d_head) | O(B × n_head × T²) |
| Softmax & 加权求和 | O(B × n_head × T²) | O(B × n_head × T²) |
| 合并输出 | O(B × T × d_model²) | O(B × T × d_model) |
| 总体 | O(B × n_head × T² × d_head) | O(B × n_head × T²) |
其中
d_head = d_model / n_head,复杂度也可写作O(B × T² × d_model)。
二、层归一化(Layer Normalization)
2.1 原理解析
层归一化对每个样本的特征维度进行标准化:
LN(x)=γ⋅x−μσ2+ϵ+β \text{LN}(x) = \gamma \cdot \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta LN(x)=γ⋅σ2+ϵx−μ+β
μ和σ²沿特征维度(最后一维)计算。γ(缩放)和β(偏移)是可学习参数,恢复模型表达能力。
与 BatchNorm 的区别:BatchNorm 沿 batch 维标准化,依赖 batch size;LayerNorm 独立于 batch,更适合序列模型。
2.2 代码实现
class LayerNorm(nn.Module):
def __init__(self, d_model, eps=1e-12):
super().__init__()
self.gamma = nn.Parameter(torch.ones(d_model))
self.beta = nn.Parameter(torch.zeros(d_model))
self.eps = eps
def forward(self, x):
# 沿最后一维(特征维)计算均值和方差
mean = x.mean(-1, keepdim=True) # (B, T, 1)
var = x.var(-1, unbiased=False, keepdim=True) # (B, T, 1)
out = (x - mean) / torch.sqrt(var + self.eps)
out = self.gamma * out + self.beta
return out
2.3 测试代码
X = torch.randn(2, 5, 512) # (batch, seq_len, d_model)
ln = LayerNorm(512)
Y = ln(X)
print(Y.shape) # torch.Size([2, 5, 512])
print(Y.mean(dim=-1)) # 接近 0
print(Y.std(dim=-1)) # 接近 1
2.4 复杂度分析
- 时间复杂度:O(B × T × d_model),仅需一次遍历计算均值和方差。
- 空间复杂度:O(B × T × d_model)(输出与输入同大小),外加两个可学习参数(O(d_model))。
三、完整代码整合(可直接运行)
import torch
from torch import nn
import math
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, n_head):
super().__init__()
self.n_head = n_head
self.d_model = d_model
self.w_q = nn.Linear(d_model, d_model)
self.w_k = nn.Linear(d_model, d_model)
self.w_v = nn.Linear(d_model, d_model)
self.w_o = nn.Linear(d_model, d_model)
self.softmax = nn.Softmax(dim=-1)
def forward(self, q, k, v):
B, T, D = q.shape
n_d = self.d_model // self.n_head
q = self.w_q(q).view(B, T, self.n_head, n_d).transpose(1, 2)
k = self.w_k(k).view(B, T, self.n_head, n_d).transpose(1, 2)
v = self.w_v(v).view(B, T, self.n_head, n_d).transpose(1, 2)
score = q @ k.transpose(-2, -1) / math.sqrt(n_d)
# 因果掩码
mask = torch.tril(torch.ones(T, T, dtype=torch.bool, device=q.device))
score = score.masked_fill(mask == 0, -1e9)
score = self.softmax(score)
out = score @ v
out = out.transpose(1, 2).contiguous().view(B, T, self.d_model)
return self.w_o(out)
class LayerNorm(nn.Module):
def __init__(self, d_model, eps=1e-12):
super().__init__()
self.gamma = nn.Parameter(torch.ones(d_model))
self.beta = nn.Parameter(torch.zeros(d_model))
self.eps = eps
def forward(self, x):
mean = x.mean(-1, keepdim=True)
var = x.var(-1, unbiased=False, keepdim=True)
return self.gamma * (x - mean) / torch.sqrt(var + self.eps) + self.beta
# 测试
if __name__ == "__main__":
X = torch.randn(16, 64, 512)
attn = MultiHeadAttention(512, 8)
Y = attn(X, X, X)
print("Attention output shape:", Y.shape)
ln = LayerNorm(512)
Z = ln(X)
print("LayerNorm output shape:", Z.shape)
print("mean near zero:", torch.allclose(Z.mean(dim=-1), torch.zeros_like(Z.mean(dim=-1)), atol=1e-5))
四、常见问题与改进建议
| 问题 | 说明 | 改进 |
|---|---|---|
| 因果掩码重复生成 | 每次 forward 都调用 torch.tril |
在 __init__ 中用 register_buffer 预存 |
| 未使用 Dropout | 标准 Transformer 注意力后需要 dropout | 添加 self.dropout = nn.Dropout(p) |
| 未加残差连接 | 注意力输出通常与输入相加 | 在外层使用 x + attn(x) |
| 未处理 padding mask | 只做了因果掩码,没有 padding 掩码 | 接受 mask 参数并与因果掩码合并 |
使用 -1e9 而非 -inf |
虽然有效,但不如 float('-inf') 严谨 |
可用 float('-inf')(注意 softmax 处理) |
五、总结
本文手写实现了 Transformer 的两大核心组件:
| 模块 | 核心操作 | 输入/输出形状 | 复杂度 |
|---|---|---|---|
| 多头注意力 | 拆分头 → 缩放点积 → 掩码 → 加权求和 → 合并 | (B,T,D) → (B,T,D) | O(B·T²·D) |
| 层归一化 | 特征维标准化 + 可学习仿射 | (B,T,D) → (B,T,D) | O(B·T·D) |
通过逐行代码和形状变化图解,应该能完全理解这两个模块的底层实现。下一步可以尝试添加位置编码、前馈网络(FFN) 和 残差连接,构建完整的 Transformer 编码器。
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐
所有评论(0)