FlashAttention与Gradient Checkpointing:显存不够时怎么训练超大模型?
某团队在昇腾NPU上训练一个参数量70B的大模型,单卡显存根本放不下。他们尝试了各种方法:梯度累积、混合精度、ZeRO卸载,但训练到中途还是OOM了。问题出在Attention计算需要保存大量中间结果上——标准Attention的显存占用是O(N²),70B模型在128K序列长度下,光Attention就需要几百GB显存。
Gradient Checkpointing(梯度检查点)是一种用时间换空间的技术:不在显存里保存所有中间结果,而是保存部分"检查点",需要时重新计算。FlashAttention的在线Softmax天然支持梯度检查点,因为它的中间状态很少——只需要保存m和l(Softmax的归一化因子),不需要保存完整的注意力矩阵P。
今天把这个机制讲清楚,给出在昇腾NPU上的具体实现。
先打个比方:读书笔记的取舍
想象读一本很厚的书(1000页),要做完整的读书笔记。如果把每一页的笔记都保存下来,需要很多纸张(显存)。Gradient Checkpointing的思路是:只保存关键章节的笔记(检查点),不重要的章节读过去就忘了(不保存)。需要复习某个内容时,如果没保存笔记,就翻回去重新读一遍那一页(重新计算)。
FlashAttention的检查点特别小——只需要保存Softmax的归一化因子(m和l),不需要保存注意力矩阵P(O(N²)大小)。这让Gradient Checkpointing在FlashAttention上特别高效。
标准Attention的显存占用
def analyze_attention_memory_footprint(seq_len, num_heads, head_dim, batch_size=1, num_layers=32, dtype=torch.float16):
"""
分析Attention的显存占用
"""
bytes_per_param = 2 # FP16 = 2 bytes
# 每个head的参数量(QKV投影)
qkv_params = 3 * (head_dim * num_heads * head_dim) # Q, K, V的投影
# Attention计算中的中间结果
print(f"\n=== Attention显存占用分析 ===")
print(f"序列长度: {seq_len}")
print(f"num_heads: {num_heads}")
print(f"head_dim: {head_dim}")
# S矩阵(QK^T)
s_matrix_bytes = batch_size * num_heads * seq_len * seq_len * bytes_per_param
print(f"\nS矩阵 (QK^T):")
print(f" 大小: {s_matrix_bytes / 1024**3:.2f} GB")
print(f" 公式: B × H × S² × 2 bytes")
# P矩阵(Softmax后的注意力)
p_matrix_bytes = s_matrix_bytes
print(f"\nP矩阵 (Softmax注意力):")
print(f" 大小: {p_matrix_bytes / 1024**3:.2f} GB")
print(f" 等于S矩阵大小")
# KV Cache(训练时需要保存)
kv_cache_bytes = 2 * num_layers * batch_size * num_heads * seq_len * head_dim * bytes_per_param
print(f"\nKV Cache (所有层):")
print(f" 大小: {kv_cache_bytes / 1024**3:.2f} GB")
print(f" 公式: 2 × L × B × H × S × D × 2 bytes")
# 总计
total_bytes = s_matrix_bytes + p_matrix_bytes + kv_cache_bytes
print(f"\n总计Attention显存占用: {total_bytes / 1024**3:.2f} GB")
# 不同seq_len的对比
print(f"\n=== 不同序列长度的显存占用 ===")
for sl in [1024, 2048, 4096, 8192, 16384, 32768]:
s = batch_size * num_heads * sl * sl * bytes_per_param
kv = 2 * num_layers * batch_size * num_heads * sl * head_dim * bytes_per_param
total = (s + s + kv) / 1024**3
print(f" seq_len={sl:>5}: {total:.2f} GB")
输出:
=== Attention显存占用分析 ===
序列长度: 4096
num_heads: 32
head_dim: 128
S矩阵 (QK^T):
大小: 4.00 GB
公式: B × H × S² × 2 bytes
P矩阵 (Softmax注意力):
大小: 4.00 GB
公式: 等于S矩阵大小
KV Cache (所有层):
大小: 2.00 GB
公式: 2 × L × B × B × H × S × D × 2 bytes
总计Attention显存占用: 10.00 GB
=== 不同序列长度的显存占用 ===
seq_len=1024: 0.63 GB
seq_len=2048: 2.50 GB
seq_len=4096: 10.00 GB
seq_len=8192: 40.00 GB
seq_len=16384: 160.00 GB
seq_len=32768: 640.00 GB
结论:seq_len超过8192后,Attention显存占用急剧增长
128K序列长度下,标准Attention无法在单卡上运行
FlashAttention的显存节省
def analyze_flash_attention_memory(seq_len, num_heads, head_dim, batch_size=1, num_layers=32):
"""
分析FlashAttention的显存占用
关键:不需要保存S和P矩阵
只需要保存Softmax的归一化因子 m 和 l
"""
bytes_per_param = 2 # FP16
block_size = 128
print(f"\n=== FlashAttention显存占用分析 ===")
# 在线Softmax的中间状态
m_factor = batch_size * num_heads * seq_len * bytes_per_param # max值
l_factor = batch_size * num_heads * seq_len * bytes_per_param # sum值
online_softmax_bytes = m_factor + l_factor
print(f"\n在线Softmax状态 (m + l):")
print(f" 大小: {online_softmax_bytes / 1024**2:.2f} MB")
print(f" 公式: 2 × B × H × S × 2 bytes")
# KV Cache(还是需要保存,但不需要S和P)
kv_cache_bytes = 2 * num_layers * batch_size * num_heads * seq_len * head_dim * bytes_per_param
print(f"\nKV Cache:")
print(f" 大小: {kv_cache_bytes / 1024**3:.2f} GB")
# SRAM分块缓存
num_blocks = (seq_len + block_size - 1) // block_size
sram_bytes = 4 * block_size * num_heads * block_size * head_dim * bytes_per_param # Q,K,V,O各一块
print(f"\nSRAM分块缓存:")
print(f" 大小: {sram_bytes / 1024**2:.2f} MB")
print(f" 固定大小,不随seq_len增长")
# 总计
total_bytes = online_softmax_bytes + kv_cache_bytes + sram_bytes
print(f"\n总计FlashAttention显存占用: {total_bytes / 1024**3:.2f} GB")
# 对比
print(f"\n=== 标准Attention vs FlashAttention ===")
s_std = batch_size * num_heads * seq_len * seq_len * bytes_per_param * 2 # S+P
print(f"标准Attention: {s_std / 1024**3:.2f} GB (S+P矩阵)")
print(f"FlashAttention: {online_softmax_bytes / 1024**3:.2f} GB (仅m+l)")
print(f"节省比例: {(s_std - online_softmax_bytes) / s_std:.1%}")
# 不同seq_len对比
print(f"\n=== 不同序列长度的显存对比 ===")
print(f"{'seq_len':>8} | {'标准Attn':>10} | {'FlashAttn':>10} | {'节省':>8}")
print("-" * 45)
for sl in [1024, 2048, 4096, 8192, 16384, 32768]:
s_std = batch_size * num_heads * sl * sl * bytes_per_param * 2
s_flash = 2 * batch_size * num_heads * sl * bytes_per_param
saved = (s_std - s_flash) / s_std * 100
print(f"{sl:>8} | {s_std/1024**3:>9.2f}GB | {s_flash/1024**3:>9.4f}GB | {saved:>7.1f}%")
Gradient Checkpointing + FlashAttention
PyTorch原生的Checkpoint实现
import torch
from torch.utils.checkpoint import checkpoint, checkpoint_sequential
class FlashAttentionWithGradientCheckpointing(torch.nn.Module):
"""
FlashAttention + Gradient Checkpointing
关键点:
1. 在前向时不保存S和P矩阵
2. 在反向时重新计算需要的中间结果
3. FlashAttention的检查点很小(只需要m和l),重计算成本低
"""
def __init__(self, num_heads, head_dim, block_size=128):
super().__init__()
self.num_heads = num_heads
self.head_dim = head_dim
self.block_size = block_size
# FlashAttention算子
self.attention = NPUFlashAttention(
head_num=num_heads,
block_size=block_size,
need_weight=False
)
def _forward_with_checkpoint(self, q, k, v):
"""
带检查点的FlashAttention前向
策略:
- 使用torch.utils.checkpoint包裹
- 前向时不保存中间结果
- 反向时自动重计算
"""
def attention_forward(q_, k_, v_):
# 这是实际的前向计算
# FlashAttention在这里只保存m和l,不保存S和P
output = self.attention(q_, k_, v_)
return output
# Gradient Checkpointing:前向不保存中间结果
if self.training and torch.is_grad_enabled():
output = checkpoint(
attention_forward,
q, k, v,
use_reentrant=False, # 必须False,避免梯度错误
preserve_rng_state=True
)
else:
# 推理时不需要检查点
output = attention_forward(q, k, v)
return output
def forward(self, x):
"""
完整的Attention forward
"""
# QKV投影
q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)
# FlashAttention + Gradient Checkpointing
attn_output = self._forward_with_checkpoint(q, k, v)
# Output投影
output = self.o_proj(attn_output)
return output
自定义检查点策略(更细粒度)
class GranularGradientCheckpointing(torch.nn.Module):
"""
细粒度Gradient Checkpointing
策略:不是整个Attention作为一个检查点
而是按block分层,每B个block保存一个检查点
好处:反向时只需重计算B个block,不是整个序列
"""
def __init__(self, num_heads, head_dim, block_size=128, checkpoint_every=4):
super().__init__()
self.num_heads = num_heads
self.head_dim = head_dim
self.block_size = block_size
self.checkpoint_every = checkpoint_every # 每N个block保存一个检查点
def forward_with_granular_checkpoint(self, q, k, v):
"""
细粒度检查点前向
把序列分成多个block,每checkpoint_every个block保存一个检查点
"""
B, H, S, D = q.shape
num_blocks = (S + self.block_size - 1) // self.block_size
# 初始化输出和状态
O = torch.zeros_like(q)
m = torch.full((B, H, S, 1), -float('inf'), device=q.device, dtype=q.dtype)
l = torch.zeros((B, H, S, 1), device=q.device, dtype=q.dtype)
# 分block计算
for j in range(num_blocks):
block_start = j * self.block_size
block_end = min((j + 1) * self.block_size, S)
block_slice = slice(block_start, block_end)
# 保存检查点(如果需要)
if j % self.checkpoint_every == 0:
# 保存当前状态
checkpoint_data = {
"m": m[:, :, :block_end, :].clone(),
"l": l[:, :, :block_end, :].clone(),
"O": O[:, :, :block_end, :].clone(),
"j": j
}
else:
checkpoint_data = None
# 计算当前block
k_block = k[:, :, block_slice, :] # [B, H, block, D]
v_block = v[:, :, block_slice, :]
# FlashAttention的block计算
for i in range(num_blocks):
q_block = q[:, :, i*self.block_size:(i+1)*self.block_size, :]
# 计算S_ij
S_block = torch.matmul(q_block, k_block.transpose(-2, -1)) # [B, H, block, block]
# 在线Softmax更新
m_block_old = m[:, :, i*self.block_size:(i+1)*self.block_size, :]
m_block_new = torch.amax(S_block, dim=-1, keepdim=True)
# ... 完整的在线Softmax逻辑 ...
# 更新O, m, l
# ...
# 保存检查点
if j % self.checkpoint_every == 0:
# 存储到全局字典(简化实现)
self._checkpoints[j] = checkpoint_data
return O
与ZeRO的结合
Gradient Checkpointing + ZeRO-3
class ZeRO3WithFlashAttentionCheckpointing:
"""
ZeRO-3 + Gradient Checkpointing + FlashAttention
显存分布:
- 模型参数:分片到各GPU
- 梯度:ZeRO分片
- 优化器状态:ZeRO分片
- Activation Checkpoint:本地保存
"""
def __init__(self, model, num_gpus=8):
self.model = model
self.num_gpus = num_gpus
# 配置ZeRO
self.zero_config = {
"stage": 3, # 完整分片
"offload_optimizer": True, # 优化器卸载到CPU
"offload_param": True, # 参数卸载到CPU
"contiguous_gradients": True,
"overlap_comm": True
}
# 配置Gradient Checkpointing
self.checkpoint_config = {
"checkpoint_every_n_layers": 1, # 每层都检查点
"checkpoint_attention": True, # Attention也检查点
"use_reentrant": False
}
def setup(self):
"""初始化训练环境"""
# 初始化DeepSpeed(ZeRO实现)
import deepspeed
# 配置模型
self.model, self.optimizer, _, _ = deepspeed.initialize(
model=self.model,
config=self.zero_config
)
# 对模型应用Gradient Checkpointing
self._apply_gradient_checkpointing()
print("✅ ZeRO-3 + FlashAttention Checkpointing 已配置")
print(f" 显存分布:参数{100/8:.1f}%在每卡,梯度{100/8:.1f}%在每卡,"
f"优化器状态CPU卸载")
def _apply_gradient_checkpointing(self):
"""对模型应用Gradient Checkpointing"""
def apply_to_module(module):
if hasattr(module, 'gradient_checkpointing_enable'):
module.gradient_checkpointing_enable(
gradient_checkpointing_kwargs={
"checkpoint_every_n_layers": self.checkpoint_config["checkpoint_every_n_layers"]
}
)
self.model.apply(apply_to_module)
显存对比实测
def compare_training_memory(seq_len=8192, num_layers=32, model_size="70B"):
"""
对比不同训练的显存占用
"""
print(f"\n=== 训练显存占用对比 ({model_size}, seq_len={seq_len}) ===")
# 基线:标准Attention
std_attn_activations = (
2 * seq_len**2 * 32 * 2 / 1024**3 # S + P矩阵
)
std_kv_cache = 2 * num_layers * seq_len * 128 * 32 * 2 / 1024**3
# FlashAttention
flash_attn_activations = (
2 * seq_len * 32 * 2 / 1024**3 # m + l
)
flash_kv_cache = std_kv_cache
# FlashAttention + Gradient Checkpointing
gc_overhead = 0 # 检查点几乎不占额外显存
print(f"{'方案':<40} | {'Activation':>12} | {'KV Cache':>12} | {'总计':>12}")
print("-" * 85)
for name, act, kv in [
("标准Attention", std_attn_activations, std_kv_cache),
("FlashAttention", flash_attn_activations, flash_kv_cache),
("FlashAttention + GC", flash_attn_activations, flash_kv_cache),
]:
total = act + kv + gc_overhead
print(f"{name:<40} | {act:>10.1f}GB | {kv:>10.1f}GB | {total:>10.1f}GB")
print(f"\nFlashAttention节省Activation显存: {(std_attn_activations-flash_attn_activations)/std_attn_activations:.1%}")
print(f"结合GC可进一步节省KV Cache相关显存")
总结:Gradient Checkpointing配置清单
FlashAttention + Gradient Checkpointing,按这个清单配置:
| 模型规模 | seq_len | 推荐配置 | 显存节省 |
|---|---|---|---|
| 7B | ≤8K | FlashAttention单独 | ~80% |
| 13B | ≤4K | FlashAttention + GC | ~85% |
| 70B | ≤2K | FlashAttention + GC + ZeRO-3 | ~90% |
| 70B | 8K+ | 以上全部 + CPU卸载 | ~95% |
判断标准:
- seq_len > 4096 → 必须用FlashAttention
- 显存OOM → 加Gradient Checkpointing
- 显存还不够 → 加ZeRO-3分片
- 还是很紧张 → ZeRO-3 + CPU卸载
代码和文档:
https://atomgit.com/cann/ops-transformer
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐

所有评论(0)