昇腾NPU上FlashAttention算子住哪?ops-transformer仓库全景拆解
第一次在昇腾NPU上跑大模型推理的人,面对CANN生态的几十个仓库,第一反应通常是:FlashAttention算子在哪个仓库里?改算子要改哪个文件?编译完产物在哪?
这些问题的答案全在ops-transformer里。这个仓库是昇腾CANN大模型算子的主仓库,FlashAttention、RMSNorm、RoPE、SwiGLU——Transformer架构用到的核心算子,实现全在这里。
仓库在CANN生态里的位置
昇腾CANN软件栈从下到上分四层,ops-transformer处在算子层:
Ascend 910 NPU硬件
↑
CANN基础层:驱动、runtime、CCE编译器、hccl集合通信
↑
算子层:ops-transformer(大模型算子) ← 现在在这里
catlass(算子模板库,ops-transformer的依赖)
ops-ascendc(通用CV/NLP算子)
↑
图引擎层:ge(算子编排、图优化、算子融合)
↑
框架适配层:torch_npu(PyTorch→NPU桥接)
↑
应用层:LLM推理/训练服务
几个容易混淆的仓库快速区分:
| 仓库 | 层级 | 职责 |
|---|---|---|
| ops-transformer | 算子层 | 大模型算子实现(FlashAttention等) |
| catlass | 算子模板层 | 分块矩阵乘、softmax等基础模板 |
| ge | 图引擎层 | 算子编排、图优化 |
| torch_npu | 框架适配层 | PyTorch接口桥接 |
一句话:要改FlashAttention的算子逻辑,看ops-transformer;要写新算子,看catlass;要调图融合,看ge。
克隆仓库,看目录结构
git clone https://atomgit.com/cann/ops-transformer.git
cd ops-transformer
tree -L 2 -I 'build|output'
核心目录:
ops-transformer/
├── opkernel/ # 算子内核实现(核心代码在这)
│ ├── flash_attention/
│ │ ├── flash_attention_score.cc # 前向:分块+在线softmax
│ │ ├── flash_attention_score_grad.cc # 反向:重计算
│ │ └── flash_attention_score_tiling.cc # 分块策略
│ ├── rms_norm/
│ ├── rope/
│ └── swiglu/
├── opplugin/ # 算子注册:GE怎么找到FlashAttention
│ └── flash_attention/
│ └── flash_attention_op.cc
├── inc/ # 公共头文件
├── scripts/ # 编译脚本
└── cmake/ # 构建配置
FlashAttention的完整实现分布在三个文件里,各管一段:
tiling.cc:算分块参数,适配UB容量score.cc:前向计算,在线softmax+分块attentionscore_grad.cc:反向传播,重计算逻辑
tiling.cc:分块大小怎么定
昇腾NPU的Unified Buffer大约256KB。FlashAttention的分块计算要求Q块、K块、V块、输出块、softmax中间结果同时放在UB里。tiling.cc就是根据UB容量、序列长度、头维度,算出最优分块大小的。
// ops-transformer/opkernel/flash_attention/flash_attention_score_tiling.cc
// 分块策略简化示意(伪代码)
struct TilingParam {
uint32_t block_m; // Q块在seq维的大小
uint32_t block_n; // K/V块在seq维的大小
uint32_t block_k; // head_dim
};
TilingParam CalcTiling(uint32_t seq_len, uint32_t head_dim) {
const uint32_t ub_size = 256 * 1024; // 256KB
const uint32_t dtype_bytes = 2; // FP16
// UB里要同时放:Q块 + K块 + V块 + 输出块 + softmax中间结果
// 5个buffer同时占UB
uint32_t total_elements = ub_size / (5 * dtype_bytes);
// 反推seq维的块大小,向下对齐到16(昇腾NPU对齐要求)
uint32_t block_seq = total_elements / head_dim;
block_seq = (block_seq / 16) * 16;
block_seq = std::min(block_seq, seq_len);
return {block_seq, block_seq, head_dim};
}
分块大小不是越大越好。太大了UB装不下,太小了循环次数多、算子调度开销大。tiling.cc就是在UB容量和循环开销之间找平衡点。
实际代码里还有因果mask的块级跳过优化——如果整块K/V都在当前Q的因果范围之外,直接跳过不计算,省掉整块的计算量。
score.cc:前向计算核心
score.cc是FlashAttention算子的主循环,实现分块计算+在线softmax。昇腾NPU上的实现有几个硬件相关的特殊处理:
// flash_attention_score.cc 主循环(伪代码)
for (uint32_t qi = 0; qi < num_q_blocks; qi++) {
// 加载Q块到UB
LoadCube(q_base, qi * block_m, cur_q);
// 初始化在线softmax累加器
float row_max = -1e9f;
float row_sum = 0.0f;
float* acc_out = ub_out_buffer;
for (uint32_t ki = 0; ki < num_kv_blocks; ki++) {
// 加载K/V块(与计算流水化)
LoadCube(k_base, ki * block_n, cur_k);
LoadCube(v_base, ki * block_n, cur_v);
// Q × K^T,达芬奇Cube单元做矩阵乘
Gemm(cur_q, cur_k.T, local_scores, block_m, block_n, head_dim);
Scale(local_scores, 1.0f / sqrtf(head_dim));
// causal mask:块级跳过,不算下三角矩阵
if (IsBlockOutsideCausal(qi, ki, block_m, block_n)) {
continue;
}
// 在线softmax更新
float local_max = ReduceMax(local_scores, block_m * block_n);
float new_max = fmaxf(row_max, local_max);
// 关键:缩放之前的累加结果
// 数学上等价于把所有分数放到同一个exp尺度
float correction = expf(row_max - new_max);
row_sum *= correction;
Scale(acc_out, correction, block_m * head_dim);
// 加上当前块的贡献
ApplyExp(local_scores, -new_max, block_m * block_n);
float local_sum = ReduceSum(local_scores, block_m * block_n);
row_sum += local_sum;
Accumulate(acc_out, local_scores, cur_v, block_m, block_n, head_dim);
row_max = new_max;
}
// 归一化输出
Scale(acc_out, 1.0f / row_sum, block_m * head_dim);
StoreCube(out_base, qi * block_m, acc_out);
}
两个昇腾特有的优化值得注意:
causal mask块级跳过:不是先算出完整下三角矩阵再乘mask,而是判断整块是否在mask范围外,是就直接continue。FlashAttention的分块计算天然适合这种块级剪枝。
搬运和计算流水化:加载第(i+1)块K/V的同时,Cube单元在计算第i块的结果。达芬奇架构的DMA引擎和计算单元独立,overlap得好可以掩盖搬运延迟。
score_grad.cc:反向传播的重计算
FlashAttention前向时不存scores和attn这两个N×N中间矩阵,反向传播时需要重新算一遍——这就是重计算(recomputation)。
// flash_attention_score_grad.cc 反向逻辑(伪代码)
void FlashAttentionBackward(
Tensor grad_out, // 输出梯度(上游传下来的)
Tensor q, k, v, // 前向输入(需要重计算)
Tensor out, // 前向输出(必须存,反向要用)
Tensor grad_q, grad_k, grad_v // 要算的梯度
) {
// 第1步:算dP = grad_out × V^T
for each q_block:
dP = Gemm(grad_out_block, v_block.T, ...);
// softmax反向:dS = P * (dP - sum(dP * P))
dS = SoftmaxGrad(dP, out_block);
// 累积到grad_q
grad_q_block += Gemm(dS, K, ...);
// 第2步:重计算前向的Q×K^T,再反向到grad_k和grad_v
for each block:
scores = Gemm(Q_block, K_block.T) * scale;
// 链式法则
grad_k += Gemm(Q.T, dS);
grad_v += Gemm(dS.T, V);
}
重计算的代价是多算一遍前向(时间换空间),但换来了显存从O(N²)降到O(N)。昇腾NPU上算力充沛、显存带宽是瓶颈,这个trade-off很划算。
算子注册:GE怎么找到FlashAttention
opplugin/flash_attention_op.cc负责把FlashAttention注册到GE图引擎,让torch_npu.npu_flash_attention()能调到它:
// opplugin/flash_attention/flash_attention_op.cc
// 算子注册(伪代码)
IMPLEMT_INFERFUNC(FlashAttentionScore, FlashAttentionScoreInfer) {
// 输出shape推导:跟输入Q一致
auto q_shape = op.GetInputDesc("q").GetShape().GetDims();
auto out_desc = op.GetOutputDesc("output");
out_desc.SetShape(ge::GeShape(q_shape));
op.UpdateOutputDesc("output", out_desc);
return GRAPH_SUCCESS;
}
REGISTER_OP("FlashAttentionScore")
.Input("q: float16")
.Input("k: float16")
.Input("v: float16")
.Output("output: float16")
.Attr("head_num: int")
.Attr("scale: float")
.InferShapeAndTypeFunc(FlashAttentionScoreInfer)
.FrameworkType("ONNX");
注册完后,PyTorch端调用torch_npu.npu_flash_attention(q, k, v, ...)时,GE图引擎根据算子名"FlashAttentionScore"找到这个注册,把计算委托给opkernel/flash_attention/下的实现。
跟catlass的依赖关系
catlass是昇腾的算子模板库,提供分块矩阵乘、reduce、softmax等基础操作的模板化实现。ops-transformer的FlashAttention依赖catlass的模板:
catlass(算子模板:GEMM、softmax、reduce)
↑ include依赖
ops-transformer(大模型算子:FlashAttention等)
↑ 算子注册
ge(图引擎)
↑ 框架适配
torch_npu(PyTorch接口)
开发新算子时:catlass提供"积木",ops-transformer负责"搭房子"。FlashAttention就是用catlass的分块矩阵乘和reduce模板,搭出了完整的在线softmax+分块attention逻辑。
编译和验证流程
改了FlashAttention的代码,需要重新编译:
cd ops-transformer
mkdir build && cd build
# CANN环境必须先source
source /usr/local/Ascend/ascend-toolkit/set_env.sh
# cmake配置
cmake .. -DCANN_INSTALL_PATH=/usr/local/Ascend/ascend-toolkit \
-DCMAKE_BUILD_TYPE=Release
# 只编译FlashAttention,比全量编译快很多
make flash_attention -j8
# 编译产物
ls output/opkernel/libflash_attention*.so
编译完替换torch_npu对应的so,或者指定自定义算子路径:
bash复制
export ASCEND_CUSTOM_OP_PATH=/path/to/ops-transformer/output
验证改动是否引入回归,用数值验证脚本对比FlashAttention输出和标准attention的差异。
克隆ops-transformer仓库,按opkernel/flash_attention/flash_attention_score.cc→flash_attention_score_tiling.cc→opplugin/flash_attention_op.cc的顺序读代码。重点关注tiling的分块策略和score.cc的在线softmax实现,理解了这两块就抓住了FlashAttention昇腾实现的主线。
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐

所有评论(0)