环境声明

  • Python版本: Python 3.10+
  • PyTorch版本: PyTorch 2.0+
  • 开发工具: PyCharm / VS Code / Jupyter Notebook
  • 硬件要求: 建议配备NVIDIA GPU(显存8GB以上)用于训练
  • 操作系统: Windows / macOS / Linux(通用)

学习目标和摘要

学习目标:

  1. 深入理解Transformer的核心架构和工作原理
  2. 掌握自注意力机制、多头注意力的数学推导
  3. 理解BERT、GPT、T5三大模型的设计差异
  4. 学会使用PyTorch实现完整的Transformer模型
  5. 了解Transformer的高效变体和最新进展

摘要: 本章将全面解析Transformer架构,从2017年Google提出的原始论文出发,深入探讨编码器-解码器结构、自注意力机制、位置编码等核心组件。同时对比分析BERT、GPT、T5三大经典模型,并介绍Mamba、RetNet等2025年最新变体,最后提供完整的PyTorch实现代码。


1. Transformer整体架构

1.1 从RNN到Transformer的演进

在Transformer出现之前,序列建模主要依赖RNN及其变体(LSTM、GRU)。但RNN存在两个致命缺陷:

  1. 顺序计算瓶颈: 必须按时间步顺序处理,无法并行化
  2. 长距离依赖问题: 随着序列增长,早期信息难以传递到后期

Transformer通过自注意力机制彻底解决了这两个问题,实现了完全的并行计算和全局依赖建模。

1.2 编码器-解码器架构

Transformer采用经典的编码器-解码器(Encoder-Decoder)架构:

输入序列 → [编码器] → 上下文表示 → [解码器] → 输出序列

编码器(Encoder):

  • 由N个相同的编码器层堆叠而成(原论文中N=6)
  • 每层包含两个子层:多头自注意力 + 前馈神经网络
  • 每个子层后接残差连接和层归一化

解码器(Decoder):

  • 由N个相同的解码器层堆叠而成
  • 每层包含三个子层:掩码多头自注意力 + 编码器-解码器注意力 + 前馈神经网络
  • 同样使用残差连接和层归一化

1.3 自注意力机制的核心思想

自注意力(Self-Attention)允许序列中的每个位置都能"关注"到其他所有位置,从而捕捉全局依赖关系。

计算过程:

  1. 为每个输入向量生成Query、Key、Value三个向量
  2. 计算Query与所有Key的点积,得到注意力分数
  3. 对分数进行缩放和Softmax归一化
  4. 用归一化后的权重对Value进行加权求和

数学公式:

Attention(Q, K, V) = softmax(QK^T / sqrt(d_k)) * V

其中d_k是Key向量的维度,sqrt(d_k)用于缩放点积,防止Softmax进入梯度饱和区。


2. 位置编码详解

2.1 为什么需要位置编码

与RNN不同,Transformer没有循环结构,无法天然感知序列顺序。因此需要显式注入位置信息。

2.2 正弦位置编码(Sinusoidal)

原论文提出的位置编码使用不同频率的正弦和余弦函数:

PE(pos, 2i)   = sin(pos / 10000^(2i/d_model))
PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))

优点:

  • 可以处理任意长度的序列
  • 相对位置可以通过线性变换获得
  • 值域在[-1, 1]之间,与词嵌入兼容

2.3 可学习位置编码(Learnable)

BERT等模型采用可学习的位置嵌入:

position_embedding = nn.Embedding(max_seq_len, d_model)

优点: 模型可以自适应学习最优的位置表示
缺点: 受限于预设的最大序列长度

2.4 旋转位置编码(RoPE)

RoPE(Rotary Position Embedding)是2021年提出的相对位置编码方法,被广泛用于LLaMA等现代大模型。

核心思想:通过旋转矩阵将位置信息编码到Query和Key向量中:

q' = RoPE(q, m)
k' = RoPE(k, n)

其中m、n分别表示query和key的位置索引。

RoPE的优势:

  • 显式建模相对位置关系
  • 外推性好,可以处理比训练时更长的序列
  • 与Flash Attention等优化技术兼容

3. 多头注意力与前馈网络

3.1 多头注意力机制

单一的注意力机制可能只关注特定类型的依赖关系。多头注意力通过多组独立的Q、K、V投影,让模型同时关注不同子空间的信息。

MultiHead(Q, K, V) = Concat(head_1, ..., head_h) * W^O

where head_i = Attention(QW_i^Q, KW_i^K, VW_i^V)

原论文中:

  • 头数 h = 8
  • d_model = 512
  • 每个头的维度 d_k = d_v = 64

3.2 前馈神经网络(FFN)

每个编码器和解码器层都包含一个全连接前馈网络:

FFN(x) = max(0, xW_1 + b_1)W_2 + b_2

这是一个两层的MLP:

  • 第一层将维度从d_model扩展到4*d_model
  • 使用ReLU激活函数
  • 第二层映射回d_model维度

FFN的作用是为每个位置独立地添加非线性变换能力。

3.3 层归一化与残差连接

残差连接(Residual Connection):

output = LayerNorm(x + Sublayer(x))

残差连接解决了深层网络的梯度消失问题,使得Transformer可以堆叠很多层。

层归一化(Layer Normalization):
与批归一化不同,层归一化对每个样本的所有特征进行归一化:

LayerNorm(x) = (x - mean) / sqrt(var + eps) * gamma + beta

层归一化在序列模型中更稳定,因为序列长度可能变化。


4. BERT:双向编码器表示

4.1 BERT的设计哲学

BERT(Bidirectional Encoder Representations from Transformers)于2018年由Google提出,核心创新是深度双向预训练

与之前只能从左到右或从右到左的单向模型不同,BERT通过Masked Language Model(MLM)实现了真正的双向上下文建模。

4.2 预训练任务

任务一:Masked Language Model(MLM)

随机遮蔽输入序列中15%的词元,让模型预测被遮蔽的词:

  • 80%的概率用MASK替换
  • 10%的概率用随机词替换
  • 10%的概率保持不变

这种策略强制模型基于双向上下文进行预测。

任务二:Next Sentence Prediction(NSP)

输入两个句子,判断第二个句子是否是第一个句子的下一句:

  • 50%的概率是真实的下一句(IsNext)
  • 50%的概率是随机采样的句子(NotNext)

NSP帮助模型理解句子间的关系,对下游的问答、推理任务很重要。

4.3 BERT的输入表示

BERT的输入由三部分嵌入相加组成:

  1. 词嵌入(Token Embeddings): 词本身的向量表示
  2. 段嵌入(Segment Embeddings): 区分句子A和句子B
  3. 位置嵌入(Position Embeddings): 可学习的位置编码

特殊标记:

4.4 BERT的模型变体

模型 层数 隐藏维度 参数量 适用场景
BERT-Base 12 768 110M 通用任务
BERT-Large 24 1024 340M 高精度需求
DistilBERT 6 768 66M 资源受限环境
RoBERTa 12/24 768/1024 125M/355M 优化版BERT

5. GPT:生成式预训练

5.1 GPT的自回归特性

GPT(Generative Pre-training)采用自左向右的单向语言建模方式。与BERT的双向编码器不同,GPT只使用Transformer的解码器部分(且去除了编码器-解码器注意力)。

5.2 因果掩码(Causal Masking)

为了实现自回归生成,GPT使用三角掩码(Triangular Mask):

[1, 0, 0, 0]
[1, 1, 0, 0]
[1, 1, 1, 0]
[1, 1, 1, 1]

这确保模型在预测第i个词时,只能看到位置0到i-1的信息。

5.3 GPT的预训练与微调

预训练阶段:
使用标准的语言建模目标,最大化似然概率:

L = sum(log P(x_i | x_1, ..., x_{i-1}))

微调阶段:
根据下游任务调整输入格式:

  • 分类任务: [START] 文本 [EXTRACT]
  • 蕴含任务: 前提 [DELIM] 假设 [EXTRACT]
  • 相似度任务: 文本1 [DELIM] 文本2 [EXTRACT]
  • 问答任务: 文档 [DELIM] 问题 [EXTRACT] 答案

5.4 GPT系列演进

版本 年份 参数量 上下文长度 关键改进
GPT-1 2018 117M 512 无监督预训练
GPT-2 2019 1.5B 1024 更大规模数据
GPT-3 2020 175B 2048 上下文学习
GPT-4 2023 未公开 128K 多模态能力
GPT-4o 2024 未公开 128K 原生多模态

6. T5:文本到文本转换

6.1 T5的统一框架

T5(Text-to-Text Transfer Transformer)由Google于2019年提出,核心思想是将所有NLP任务统一为文本生成任务

无论是分类、翻译、摘要还是问答,T5都将其转换为:

输入: "translate English to German: That is good."
输出: "Das ist gut."

6.2 T5的架构特点

T5采用标准的Encoder-Decoder架构:

  • 编码器:双向注意力,处理输入文本
  • 解码器:因果注意力,自回归生成输出

与BERT和GPT的关键区别:

  • BERT: 仅编码器,双向
  • GPT: 仅解码器,单向
  • T5: 编码器-解码器,编码器双向、解码器单向

6.3 T5的预训练策略

T5采用Span Corruption作为预训练目标:

  1. 随机采样文本中的连续片段(span)
  2. 用唯一的哨兵标记(sentinel token)替换每个span
  3. 解码器需要还原被替换的span

例如:

输入: "Thank you [X] me to your party [Y] week."
目标: "[X] for inviting [Y] last"

相比BERT的MLM,Span Corruption更符合生成任务的特点。

6.4 T5的模型规模

模型 参数量 编码器层数 解码器层数 d_model
T5-Small 60M 6 6 512
T5-Base 220M 12 12 768
T5-Large 770M 24 24 1024
T5-3B 3B 24 24 1024
T5-11B 11B 24 24 1024

7. BERT、GPT、T5对比分析

特性 BERT GPT T5
架构类型 仅编码器 仅解码器 编码器-解码器
注意力方向 双向 单向(因果) 编码器双向、解码器单向
预训练任务 MLM + NSP 自回归语言建模 Span Corruption
核心优势 理解任务 生成任务 统一框架
典型应用 分类、NER、问答 文本生成、对话 翻译、摘要、通用NLP
参数量级 110M-340M 117M-175B+ 60M-11B
位置编码 可学习 可学习 相对位置偏置
代表模型 BERT、RoBERTa GPT-2/3/4 T5、FLAN-T5

一句话总结:

  • BERT擅长"理解",像一位细心的读者
  • GPT擅长"生成",像一位作家
  • T5擅长"转换",像一位翻译家

8. Transformer的复杂度分析

8.1 自注意力的计算复杂度

标准自注意力的计算复杂度为O(n^2 * d),其中:

  • n: 序列长度
  • d: 模型维度

具体分解:

  1. Q、K、V投影: O(3 * n * d^2)
  2. QK^T计算: O(n^2 * d)
  3. Softmax与加权: O(n^2 * d)
  4. 输出投影: O(n * d^2)

内存复杂度: O(n^2),主要来自注意力矩阵的存储。

8.2 O(n^2)问题的影响

当序列长度n增加时:

  • n=512: 注意力计算占总计算量的约20%
  • n=2048: 注意力计算占总计算量的约50%
  • n=8192: 注意力计算成为主要瓶颈

这使得Transformer难以处理长文档、基因组序列、高分辨率图像等长序列数据。

8.3 复杂度优化方向

  1. 稀疏注意力: 只计算部分位置的注意力(如Longformer、BigBird)
  2. 低秩近似: 用低秩矩阵近似注意力矩阵(如Linformer)
  3. 核方法: 用核技巧降低复杂度(如Performer)
  4. 线性注意力: 将复杂度降至O(n)(如Linear Transformer、RWKV)
  5. 状态空间模型: 完全替代注意力机制(如Mamba)

9. Transformer高效变体

9.1 Linformer

Linformer通过低秩近似将复杂度从O(n^2)降至O(n):

核心思想:注意力矩阵是低秩的,可以用k个投影向量近似(k << n)。

Attention(Q, K, V) ≈ softmax(Q * (E * K)^T / sqrt(d_k)) * (F * V)

其中E、F是可学习的投影矩阵,将n维投影到k维。

复杂度: O(n * k * d),当k为常数时,复杂度为O(n)。

9.2 Performer

Performer使用**FAVOR+(Fast Attention Via Orthogonal Random Features)**算法:

通过随机特征映射将Softmax核近似为显式特征映射的内积:

exp(QK^T / sqrt(d)) ≈ phi(Q) * phi(K)^T

其中phi是随机特征映射函数。

优势:

  • 理论保证的近似精度
  • 无需存储O(n^2)的注意力矩阵
  • 支持极长序列(如100万token)

9.3 Mamba(2024-2025)

Mamba是基于**状态空间模型(State Space Model, SSM)**的全新架构,由Albert Gu和Tri Dao于2023年底提出,2024-2025年迅速成为研究热点。

核心创新:

  1. 选择性状态空间(Selective SSM):
    传统SSM对所有输入一视同仁,Mamba引入选择机制,让模型可以动态关注或忽略输入信息:

    h_t = A * h_{t-1} + B * x_t  (状态更新)
    y_t = C * h_t                (输出)
    

    其中B、C由输入x动态生成。

  2. 硬件感知算法:
    使用Flash Attention类似的内存优化技术,实现了与Transformer相当的训练速度。

  3. 线性复杂度:
    训练和推理复杂度均为O(n),推理速度比Transformer快5倍。

Mamba的优势:

  • 长序列建模能力强(测试过百万级token)
  • 推理速度快,吞吐量高
  • 内存占用低

Mamba的局限:

  • 对某些需要全局交互的任务可能不如Transformer
  • 2025年CVPR上Google DeepMind提出混合架构,结合Transformer和Mamba的优点

9.4 RetNet

RetNet(Retentive Network)由微软研究院于2023年提出,旨在同时实现训练并行化和推理低成本。

核心机制:Retention机制

RetNet将注意力分解为两个正交维度:

  1. 循环(Recurrent)表示: 用于高效推理
  2. 并行(Parallel)表示: 用于训练并行化

数学形式:

Retention(X) = (Q * K^T * D) * V

其中D是因果衰减矩阵,包含位置相关的衰减因子。

RetNet的特点:

  • 训练时并行计算(类似Transformer)
  • 推理时循环计算(类似RNN),复杂度O(1)
  • 性能与Transformer相当

9.5 其他重要变体

模型 核心思想 复杂度 适用场景
Longformer 局部+全局注意力 O(n) 长文档处理
BigBird 随机+窗口+全局注意力 O(n) 长序列建模
Reformer LSH局部敏感哈希 O(n log n) 内存受限环境
RWKV RNN与Transformer结合 O(n) 长文本生成
Flash Attention IO感知的精确注意力 O(n^2)但更快 通用加速

10. Transformer完整PyTorch实现

以下是完整的Transformer实现,包含多头注意力、位置编码、编码器、解码器等所有组件:

import torch
import torch.nn as nn
import torch.nn.functional as F
import math


class MultiHeadAttention(nn.Module):
    """多头自注意力机制"""
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0, "d_model必须能被num_heads整除"
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        # Q、K、V的线性投影
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        
    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        """缩放点积注意力"""
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        
        attn_weights = F.softmax(scores, dim=-1)
        output = torch.matmul(attn_weights, V)
        return output, attn_weights
    
    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)
        
        # 线性投影并分头
        Q = self.W_q(query).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        K = self.W_k(key).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        V = self.W_v(value).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        
        # 计算注意力
        attn_output, attn_weights = self.scaled_dot_product_attention(Q, K, V, mask)
        
        # 拼接多头结果
        attn_output = attn_output.transpose(1, 2).contiguous().view(
            batch_size, -1, self.d_model
        )
        
        return self.W_o(attn_output)


class PositionwiseFeedForward(nn.Module):
    """位置前馈网络"""
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        return self.linear2(self.dropout(F.relu(self.linear1(x))))


class PositionalEncoding(nn.Module):
    """正弦位置编码"""
    def __init__(self, d_model, max_seq_length=5000, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        
        # 创建位置编码矩阵
        pe = torch.zeros(max_seq_length, d_model)
        position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
        
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * 
            (-math.log(10000.0) / d_model)
        )
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)


class EncoderLayer(nn.Module):
    """Transformer编码器层"""
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, mask=None):
        # 自注意力子层
        attn_output = self.self_attn(x, x, x, mask)
        x = self.norm1(x + self.dropout(attn_output))
        
        # 前馈子层
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_output))
        
        return x


class DecoderLayer(nn.Module):
    """Transformer解码器层"""
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.masked_self_attn = MultiHeadAttention(d_model, num_heads)
        self.cross_attn = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, enc_output, src_mask=None, tgt_mask=None):
        # 掩码自注意力
        attn_output = self.masked_self_attn(x, x, x, tgt_mask)
        x = self.norm1(x + self.dropout(attn_output))
        
        # 编码器-解码器注意力
        attn_output = self.cross_attn(x, enc_output, enc_output, src_mask)
        x = self.norm2(x + self.dropout(attn_output))
        
        # 前馈网络
        ff_output = self.feed_forward(x)
        x = self.norm3(x + self.dropout(ff_output))
        
        return x


class TransformerEncoder(nn.Module):
    """Transformer编码器"""
    def __init__(self, vocab_size, d_model, num_heads, num_layers, d_ff, 
                 max_seq_length=5000, dropout=0.1):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(d_model, max_seq_length, dropout)
        self.layers = nn.ModuleList([
            EncoderLayer(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])
        self.dropout = nn.Dropout(dropout)
        self.scale = math.sqrt(d_model)
        
    def forward(self, x, mask=None):
        x = self.embedding(x) * self.scale
        x = self.pos_encoding(x)
        
        for layer in self.layers:
            x = layer(x, mask)
            
        return x


class TransformerDecoder(nn.Module):
    """Transformer解码器"""
    def __init__(self, vocab_size, d_model, num_heads, num_layers, d_ff,
                 max_seq_length=5000, dropout=0.1):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(d_model, max_seq_length, dropout)
        self.layers = nn.ModuleList([
            DecoderLayer(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])
        self.dropout = nn.Dropout(dropout)
        self.scale = math.sqrt(d_model)
        
    def forward(self, x, enc_output, src_mask=None, tgt_mask=None):
        x = self.embedding(x) * self.scale
        x = self.pos_encoding(x)
        
        for layer in self.layers:
            x = layer(x, enc_output, src_mask, tgt_mask)
            
        return x


class Transformer(nn.Module):
    """完整的Transformer模型"""
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model=512, 
                 num_heads=8, num_encoder_layers=6, num_decoder_layers=6,
                 d_ff=2048, max_seq_length=5000, dropout=0.1):
        super().__init__()
        
        self.encoder = TransformerEncoder(
            src_vocab_size, d_model, num_heads, num_encoder_layers,
            d_ff, max_seq_length, dropout
        )
        
        self.decoder = TransformerDecoder(
            tgt_vocab_size, d_model, num_heads, num_decoder_layers,
            d_ff, max_seq_length, dropout
        )
        
        self.output_layer = nn.Linear(d_model, tgt_vocab_size)
        
        # 参数初始化
        self._init_parameters()
        
    def _init_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
    
    def make_src_mask(self, src):
        """创建源序列掩码(用于padding)"""
        src_mask = (src != 0).unsqueeze(1).unsqueeze(2)
        return src_mask
    
    def make_tgt_mask(self, tgt):
        """创建目标序列掩码(因果掩码 + padding掩码)"""
        tgt_pad_mask = (tgt != 0).unsqueeze(1).unsqueeze(3)
        tgt_len = tgt.size(1)
        tgt_sub_mask = torch.tril(
            torch.ones(tgt_len, tgt_len, device=tgt.device)
        ).bool()
        tgt_mask = tgt_pad_mask & tgt_sub_mask
        return tgt_mask
    
    def forward(self, src, tgt):
        src_mask = self.make_src_mask(src)
        tgt_mask = self.make_tgt_mask(tgt)
        
        enc_output = self.encoder(src, src_mask)
        dec_output = self.decoder(tgt, enc_output, src_mask, tgt_mask)
        output = self.output_layer(dec_output)
        
        return output


# 使用示例
if __name__ == "__main__":
    # 模型参数
    SRC_VOCAB_SIZE = 10000
    TGT_VOCAB_SIZE = 10000
    D_MODEL = 512
    NUM_HEADS = 8
    NUM_ENCODER_LAYERS = 6
    NUM_DECODER_LAYERS = 6
    D_FF = 2048
    MAX_SEQ_LENGTH = 100
    DROPOUT = 0.1
    BATCH_SIZE = 32
    SEQ_LENGTH = 20
    
    # 创建模型
    model = Transformer(
        src_vocab_size=SRC_VOCAB_SIZE,
        tgt_vocab_size=TGT_VOCAB_SIZE,
        d_model=D_MODEL,
        num_heads=NUM_HEADS,
        num_encoder_layers=NUM_ENCODER_LAYERS,
        num_decoder_layers=NUM_DECODER_LAYERS,
        d_ff=D_FF,
        max_seq_length=MAX_SEQ_LENGTH,
        dropout=DROPOUT
    )
    
    # 模拟输入数据
    src = torch.randint(1, SRC_VOCAB_SIZE, (BATCH_SIZE, SEQ_LENGTH))
    tgt = torch.randint(1, TGT_VOCAB_SIZE, (BATCH_SIZE, SEQ_LENGTH))
    
    # 前向传播
    output = model(src, tgt)
    
    print(f"输入形状: {src.shape}")
    print(f"输出形状: {output.shape}")
    print(f"模型参数量: {sum(p.numel() for p in model.parameters()) / 1e6:.2f}M")

11. 避坑小贴士

11.1 注意力掩码的常见错误

错误1:混淆padding掩码和因果掩码

  • padding掩码用于忽略填充位置(值为0的位置)
  • 因果掩码用于防止看到未来信息(上三角矩阵)
  • 解码器中需要同时使用两种掩码

错误2:掩码值设置不当

# 错误:使用0作为掩码值
scores = scores.masked_fill(mask == 0, 0)

# 正确:使用极大负数
scores = scores.masked_fill(mask == 0, -1e9)

11.2 位置编码的实现陷阱

错误:位置编码没有注册为buffer

# 错误:位置编码会被当作模型参数保存和更新
self.pe = pe

# 正确:注册为buffer,不会参与梯度计算
self.register_buffer('pe', pe)

11.3 维度不匹配问题

多头注意力中常见的维度错误:

# 错误:没有正确分头和拼接
Q = self.W_q(query)  # [batch, seq, d_model]
Q = Q.view(batch_size, self.num_heads, -1, self.d_k)  # 维度错误

# 正确:先view再transpose
Q = self.W_q(query).view(batch_size, -1, self.num_heads, self.d_k)
Q = Q.transpose(1, 2)  # [batch, num_heads, seq, d_k]

11.4 梯度消失与爆炸

建议1: 使用Xavier/Glorot初始化

for p in self.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)

建议2: 使用学习率预热(Warmup)

# 前几个epoch线性增加学习率,之后按某种策略衰减
warmup_steps = 4000
lr = d_model^(-0.5) * min(step^(-0.5), step * warmup_steps^(-1.5))

11.5 内存优化建议

  1. 使用梯度累积: 当显存不足时,可以累积多个小批量的梯度再更新
  2. 混合精度训练: 使用torch.cuda.amp可以节省约50%显存
  3. 梯度检查点: 用计算换内存,只保存关键层的激活值
  4. Flash Attention: 使用优化的注意力实现,大幅减少显存占用

12. 本章小结和知识点回顾

核心知识点

  1. Transformer架构: 编码器-解码器结构,完全基于注意力机制,摒弃了RNN的循环结构

  2. 自注意力机制: 通过Query、Key、Value计算序列中各位置的相互关系,实现全局依赖建模

  3. 多头注意力: 多组独立的注意力并行计算,捕捉不同类型的依赖关系

  4. 位置编码: 为模型提供序列顺序信息,包括正弦编码、可学习编码、RoPE等

  5. BERT: 双向编码器,通过MLM和NSP预训练,擅长理解任务

  6. GPT: 单向解码器,通过自回归语言建模预训练,擅长生成任务

  7. T5: 编码器-解码器架构,将所有NLP任务统一为文本到文本的转换

  8. 复杂度问题: 标准Transformer的O(n^2)复杂度是处理长序列的主要瓶颈

  9. 高效变体: Linformer、Performer、Mamba、RetNet等通过不同策略降低复杂度

一句话总结

Transformer通过自注意力机制实现了序列建模的并行化和全局依赖捕捉,BERT、GPT、T5三大模型分别代表了理解、生成、转换三种范式,而Mamba等新型架构正在挑战Transformer的统治地位。

学习建议

  1. 动手实现一遍完整的Transformer,加深理解
  2. 使用Hugging Face Transformers库,体验预训练模型的威力
  3. 关注Mamba等新兴架构的发展动态
  4. 尝试将Transformer应用到自己的项目中
Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐