和一个做推荐系统的朋友吃饭,他问我:“我训练千问模型,Attention层特别慢,听说FlashAttention能加速,但我不懂CUDA,这玩意儿到底是怎么快的?”

我想了一下,跟他说:“你把大模型训练想象成一个超大的餐厅厨房。每次做一道菜(处理一个batch),厨师(GPU/NPU)要做三件事:切菜(QK^T矩阵乘)、调味(Softmax)、翻炒(乘V)。传统做法是切完菜放到盘子里(写HBM),再从盘子拿起来调味,调完味又放盘子,再拿起来翻炒——来来回回跑好多趟。”

“FlashAttention是什么?它是一个拼菜师傅,把切菜、调味、翻炒三步合并成一步,在灶台上直接完成,中间不用来回跑厨房和餐厅。”

朋友眼睛亮了:“所以快的原因是不用来回跑?”

“对。专业术语叫IO-aware——不是算力不够,是搬运数据太费时间。”

传统 Attention 的"来回跑"问题

要理解 FlashAttention,先得知道传统 Attention 是怎么工作的。

假设你有一个句子,128个token,每个token用512维向量表示。Attention 要计算每个token和所有其他token的关系,得到一个128×128的注意力矩阵。

传统实现分三步:

# 传统 Attention 实现(简化版)
import torch

def traditional_attention(Q, K, V):
    # 第一步:计算 QK^T,得到注意力得分矩阵
    # 大小:batch × heads × seq_len × seq_len
    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
    # ⚠️ 这里 scores 要写回 HBM(显存),占用 seq_len × seq_len 空间

    # 第二步:Softmax 归一化
    attn_weights = torch.softmax(scores, dim=-1)
    # ⚠️ 这里要读 scores(从 HBM 读),再写 attn_weights(写回 HBM)

    # 第三步:乘 V,得到输出
    output = torch.matmul(attn_weights, V)
    # ⚠️ 这里要读 attn_weights(从 HBM 读)

    return output

# 问题:三步都有 HBM 读写,来回搬运数据占用了 60% 以上的时间
# 算力(矩阵乘)只占了不到 40%

这三步,每一步都要把中间结果写到 HBM(High Bandwidth Memory,显存),下一步再读出来。就像那个餐厅比喻——切完菜放盘子,再从盘子拿起来调味。

当 seq_len 是 4096 的时候,那个注意力矩阵的大小是 4096×4096×2 bytes(float16)= 32MB。看着不大,但这是每个头、每个 batch 都要存的。32 heads × 4 batch = 128 份,总共 4GB——就存个中间结果。

FlashAttention 的"灶台合并"策略

FlashAttention 的核心思路:别把中间结果写回 HBM,在灶台上直接搞定。

具体做法是把 K 和 V 按小块(tile)读入 UB(Unified Buffer,昇腾NPU 上的高速片上存储),在 UB 里完成一个 tile 的 QK^T → Softmax → 乘 V 完整计算,然后把结果累积到输出里。

# FlashAttention 的"灶台合并"思路(伪代码)
def flash_attention_npu(Q, K, V, tile_size=128):
    # Q: (batch, heads, seq_len, dim)
    # K, V: (batch, heads, seq_len, dim)

    output = torch.zeros_like(Q)

    # 把 K 和 V 按 tile 分块
    # 每次只把一块 K_tile 和 V_tile 读到 UB 上
    for i in range(0, seq_len, tile_size):
        K_tile = K[:, :, i:i+tile_size, :]  # 从 HBM 读一小块 K
        V_tile = V[:, :, i:i+tile_size, :]  # 从 HBM 读一小块 V

        # 在 UB 上计算 QK^T(这块很小,UB 放得下)
        scores_tile = torch.matmul(Q, K_tile.transpose(-2, -1))

        # 在 UB 上做 Softmax(不写回 HBM)
        attn_tile = torch.softmax(scores_tile, dim=-1)

        # 在 UB 上乘 V_tile(不写回 HBM)
        output += torch.matmul(attn_tile, V_tile)
        # 只有 output 的最终结果才写回 HBM

    return output

# 优势:中间结果(scores_tile, attn_tile)一直留在 UB 上,不写 HBM
# HBM 访存量从 34GB 降到 6GB(seq_len=4096, batch=4, heads=32)

这个策略在 GPU 上已经很快了,但在昇腾NPU 上还能更快——因为昇腾NPU 的 UB 比 GPU 的 shared memory 大(256KB vs 通常 64~164KB),可以放更大的 tile,减少循环次数。

昇腾NPU 上的 FlashAttention:ops-transformer 的实现

ops-transformer 是昇腾CANN 社区的开源仓库,里面有针对昇腾NPU 高度优化的 FlashAttention 实现。

关键点:ops-transformer 的 FlashAttention 不是简单的算法移植,而是针对达芬奇架构做了深度优化:

  1. Cube 和 Vector 并行:达芬奇架构有两套计算单元——Cube 做矩阵乘(QK^T 和 PV),Vector 做逐元素运算(Softmax)。ops-transformer 的实现让这两步 pipeline 起来,一边算矩阵乘,一边算 Softmax,不浪费时间。

  2. 异步数据搬运:在当前 tile 计算的同时,预加载下一个 tile 的 K 和 V 到 UB。这样计算单元就不会等数据。

  3. Tiling 策略自动调优:不同 seq_len 和 dim 的最优 tile 大小不一样。ops-transformer 的 tiling 策略会根据输入形状自动选择最优分块大小。

用代码验证 ops-transformer 的 FlashAttention 效果:

import torch
import torch_npu

# 确保 torch-npu 已安装(昇腾NPU 的 PyTorch 后端)
# pip install torch-npu==2.1.0  (版本号以 CANN 为准)

# 构造输入
batch, heads, seq_len, dim = 4, 32, 4096, 64
Q = torch.randn(batch, heads, seq_len, dim, dtype=torch.float16).npu()
K = torch.randn(batch, heads, seq_len, dim, dtype=torch.float16).npu()
V = torch.randn(batch, heads, seq_len, dim, dtype=torch.float16).npu()

# 方法1:PyTorch 原生 Attention(逐算子路径,无融合)
with torch.no_grad():
    output_native = torch.nn.functional.scaled_dot_product_attention(Q, K, V, is_causal=True)
torch.npu.synchronize()

# 方法2:ops-transformer 的 FlashAttention(融合算子)
# 需要先编译安装 ops-transformer:
# git clone https://atomgit.com/cann/ops-transformer
# cd ops-transformer && mkdir build && cd build
# cmake .. && make -j && make install

from flash_attention_ops import flash_attention_npu
with torch.no_grad():
    output_fa = flash_attention_npu(Q, K, V, causal=True)
torch.npu.synchronize()

# 对比结果(误差应该在 1e-3 以内)
max_err = (output_native.cpu().float() - output_fa.cpu().float()).abs().max().item()
print(f"PyTorch 原生 vs FlashAttention 最大误差: {max_err:.6f}")
print("误差 < 1e-3,正确性验证通过!" if max_err < 1e-3 else "误差过大,检查实现!")

# 性能对比(用 torch_npu.profiler 抓 trace)
from torch_npu.profiler import profile, ProfilerActivity

with profile(activities=[ProfilerActivity.NPU], export_name="native_attention.json"):
    output_native = torch.nn.functional.scaled_dot_product_attention(Q, K, V, is_causal=True)
torch.npu.synchronize()

with profile(activities=[ProfilerActivity.NPU], export_name="flash_attention.json"):
    output_fa = flash_attention_npu(Q, K, V, causal=True)
torch.npu.synchronize()

# 在 Profiler GUI 里看:
# - native_attention.json:有三个大色块(MatMul / Softmax / MatMul),每个色块前后都有 HBM 读写的小色块
# - flash_attention.json:只有一个大的 FlashAttentionKernel 色块,HBM 读写少很多

怎么确认 FlashAttention 真的生效了?

光看代码不够,得用 Profiler 抓 trace 确认。

# 第一步:跑一次训练,抓 Profiler trace
python train.py --use-flash-attention --profiler-output trace.json

# 第二步:在昇腾 CANN 的 Profiler GUI 里打开 trace.json
# 看 Attention 层对应的色块:
# - 如果看到 MatMul、Softmax、MatMul 三个独立色块 → FlashAttention 没生效
# - 如果看到一个 FlashAttentionKernel 色块 → 生效了!

# 第三步:看 HBM 访存量
# 在 Profiler GUI 的 "Memory" 标签页:
# - 传统 Attention:HBM 访存量 ~34GB(seq_len=4096)
# - FlashAttention:HBM 访存量 ~6GB(节省 82%)

如果 FlashAttention 没生效,检查一下:

  1. 框架适配层配置:PyTorch 的 scaled_dot_product_attention 是否路由到了 ops-transformer 的实现(需要安装 torch-npu 并正确配置)
  2. GE 融合规则:CANN 的 GE 图引擎是否识别到了 MatMul→Softmax→MatMul 的融合模式(查看 GE 的融合日志)
  3. 输入形状:FlashAttention 对 seq_len 有要求(通常是 2 的幂次方,比如 512、1024、2048、4096)

如果碰到问题,可以去 atomgit 上的 Discussions 区提问,社区响应很快。

相关仓库:

https://atomgit.com/cann/ops-transformer

https://atomgit.com/cann/cann-learning-hub

https://atomgit.com/cann/cann-samples

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐