GE 怎么做算子融合
·

两个相邻的算子,如果前一个的输出正好是后一个的输入,可以合成一个算子。融合之后省掉一次显存读写,性能提升显著。GE 在 ATC 编译阶段自动做算子融合,不需要手动干预。
下面用代码示例解析融合规则和触发条件。
一、算子融合的基本原理
为什么融合能提速
# 未融合:两次显存读写
def no_fusion(x, conv_weight, bn_weight, bn_bias):
conv_out = conv(x, conv_weight) # 写 HBM
bn_out = batch_norm(conv_out, bn_weight, bn_bias) # 读 HBM,写 HBM
relu_out = relu(bn_out) # 读 HBM,写 HBM
return relu_out
# 融合后:一次显存读写
def fused(x, conv_weight, bn_weight, bn_bias):
# 整个计算在 UB 里完成,不写回 HBM
return fused_conv_bn_relu(x, conv_weight, bn_weight, bn_bias)
融合前后的显存读写对比
# 计算显存读写量
def calc_memory_traffic(shape, fused=False):
batch, channels, height, width = shape
tensor_size = batch * channels * height * width * 2 # FP16 = 2 bytes
if fused:
# 融合后:只写一次输出
return tensor_size
else:
# 未融合:conv 输出 + bn 输出 + relu 输出
return tensor_size * 3
shape = (1, 64, 56, 56)
print(f"未融合: {calc_memory_traffic(shape, fused=False) / 1024:.1f} KB")
print(f"融合后: {calc_memory_traffic(shape, fused=True) / 1024:.1f} KB")
# 输出:
# 未融合: 705.6 KB
# 融合后: 235.2 KB
二、Conv + BN + ReLU 融合
这是最常见的融合模式,ResNet 等模型里大量出现。
PyTorch 模型定义
import torch
import torch.nn as nn
class ConvBNReLU(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding=kernel_size//2)
self.bn = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
return x
# 创建模型
model = ConvBNReLU(64, 128, 3).npu().eval()
导出 ONNX 并编译
# 导出 ONNX
dummy_input = torch.randn(1, 64, 56, 56).npu()
torch.onnx.export(model, dummy_input, "conv_bn_relu.onnx", opset_version=11)
# ATC 编译(自动触发融合)
import os
os.system("""
atc --model=conv_bn_relu.onnx \
--framework=5 \
--output=conv_bn_relu \
--enable_fusion=true
""")
查看融合结果
# 编译日志中查看融合信息
atc --model=conv_bn_relu.onnx \
--framework=5 \
--output=conv_bn_relu \
--log=info 2>&1 | grep -i fusion
# 输出示例:
# [GE] Fusion: Conv + BatchNorm + ReLU -> ConvBNReLU
# [GE] Fused 3 nodes into 1
BN 参数融合到 Conv
BatchNorm 在推理时可以完全融合到 Conv 权重里:
def fuse_bn_to_conv(conv, bn):
"""把 BN 参数融合到 Conv 权重"""
# BN 公式:y = (x - mean) / sqrt(var + eps) * gamma + beta
# 融合到 Conv:w' = w * gamma / sqrt(var + eps)
# b' = (b - mean) * gamma / sqrt(var + eps) + beta
w = conv.weight.data
b = conv.bias.data if conv.bias is not None else torch.zeros(w.shape[0])
mean = bn.running_mean
var = bn.running_var
gamma = bn.weight
beta = bn.bias
eps = bn.eps
# 计算融合后的权重和偏置
std = torch.sqrt(var + eps)
w_fused = w * gamma.view(-1, 1, 1, 1) / std.view(-1, 1, 1, 1)
b_fused = (b - mean) * gamma / std + beta
return w_fused, b_fused
# 使用示例
conv = nn.Conv2d(64, 128, 3)
bn = nn.BatchNorm2d(128)
bn.eval() # 切换到推理模式
w_fused, b_fused = fuse_bn_to_conv(conv, bn)
conv.weight.data = w_fused
conv.bias = nn.Parameter(b_fused)
# 现在 bn 可以删除,直接用 conv
三、MatMul + Add + ReLU 融合
线性层后接激活函数是常见的模式。
PyTorch 示例
class LinearReLU(nn.Module):
def __init__(self, in_features, out_features):
super().__init__()
self.linear = nn.Linear(in_features, out_features)
self.relu = nn.ReLU()
def forward(self, x):
x = self.linear(x) # MatMul + Add
x = self.relu(x)
return x
model = LinearReLU(1024, 512).npu().eval()
导出并编译
dummy_input = torch.randn(1, 1024).npu()
torch.onnx.export(model, dummy_input, "linear_relu.onnx")
# 编译
os.system("""
atc --model=linear_relu.onnx \
--framework=5 \
--output=linear_relu \
--enable_fusion=true
""")
融合规则代码
# GE 内部的融合规则(简化版)
def check_matmul_relu_fusion(node_list):
"""检查是否可以融合 MatMul + ReLU"""
for i, node in enumerate(node_list[:-1]):
next_node = node_list[i + 1]
# 条件1:当前节点是 MatMul 或 Add
if node.op_type not in ["MatMul", "Add"]:
continue
# 条件2:下一个节点是 ReLU
if next_node.op_type != "Relu":
continue
# 条件3:数据流连续
if node.output[0] != next_node.input[0]:
continue
# 可以融合
print(f"Can fuse: {node.name} -> {next_node.name}")
return True
return False
四、FlashAttention 融合
Transformer 模型的核心优化:把 QK^T + Softmax + ×V 合成一个算子。
标准 Attention
def standard_attention(Q, K, V):
"""标准 Attention 实现"""
# Q: [batch, heads, seq, dim]
# K, V: [batch, heads, seq, dim]
d_k = Q.shape[-1]
# 计算 QK^T / sqrt(d_k)
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
# Softmax
attn_weights = torch.softmax(scores, dim=-1)
# 乘以 V
output = torch.matmul(attn_weights, V)
return output
FlashAttention 融合触发
import torch_npu
from ops_transformer import flash_attention
def flash_attention_wrapper(Q, K, V):
"""FlashAttention 接口"""
return flash_attention(
query=Q,
key=K,
value=V,
head_num=32,
input_layout="BSND" # Batch, Seq, Num_heads, Dim
)
# 替换标准 Attention
model.attention.forward = flash_attention_wrapper
编译时自动识别
# 编译时 GE 会自动识别 Attention 模式
atc --model=bert.onnx \
--framework=5 \
--output=bert \
--enable_fusion=true \
--log=info 2>&1 | grep -i attention
# 输出示例:
# [GE] Detected Attention pattern: MatMul + Softmax + MatMul
# [GE] Fusion: Apply FlashAttention
# [GE] FlashAttention block size: 64x64
五、ElementWise 算子融合
多个逐元素操作可以融合成一个 kernel。
融合前
def no_fusion_elementwise(x):
x = x + 1 # Add kernel
x = x * 2 # Mul kernel
x = torch.relu(x) # ReLU kernel
return x
融合后
def fused_elementwise(x):
# 三个操作合成一个 kernel
return fused_add_mul_relu(x, 1, 2)
# GE 会自动识别并融合
查看融合效果
import torch
import torch_npu
class ElementwiseModel(nn.Module):
def forward(self, x):
return torch.relu(x * 2 + 1)
model = ElementwiseModel().npu()
input_tensor = torch.randn(1, 1024, 1024).npu()
# 导出计算图
import os
os.environ["GE_GRAPH_SAVE_PATH"] = "./graph"
model(input_tensor)
# 查看节点数量
# 未优化:3 个节点(Add, Mul, ReLU)
# 优化后:1 个节点(FusedElementwise)
六、融合规则配置
GE 支持配置融合规则,可以开启或关闭特定的融合。
关闭融合
atc --model=model.onnx \
--framework=5 \
--output=model \
--enable_fusion=false
部分开启融合
# 通过环境变量控制融合规则
import os
# 关闭特定融合
os.environ["GE_DISABLE_CONV_BN_RELU_FUSION"] = "1"
os.environ["GE_DISABLE_FLASH_ATTENTION_FUSION"] = "1"
# 编译
os.system("atc --model=model.onnx --framework=5 --output=model")
自定义融合规则
# 定义自定义融合规则(高级用法)
fusion_rules = {
"conv_bn_relu": {
"pattern": ["Conv", "BatchNormalization", "Relu"],
"output_op": "ConvBNReLU",
"conditions": {
"bn_mode": "inference"
}
},
"matmul_relu": {
"pattern": ["MatMul", "Relu"],
"output_op": "MatMulReLU",
"conditions": {}
}
}
# 保存配置文件
import json
with open("fusion_rules.json", "w") as f:
json.dump(fusion_rules, f)
七、查看融合效果
统计融合节点数
# 编译日志中统计融合效果
atc --model=resnet50.onnx \
--framework=5 \
--output=resnet50 \
--log=info 2>&1 | grep -E "nodes|Fused"
# 输出示例:
# [GE] Original graph: 178 nodes
# [GE] After optimization: 125 nodes
# [GE] Fused 53 nodes
# [GE] Fusion details:
# [GE] Conv+BN+ReLU: 16 times
# [GE] MatMul+Add: 3 times
# [GE] ElementWise: 8 times
性能对比测试
import torch
import torch_npu
import time
# 加载未融合模型
model_no_fusion = torch.jit.load("model_no_fusion.om", map_location="npu:0")
# 加载融合模型
model_fused = torch.jit.load("model_fused.om", map_location="npu:0")
# 性能测试
def benchmark(model, input_tensor, iterations=100):
# 预热
for _ in range(10):
model(input_tensor)
torch.npu.synchronize()
start = time.time()
for _ in range(iterations):
model(input_tensor)
torch.npu.synchronize()
end = time.time()
return (end - start) / iterations * 1000 # ms
input_tensor = torch.randn(1, 3, 224, 224).npu()
print(f"未融合: {benchmark(model_no_fusion, input_tensor):.2f} ms")
print(f"融合后: {benchmark(model_fused, input_tensor):.2f} ms")
八、融合失败排查
某些情况下融合不会触发:
# 情况1:BN 在训练模式
class Model(nn.Module):
def __init__(self):
self.bn = nn.BatchNorm2d(64)
def forward(self, x):
return self.bn(x) # 训练模式下不会融合
# 解决:切换到 eval 模式
model.eval()
# 情况2:中间结果被其他节点使用
class Model(nn.Module):
def forward(self, x):
conv_out = self.conv(x)
bn_out = self.bn(conv_out)
# conv_out 被其他节点使用,不能融合
side_output = self.other_op(conv_out)
return self.relu(bn_out), side_output
# 情况3:精度不匹配
# 某些融合在 FP16 下精度不达标,GE 会自动跳过
参考资源
- 算子融合规则详解:https://www.hiascend.com/document/detail/zh/CANN/
- GE 优化指南:https://www.hiascend.com/document/detail/zh/CANN/
- FlashAttention 实现:https://atomgit.com/cann/ops-transformer
- 性能调优最佳实践:https://www.hiascend.com/document/detail/zh/CANN/
总结
GE 的算子融合分三类:Conv+BN+ReLU 这种前后依赖的算子合成一个、ElementWise 多个逐元素操作合成一个 kernel、FlashAttention 这种复杂模式识别后优化。融合的触发条件是数据流连续、没有副作用、精度可控。通过 --enable_fusion 控制开关,编译日志里能看到融合详情。如果融合没触发,检查 BN 是否在 eval 模式、中间结果是否被复用、精度是否符合要求。融合后的模型节点数减少 30-50%,推理速度提升 20-40%。
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐


所有评论(0)