推理服务突然报 CUDA out of memory,查日志发现是一条超长 system prompt 触发了 KV Cache 线性膨胀,把 A100 的 80GB HBM 直接吃满。问题不在显卡不够,而是我们没真正理解 Transformer 里哪些中间结果该存、哪些可以丢弃、哪些能换一种方式算。这篇文章就从这次事故的复盘开始,把 Transformer 架构里真正影响训练和推理的工程点拆开看一遍,重点放在决策依据、取舍逻辑和可执行的做法上。

线上 OOM 背后的真凶:注意力矩阵与 KV Cache

很多人以为推理时的显存大头是模型权重,其实在长序列场景下,KV Cache 是中低端显卡的隐形杀手。标准自注意力每层都要为每个 token 算出 Key 和 Value,并在解码时把所有历史 token 的 K、V 缓存下来。对于一个 L 层、h 个注意力头、每头维度 d 的 Decoder-only 模型,一个序列的 KV Cache 显存占用 = 2 × L × seq_len × h × d × 数据类型字节数。当 seq_len 从 4K 涨到 32K,这坨缓存就能把一张 3090 爆干净。

更隐蔽的是,在标准实现里,生成第 t 个 token 时还要计算整个 N×N 的注意力分数矩阵,中间产物巨大,但最终只用到最后一行。如果能理解这种计算和存储的浪费,就能顺藤摸瓜找到 FlashAttention、GQA、KV Cache 驱逐这些优化点。所以 OOM 不是灾难,是逼着我们认清架构的一次机会。

自注意力不是魔法:Q、K、V 在工程中的真实形态

Transformer 的核心就是让每个 token 学会“查字典”——Query(查询)去匹配 Key(键),加权聚合 Value(值)。这一步在数学上简洁,落到 GPU 上却全是坑。Q 和 K 的矩阵乘得到注意力分数 S=QK^T,在大维度下会出现梯度消失,所以必须除以 sqrt(d_k) 缩放,再做 softmax 才稳定。工程里我们看到的不是公式本身,而是torch.nn.functional.scaled_dot_product_attention 里的一系列参数:attn_mask、dropout_p、is_causal、scale 等。

踩过一个坑:某团队在微调时为了节约显存关闭了 dropout,结果预训练时 dropout 残差路径没对齐,导致推理质量骤降。自注意力的每个算子都不是孤岛,改一个参数必须前后链路对齐。另外,很多人的误区是把注意力输出直接当最终表示,忘了 Transformer 每个 block 里还有残差连接和 LayerNorm 在起作用。

位置编码怎么选?RoPE 成为 Decoder-only 默认的缘由

Transformer 本身没有序列顺序感,位置编码就是赋予 token 位置信息的机制。早年是正弦位置编码,后来出现了可学习的绝对位置编码,再到 AliBi,最后几乎所有的 Decoder-only 模型都选了旋转位置编码 RoPE(Rotary Position Embedding)。原因很明确:RoPE 通过复数旋转把位置信息注入 Q 和 K 的内积中,天然带有相对位置性质,且对长上下文扩展友好。Llama 系列、ChatGLM 系列都用它。

一个关键工程决策:RoPE 的基频 base 大小直接影响长程衰减。默认 base=10000 在 4K 上下文时够用,但要想外推到 32K 或更高,必须调整基频或配合 NTK-aware interpolation 缩放。硬扩窗口而不改基频,困惑度会飘得很离谱。所以我们做 long-context 微调时,必改的配置项就是 RoPE base 和 scale。

多头注意力不是越多越好:MHA、MQA 与 GQA 的显存账

Multi-Head Attention 本意是让不同注意力头关注不同子空间,但它也让 KV Cache 翻倍——每个头都有独立的 K、V。MHA 的参数量和计算量到千亿级别时,KV Cache 很可能吞掉 80% 以上的推理显存。于是 Google 先提出 Multi-Query Attention,让所有头共享一份 K、V,Cache 直接缩小 head_num 倍,代价是学习能力打折。

折中方案是 Grouped-Query Attention,将查询头分成 G 组,每组共享一组 K、V。Llama 2/3 70B 选择 G=8,不是拍脑袋,而是对应单机 8 卡推理时,每张卡刚好负责一组 G 的 K、V,最大化减少卡间通信。这告诉我们:架构设计必须对着部署拓扑来。现在选型时我会明确问一句:这个模型的 G 值能在现有 GPU 拓扑下实现纯数据并行吗?

残差与 LayerNorm:Pre-Norm 如何让千亿模型稳定训练

原始 Transformer 用 Post-LayerNorm,即先接子层(注意力或 FFN)再接残差最后做层归一化。这种顺序在训练极深网络时梯度容易爆炸或消失。Pre-LayerNorm 改为先做归一化再进子层然后残差相加,训练曲线平滑得多,千亿参数模型都靠它吃饭。现在所有主流 Decoder-only 模型清一色采用 Pre-Norm。

还有一个坑:LayerNorm 的计算对混合精度敏感。某些框架在 FP16 下用 Naive 实现会导致方差偏移,出现 NaN。解决方案是切到 PyTorch 的 fused LayerNorm 或者强制 FP32 计算归一化统计量。线上部署时我卡过一个诡异问题:推理精度设为 FP16,长序列末尾 token 的 logits 全是 inf,最后定位到 LayerNorm 的 epsilon 配置不兼容——两个微小工程参数的连锁反应。

Decoder-only 架构为何一统天下?因果遮罩的双面性

当下所有大模型几乎都是 Decoder-only,BLOOM 那种 Encoder-Decoder 结构反而少见。核心原因是因果遮罩(causal mask)既限制了每个 token 只能看到过去的 token,又恰好让自回归生成变得天然增量:生成第 t 个 token 时,历史 token 的 K、V 能复用,只用计算当前 token 的 Q、K、V,然后结合 Cache 完成注意力输出。没有这个下三角遮罩,Attention 会依赖未来的 token,增量生成就无从谈起。

另一个好处是预训练数据利用率:Decoder-only 直接训练因果语言模型,无需像 BERT 那样构造 MLM 任务,每个样本的每个 token 都参与损失计算,同等算力下信号更多。不过工程上要注意,因果遮罩并不是总能在前向函数里显现——PyTorch 的 scaled_dot_product_attention 在 is_causal=True 时会在内部自动生成 mask,外部看不到,调试时容易误以为没生效。

KV Cache 的生死线:预填充与解码阶段分离的工程实现

带 Cache 的推理必须拆成预填充(prefill)和解码两个阶段。预填充一次性处理整个 prompt,生成所有 token 的 K、V 并存入 Cache,然后输出最后一个 token 的 logits 作为第一个生成 token 的概率分布。解码阶段每次只输入当前生成的单个 token,复用 Cache 增量计算,大幅避免重复计算。

下面这段伪代码展示了这一分离逻辑,注意第二次调用时 past_key_values 的串联方式。

# 预填充:一次性编码 prompt
outputs = model(
    input_ids=prompt_ids, 
    use_cache=True, 
    past_key_values=None
)
logits = outputs.logits[0, -1, :]
next_token = torch.argmax(logits, dim=-1)
generated = [next_token.item()]
past_kv = outputs.past_key_values

# 解码:逐个 token 自回归
for _ in range(max_new_tokens - 1):
    outputs = model(
        input_ids=next_token.unsqueeze(0), 
        use_cache=True, 
        past_key_values=past_kv
    )
    logits = outputs.logits[0, -1, :]
    next_token = torch.argmax(logits, dim=-1)
    generated.append(next_token.item())
    past_kv = outputs.past_key_values
    if next_token.item() == eos_token_id: break

工程上必须注意:use_cache 默认 True,但如果在微调时关闭了,推理时要显式打开;同时要确保每一层的 past_key_values 是元组结构,动态扩充时避免产生碎片。vLLM 的 PagedAttention 用类似操作系统分页的方式管理这部分内存,显著减少了显存浪费。

FlashAttention 的分块策略:让 N×N 注意力矩阵从未存在

传统注意力会实例化一个 N×N 的分数矩阵,即便用了因果遮罩,HBM 读写始终是 O(N²) 级。FlashAttention 的突破在于分块:把 Q、K、V 切成长度为 B_q 和 B_k 的瓦片,全部在 SRAM 中计算局部注意力分数、使用在线 softmax 累计部分结果,计算完一个瓦片立即丢弃,最终输出写回 HBM。这样完整的 N×N 矩阵从未在任何地方存在过。

看一眼数据流向就清楚为什么能省显存:标准注意力数据流量 O(N²),FlashAttention 降到 O(N²/M),M 近似为块大小。K、V 的重复读取虽然增加,但顺序访问对 L2 缓存友好。不过要理解一个常见误解:FlashAttention 不能削减 HBM 中的 KV Cache 存储——Cache 仍然必须持所有 token 的 K、V,它不是计算中间产物。只是在计算注意力时,Cache 的瓦片被高效流式加载。

在 HuggingFace Transformers 里启用它只需一行配置:

model = AutoModelForCausalLM.from_pretrained(
    'meta-llama/Llama-2-7b-hf',
    attn_implementation='flash_attention_2',
    torch_dtype=torch.float16
)

切换后务必检查 torch.backends.cuda.enable_flash_sdp 是否开启,以及当前 GPU 的算力是否满足要求。遇到 is_causal 未定义报错,多半是库版本不兼容。

长上下文压缩:从 LaCache 阶梯剪枝到 Gemini 1.5 的 L2 记忆

KV Cache 膨胀到百万 token 级别时,即使有 FlashAttention,显存也扛不住。必须对 Cache 进行压缩或驱逐。LaCache 的思路是用阶梯形剪枝:当 Cache 达到容量上限,对已经压缩过的历史 Cache 再次应用更大压缩比,而新 token 保留更多。这样越早的上下文被压缩得越狠,近期信息保留细节,并通过一种无注意力计算的驱逐策略避免依赖 S 矩阵,兼容 FlashAttention。

Google Gemini 1.5 走得更远,把记忆分成“工作记忆”和“长期记忆”。近期 1 万 token 保留高精度 KV,更早的 token 通过语义聚类压缩成记忆块,引入可训练的“记忆路由器”决定是否需要召回历史块。这种设计把存储与计算完全解耦,冷数据可以暂存到 CPU/NVMe,热数据留在 HBM 并用 FlashAttention-2 加速。

部署长上下文服务时,我给团队定过一条硬规矩:不做任何 Cache 压缩的模型不要直接开放 >16K 的窗口,哪怕官方说支持。因为不做驱逐策略的推理服务,等于把 OOM 的主动权交给了用户 prompt。

推理部署的取舍:PagedAttention 与显存管理策略

vLLM 的 PagedAttention 把 KV Cache 按固定大小的 Block 管理,不再需要预先分配一大块连续显存。这种虚拟化方式不仅减少碎片,还能把多个序列共享相同 prefix 的 Cache Block(例如 system prompt)复用,极大提高吞吐。我们内部压测显示,同硬件上 vLLM 的并发吞吐是普通 HF generate 的 10 倍以上。

部署中需要关注几个参数:gpu_memory_utilization 不宜拉满,留出 margin 给 CUDA context 和波动;max_num_seqs 控制同时处理的请求数,要根据平均序列长度动态调整;Cache Block 大小(通常 16 或 32)影响碎片率与交换效率。另一个容易忽视的点是,服务重启后 KV Cache 丢失,长对话状态无法恢复,必须把 Cache 序列化或借助持久化存储,才能支持断点续聊。

工程上还要考虑幂等和重试:客户端断开时,服务端应该检测并释放相应的 Cache Block,避免内存泄漏;同时推理接口最好支持 idempotency key,防止网络抖动导致重复生成浪费算力。

从模型设计到上线:三个关键的决策检查点

经历过数次从模型选型到上线翻车后,我整理了三个必须在立项初期就明确的检查点:

  • 上下文长度 vs 显存预算:用公式估算 KV Cache 上限,决定是否需要 GQA、是否必须上 FlashAttention、是否需要 Cache 压缩策略。
  • 部署拓扑 vs 注意力头分组:如果模型采用 GQA,G 值应当与单机卡数或张量并行度对齐,避免跨机通信开销吃掉加速收益。
  • 混合精度与归一化的一致性:训练时用的 LayerNorm/RMSNorm 精度、dropout 设置必须与推理完全一致,否则会出现静默的质量下降。

这三个点没有理顺,后面补锅的代价巨大。很多时候架构师纠结算法指标,却忽略了基础设施的约束。Transformer 虽然原理简单,但从预训练到推理,每一层都有显存、计算、通信的纠缠,能提前画清边界才是工程能力的分水岭。

参考资料

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐