【学习记录】手撕 Transformer 核心组件:多头注意力 + 层归一化(PyTorch 逐行详解)

本文从零实现 Transformer 的两大核心模块:多头注意力(Multi-Head Attention)层归一化(Layer Normalization)。包含完整的 PyTorch 代码、形状变化图解、数学原理以及常见陷阱。理解这些代码,就掌握了现代大模型的底层基石。


📌 题目清单

模块 核心考点
多头注意力 缩放点积、因果掩码、多头拆分与合并
层归一化 特征维度标准化、可学习参数

一、多头注意力(Multi-Head Attention)

1.1 原理解析

多头注意力的核心公式:

Attention(Q,K,V)=softmax(QKTdk)V \text{Attention}(Q,K,V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V Attention(Q,K,V)=softmax(dk QKT)V

多头则是在 d_model 维度上切分成 n_head 个头,每个头独立计算注意力,最后拼接并通过线性变换输出。

作用:让模型从不同子空间同时关注信息,增强表达能力。

1.2 代码逐行详解

import torch
from torch import nn
import math

# 生成模拟输入:batch_size=16, 序列长度=64, 特征维度=512
X = torch.randn(16, 64, 512)
d_model = 512
n_head = 8
定义多头注意力类
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_head):
        super().__init__()
        self.n_head = n_head
        self.d_model = d_model
        
        # Q/K/V 线性映射层
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        self.w_o = nn.Linear(d_model, d_model)   # 输出映射
        
        self.softmax = nn.Softmax(dim=-1)

参数说明

  • d_model:模型隐藏维度(512)
  • n_head:注意力头数(8)
  • d_head = d_model // n_head = 64,每个头的维度。
前向传播
    def forward(self, q, k, v, mask=None):
        B, T, D = q.shape          # B=16, T=64, D=512
        n_d = self.d_model // self.n_head   # 64

步骤 1:线性映射并拆分多头

        q = self.w_q(q)   # (B, T, 512)
        k = self.w_k(k)
        v = self.w_v(v)
        
        # view 拆分最后一维,transpose 交换序列维和头维
        q = q.view(B, T, self.n_head, n_d).transpose(1, 2)  # (B, n_head, T, n_d)
        k = k.view(B, T, self.n_head, n_d).transpose(1, 2)
        v = v.view(B, T, self.n_head, n_d).transpose(1, 2)

形状变化图解

输入: (B, T, 512)
view: (B, T, 8, 64)
transpose: (B, 8, T, 64)  ← 每个头独立处理序列

步骤 2:缩放点积注意力

        score = q @ k.transpose(-2, -1) / math.sqrt(n_d)   # (B, n_head, T, T)
  • q @ k.transpose(-2,-1):矩阵乘法 (B,8,T,64) @ (B,8,64,T)(B,8,T,T)
  • 除以 √64 = 8 防止 softmax 梯度饱和。

步骤 3:因果掩码(可选,用于自回归)

        # 生成下三角掩码(只允许看到当前及之前的位置)
        causal_mask = torch.tril(torch.ones(T, T, dtype=torch.bool))
        # 将未来位置填充为极小值
        score = score.masked_fill(causal_mask == 0, -1e9)

步骤 4:Softmax 与加权求和

        score = self.softmax(score)      # 注意力权重,形状不变
        out = score @ v                  # (B, n_head, T, n_d)

步骤 5:合并多头并输出

        # 形状: (B, n_head, T, n_d) -> (B, T, n_head, n_d) -> (B, T, d_model)
        out = out.transpose(1, 2).contiguous().view(B, T, self.d_model)
        out = self.w_o(out)              # (B, T, 512)
        return out

1.3 测试代码

attn = MultiHeadAttention(d_model, n_head)
Y = attn(X, X, X)        # 自注意力
print(Y.shape)           # torch.Size([16, 64, 512])

1.4 复杂度分析

操作 时间复杂度 空间复杂度
Q/K/V 线性映射 O(B × T × d_model²) O(B × T × d_model)
拆分多头 O(B × T × d_model) O(B × n_head × T × d_head)
注意力分数计算 O(B × n_head × T² × d_head) O(B × n_head × T²)
Softmax & 加权求和 O(B × n_head × T²) O(B × n_head × T²)
合并输出 O(B × T × d_model²) O(B × T × d_model)
总体 O(B × n_head × T² × d_head) O(B × n_head × T²)

其中 d_head = d_model / n_head,复杂度也可写作 O(B × T² × d_model)


二、层归一化(Layer Normalization)

2.1 原理解析

层归一化对每个样本的特征维度进行标准化:

LN(x)=γ⋅x−μσ2+ϵ+β \text{LN}(x) = \gamma \cdot \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta LN(x)=γσ2+ϵ xμ+β

  • μσ² 沿特征维度(最后一维)计算。
  • γ(缩放)和 β(偏移)是可学习参数,恢复模型表达能力。

与 BatchNorm 的区别:BatchNorm 沿 batch 维标准化,依赖 batch size;LayerNorm 独立于 batch,更适合序列模型。

2.2 代码实现

class LayerNorm(nn.Module):
    def __init__(self, d_model, eps=1e-12):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(d_model))
        self.beta  = nn.Parameter(torch.zeros(d_model))
        self.eps = eps

    def forward(self, x):
        # 沿最后一维(特征维)计算均值和方差
        mean = x.mean(-1, keepdim=True)        # (B, T, 1)
        var  = x.var(-1, unbiased=False, keepdim=True)  # (B, T, 1)
        out = (x - mean) / torch.sqrt(var + self.eps)
        out = self.gamma * out + self.beta
        return out

2.3 测试代码

X = torch.randn(2, 5, 512)    # (batch, seq_len, d_model)
ln = LayerNorm(512)
Y = ln(X)
print(Y.shape)                 # torch.Size([2, 5, 512])
print(Y.mean(dim=-1))          # 接近 0
print(Y.std(dim=-1))           # 接近 1

2.4 复杂度分析

  • 时间复杂度:O(B × T × d_model),仅需一次遍历计算均值和方差。
  • 空间复杂度:O(B × T × d_model)(输出与输入同大小),外加两个可学习参数(O(d_model))。

三、完整代码整合(可直接运行)

import torch
from torch import nn
import math

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_head):
        super().__init__()
        self.n_head = n_head
        self.d_model = d_model
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        self.w_o = nn.Linear(d_model, d_model)
        self.softmax = nn.Softmax(dim=-1)
        
    def forward(self, q, k, v):
        B, T, D = q.shape
        n_d = self.d_model // self.n_head
        
        q = self.w_q(q).view(B, T, self.n_head, n_d).transpose(1, 2)
        k = self.w_k(k).view(B, T, self.n_head, n_d).transpose(1, 2)
        v = self.w_v(v).view(B, T, self.n_head, n_d).transpose(1, 2)
        
        score = q @ k.transpose(-2, -1) / math.sqrt(n_d)
        # 因果掩码
        mask = torch.tril(torch.ones(T, T, dtype=torch.bool, device=q.device))
        score = score.masked_fill(mask == 0, -1e9)
        score = self.softmax(score)
        out = score @ v
        
        out = out.transpose(1, 2).contiguous().view(B, T, self.d_model)
        return self.w_o(out)

class LayerNorm(nn.Module):
    def __init__(self, d_model, eps=1e-12):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(d_model))
        self.beta = nn.Parameter(torch.zeros(d_model))
        self.eps = eps
        
    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        var = x.var(-1, unbiased=False, keepdim=True)
        return self.gamma * (x - mean) / torch.sqrt(var + self.eps) + self.beta

# 测试
if __name__ == "__main__":
    X = torch.randn(16, 64, 512)
    attn = MultiHeadAttention(512, 8)
    Y = attn(X, X, X)
    print("Attention output shape:", Y.shape)
    
    ln = LayerNorm(512)
    Z = ln(X)
    print("LayerNorm output shape:", Z.shape)
    print("mean near zero:", torch.allclose(Z.mean(dim=-1), torch.zeros_like(Z.mean(dim=-1)), atol=1e-5))

四、常见问题与改进建议

问题 说明 改进
因果掩码重复生成 每次 forward 都调用 torch.tril __init__ 中用 register_buffer 预存
未使用 Dropout 标准 Transformer 注意力后需要 dropout 添加 self.dropout = nn.Dropout(p)
未加残差连接 注意力输出通常与输入相加 在外层使用 x + attn(x)
未处理 padding mask 只做了因果掩码,没有 padding 掩码 接受 mask 参数并与因果掩码合并
使用 -1e9 而非 -inf 虽然有效,但不如 float('-inf') 严谨 可用 float('-inf')(注意 softmax 处理)

五、总结

本文手写实现了 Transformer 的两大核心组件:

模块 核心操作 输入/输出形状 复杂度
多头注意力 拆分头 → 缩放点积 → 掩码 → 加权求和 → 合并 (B,T,D) → (B,T,D) O(B·T²·D)
层归一化 特征维标准化 + 可学习仿射 (B,T,D) → (B,T,D) O(B·T·D)

通过逐行代码和形状变化图解,应该能完全理解这两个模块的底层实现。下一步可以尝试添加位置编码前馈网络(FFN)残差连接,构建完整的 Transformer 编码器。

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐