FlashAttention:让大模型训练快三倍的“拼菜师傅“
和一个做推荐系统的朋友吃饭,他问我:“我训练千问模型,Attention层特别慢,听说FlashAttention能加速,但我不懂CUDA,这玩意儿到底是怎么快的?”
我想了一下,跟他说:“你把大模型训练想象成一个超大的餐厅厨房。每次做一道菜(处理一个batch),厨师(GPU/NPU)要做三件事:切菜(QK^T矩阵乘)、调味(Softmax)、翻炒(乘V)。传统做法是切完菜放到盘子里(写HBM),再从盘子拿起来调味,调完味又放盘子,再拿起来翻炒——来来回回跑好多趟。”
“FlashAttention是什么?它是一个拼菜师傅,把切菜、调味、翻炒三步合并成一步,在灶台上直接完成,中间不用来回跑厨房和餐厅。”
朋友眼睛亮了:“所以快的原因是不用来回跑?”
“对。专业术语叫IO-aware——不是算力不够,是搬运数据太费时间。”
传统 Attention 的"来回跑"问题
要理解 FlashAttention,先得知道传统 Attention 是怎么工作的。
假设你有一个句子,128个token,每个token用512维向量表示。Attention 要计算每个token和所有其他token的关系,得到一个128×128的注意力矩阵。
传统实现分三步:
# 传统 Attention 实现(简化版)
import torch
def traditional_attention(Q, K, V):
# 第一步:计算 QK^T,得到注意力得分矩阵
# 大小:batch × heads × seq_len × seq_len
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
# ⚠️ 这里 scores 要写回 HBM(显存),占用 seq_len × seq_len 空间
# 第二步:Softmax 归一化
attn_weights = torch.softmax(scores, dim=-1)
# ⚠️ 这里要读 scores(从 HBM 读),再写 attn_weights(写回 HBM)
# 第三步:乘 V,得到输出
output = torch.matmul(attn_weights, V)
# ⚠️ 这里要读 attn_weights(从 HBM 读)
return output
# 问题:三步都有 HBM 读写,来回搬运数据占用了 60% 以上的时间
# 算力(矩阵乘)只占了不到 40%
这三步,每一步都要把中间结果写到 HBM(High Bandwidth Memory,显存),下一步再读出来。就像那个餐厅比喻——切完菜放盘子,再从盘子拿起来调味。
当 seq_len 是 4096 的时候,那个注意力矩阵的大小是 4096×4096×2 bytes(float16)= 32MB。看着不大,但这是每个头、每个 batch 都要存的。32 heads × 4 batch = 128 份,总共 4GB——就存个中间结果。
FlashAttention 的"灶台合并"策略
FlashAttention 的核心思路:别把中间结果写回 HBM,在灶台上直接搞定。
具体做法是把 K 和 V 按小块(tile)读入 UB(Unified Buffer,昇腾NPU 上的高速片上存储),在 UB 里完成一个 tile 的 QK^T → Softmax → 乘 V 完整计算,然后把结果累积到输出里。
# FlashAttention 的"灶台合并"思路(伪代码)
def flash_attention_npu(Q, K, V, tile_size=128):
# Q: (batch, heads, seq_len, dim)
# K, V: (batch, heads, seq_len, dim)
output = torch.zeros_like(Q)
# 把 K 和 V 按 tile 分块
# 每次只把一块 K_tile 和 V_tile 读到 UB 上
for i in range(0, seq_len, tile_size):
K_tile = K[:, :, i:i+tile_size, :] # 从 HBM 读一小块 K
V_tile = V[:, :, i:i+tile_size, :] # 从 HBM 读一小块 V
# 在 UB 上计算 QK^T(这块很小,UB 放得下)
scores_tile = torch.matmul(Q, K_tile.transpose(-2, -1))
# 在 UB 上做 Softmax(不写回 HBM)
attn_tile = torch.softmax(scores_tile, dim=-1)
# 在 UB 上乘 V_tile(不写回 HBM)
output += torch.matmul(attn_tile, V_tile)
# 只有 output 的最终结果才写回 HBM
return output
# 优势:中间结果(scores_tile, attn_tile)一直留在 UB 上,不写 HBM
# HBM 访存量从 34GB 降到 6GB(seq_len=4096, batch=4, heads=32)
这个策略在 GPU 上已经很快了,但在昇腾NPU 上还能更快——因为昇腾NPU 的 UB 比 GPU 的 shared memory 大(256KB vs 通常 64~164KB),可以放更大的 tile,减少循环次数。
昇腾NPU 上的 FlashAttention:ops-transformer 的实现
ops-transformer 是昇腾CANN 社区的开源仓库,里面有针对昇腾NPU 高度优化的 FlashAttention 实现。
关键点:ops-transformer 的 FlashAttention 不是简单的算法移植,而是针对达芬奇架构做了深度优化:
-
Cube 和 Vector 并行:达芬奇架构有两套计算单元——Cube 做矩阵乘(QK^T 和 PV),Vector 做逐元素运算(Softmax)。ops-transformer 的实现让这两步 pipeline 起来,一边算矩阵乘,一边算 Softmax,不浪费时间。
-
异步数据搬运:在当前 tile 计算的同时,预加载下一个 tile 的 K 和 V 到 UB。这样计算单元就不会等数据。
-
Tiling 策略自动调优:不同 seq_len 和 dim 的最优 tile 大小不一样。ops-transformer 的 tiling 策略会根据输入形状自动选择最优分块大小。
用代码验证 ops-transformer 的 FlashAttention 效果:
import torch
import torch_npu
# 确保 torch-npu 已安装(昇腾NPU 的 PyTorch 后端)
# pip install torch-npu==2.1.0 (版本号以 CANN 为准)
# 构造输入
batch, heads, seq_len, dim = 4, 32, 4096, 64
Q = torch.randn(batch, heads, seq_len, dim, dtype=torch.float16).npu()
K = torch.randn(batch, heads, seq_len, dim, dtype=torch.float16).npu()
V = torch.randn(batch, heads, seq_len, dim, dtype=torch.float16).npu()
# 方法1:PyTorch 原生 Attention(逐算子路径,无融合)
with torch.no_grad():
output_native = torch.nn.functional.scaled_dot_product_attention(Q, K, V, is_causal=True)
torch.npu.synchronize()
# 方法2:ops-transformer 的 FlashAttention(融合算子)
# 需要先编译安装 ops-transformer:
# git clone https://atomgit.com/cann/ops-transformer
# cd ops-transformer && mkdir build && cd build
# cmake .. && make -j && make install
from flash_attention_ops import flash_attention_npu
with torch.no_grad():
output_fa = flash_attention_npu(Q, K, V, causal=True)
torch.npu.synchronize()
# 对比结果(误差应该在 1e-3 以内)
max_err = (output_native.cpu().float() - output_fa.cpu().float()).abs().max().item()
print(f"PyTorch 原生 vs FlashAttention 最大误差: {max_err:.6f}")
print("误差 < 1e-3,正确性验证通过!" if max_err < 1e-3 else "误差过大,检查实现!")
# 性能对比(用 torch_npu.profiler 抓 trace)
from torch_npu.profiler import profile, ProfilerActivity
with profile(activities=[ProfilerActivity.NPU], export_name="native_attention.json"):
output_native = torch.nn.functional.scaled_dot_product_attention(Q, K, V, is_causal=True)
torch.npu.synchronize()
with profile(activities=[ProfilerActivity.NPU], export_name="flash_attention.json"):
output_fa = flash_attention_npu(Q, K, V, causal=True)
torch.npu.synchronize()
# 在 Profiler GUI 里看:
# - native_attention.json:有三个大色块(MatMul / Softmax / MatMul),每个色块前后都有 HBM 读写的小色块
# - flash_attention.json:只有一个大的 FlashAttentionKernel 色块,HBM 读写少很多
怎么确认 FlashAttention 真的生效了?
光看代码不够,得用 Profiler 抓 trace 确认。
# 第一步:跑一次训练,抓 Profiler trace
python train.py --use-flash-attention --profiler-output trace.json
# 第二步:在昇腾 CANN 的 Profiler GUI 里打开 trace.json
# 看 Attention 层对应的色块:
# - 如果看到 MatMul、Softmax、MatMul 三个独立色块 → FlashAttention 没生效
# - 如果看到一个 FlashAttentionKernel 色块 → 生效了!
# 第三步:看 HBM 访存量
# 在 Profiler GUI 的 "Memory" 标签页:
# - 传统 Attention:HBM 访存量 ~34GB(seq_len=4096)
# - FlashAttention:HBM 访存量 ~6GB(节省 82%)
如果 FlashAttention 没生效,检查一下:
- 框架适配层配置:PyTorch 的
scaled_dot_product_attention是否路由到了 ops-transformer 的实现(需要安装 torch-npu 并正确配置) - GE 融合规则:CANN 的 GE 图引擎是否识别到了 MatMul→Softmax→MatMul 的融合模式(查看 GE 的融合日志)
- 输入形状:FlashAttention 对 seq_len 有要求(通常是 2 的幂次方,比如 512、1024、2048、4096)
如果碰到问题,可以去 atomgit 上的 Discussions 区提问,社区响应很快。
相关仓库:
https://atomgit.com/cann/ops-transformer
https://atomgit.com/cann/cann-learning-hub
https://atomgit.com/cann/cann-samples
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐
所有评论(0)