在这里插入图片描述
两个相邻的算子,如果前一个的输出正好是后一个的输入,可以合成一个算子。融合之后省掉一次显存读写,性能提升显著。GE 在 ATC 编译阶段自动做算子融合,不需要手动干预。

下面用代码示例解析融合规则和触发条件。


一、算子融合的基本原理

为什么融合能提速

# 未融合:两次显存读写
def no_fusion(x, conv_weight, bn_weight, bn_bias):
    conv_out = conv(x, conv_weight)      # 写 HBM
    bn_out = batch_norm(conv_out, bn_weight, bn_bias)  # 读 HBM,写 HBM
    relu_out = relu(bn_out)              # 读 HBM,写 HBM
    return relu_out

# 融合后:一次显存读写
def fused(x, conv_weight, bn_weight, bn_bias):
    # 整个计算在 UB 里完成,不写回 HBM
    return fused_conv_bn_relu(x, conv_weight, bn_weight, bn_bias)

融合前后的显存读写对比

# 计算显存读写量
def calc_memory_traffic(shape, fused=False):
    batch, channels, height, width = shape
    tensor_size = batch * channels * height * width * 2  # FP16 = 2 bytes
    
    if fused:
        # 融合后:只写一次输出
        return tensor_size
    else:
        # 未融合:conv 输出 + bn 输出 + relu 输出
        return tensor_size * 3

shape = (1, 64, 56, 56)
print(f"未融合: {calc_memory_traffic(shape, fused=False) / 1024:.1f} KB")
print(f"融合后: {calc_memory_traffic(shape, fused=True) / 1024:.1f} KB")

# 输出:
# 未融合: 705.6 KB
# 融合后: 235.2 KB

二、Conv + BN + ReLU 融合

这是最常见的融合模式,ResNet 等模型里大量出现。

PyTorch 模型定义

import torch
import torch.nn as nn

class ConvBNReLU(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding=kernel_size//2)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
    
    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x

# 创建模型
model = ConvBNReLU(64, 128, 3).npu().eval)

导出 ONNX 并编译

# 导出 ONNX
dummy_input = torch.randn(1, 64, 56, 56).npu()
torch.onnx.export(model, dummy_input, "conv_bn_relu.onnx", opset_version=11)

# ATC 编译(自动触发融合)
import os
os.system("""
atc --model=conv_bn_relu.onnx \
    --framework=5 \
    --output=conv_bn_relu \
    --enable_fusion=true
""")

查看融合结果

# 编译日志中查看融合信息
atc --model=conv_bn_relu.onnx \
    --framework=5 \
    --output=conv_bn_relu \
    --log=info 2>&1 | grep -i fusion

# 输出示例:
# [GE] Fusion: Conv + BatchNorm + ReLU -> ConvBNReLU
# [GE] Fused 3 nodes into 1

BN 参数融合到 Conv

BatchNorm 在推理时可以完全融合到 Conv 权重里:

def fuse_bn_to_conv(conv, bn):
    """把 BN 参数融合到 Conv 权重"""
    # BN 公式:y = (x - mean) / sqrt(var + eps) * gamma + beta
    # 融合到 Conv:w' = w * gamma / sqrt(var + eps)
    #             b' = (b - mean) * gamma / sqrt(var + eps) + beta
    
    w = conv.weight.data
    b = conv.bias.data if conv.bias is not None else torch.zeros(w.shape[0])
    
    mean = bn.running_mean
    var = bn.running_var
    gamma = bn.weight
    beta = bn.bias
    eps = bn.eps
    
    # 计算融合后的权重和偏置
    std = torch.sqrt(var + eps)
    w_fused = w * gamma.view(-1, 1, 1, 1) / std.view(-1, 1, 1, 1)
    b_fused = (b - mean) * gamma / std + beta
    
    return w_fused, b_fused

# 使用示例
conv = nn.Conv2d(64, 128, 3)
bn = nn.BatchNorm2d(128)
bn.eval)  # 切换到推理模式

w_fused, b_fused = fuse_bn_to_conv(conv, bn)
conv.weight.data = w_fused
conv.bias = nn.Parameter(b_fused)

# 现在 bn 可以删除,直接用 conv

三、MatMul + Add + ReLU 融合

线性层后接激活函数是常见的模式。

PyTorch 示例

class LinearReLU(nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()
        self.linear = nn.Linear(in_features, out_features)
        self.relu = nn.ReLU()
    
    def forward(self, x):
        x = self.linear(x)  # MatMul + Add
        x = self.relu(x)
        return x

model = LinearReLU(1024, 512).npu().eval)

导出并编译

dummy_input = torch.randn(1, 1024).npu()
torch.onnx.export(model, dummy_input, "linear_relu.onnx")

# 编译
os.system("""
atc --model=linear_relu.onnx \
    --framework=5 \
    --output=linear_relu \
    --enable_fusion=true
""")

融合规则代码

# GE 内部的融合规则(简化版)
def check_matmul_relu_fusion(node_list):
    """检查是否可以融合 MatMul + ReLU"""
    for i, node in enumerate(node_list[:-1]):
        next_node = node_list[i + 1]
        
        # 条件1:当前节点是 MatMul 或 Add
        if node.op_type not in ["MatMul", "Add"]:
            continue
        
        # 条件2:下一个节点是 ReLU
        if next_node.op_type != "Relu":
            continue
        
        # 条件3:数据流连续
        if node.output[0] != next_node.input[0]:
            continue
        
        # 可以融合
        print(f"Can fuse: {node.name} -> {next_node.name}")
        return True
    
    return False

四、FlashAttention 融合

Transformer 模型的核心优化:把 QK^T + Softmax + ×V 合成一个算子。

标准 Attention

def standard_attention(Q, K, V):
    """标准 Attention 实现"""
    # Q: [batch, heads, seq, dim]
    # K, V: [batch, heads, seq, dim]
    
    d_k = Q.shape[-1]
    
    # 计算 QK^T / sqrt(d_k)
    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
    
    # Softmax
    attn_weights = torch.softmax(scores, dim=-1)
    
    # 乘以 V
    output = torch.matmul(attn_weights, V)
    
    return output

FlashAttention 融合触发

import torch_npu
from ops_transformer import flash_attention

def flash_attention_wrapper(Q, K, V):
    """FlashAttention 接口"""
    return flash_attention(
        query=Q,
        key=K,
        value=V,
        head_num=32,
        input_layout="BSND"  # Batch, Seq, Num_heads, Dim
    )

# 替换标准 Attention
model.attention.forward = flash_attention_wrapper

编译时自动识别

# 编译时 GE 会自动识别 Attention 模式
atc --model=bert.onnx \
    --framework=5 \
    --output=bert \
    --enable_fusion=true \
    --log=info 2>&1 | grep -i attention

# 输出示例:
# [GE] Detected Attention pattern: MatMul + Softmax + MatMul
# [GE] Fusion: Apply FlashAttention
# [GE] FlashAttention block size: 64x64

五、ElementWise 算子融合

多个逐元素操作可以融合成一个 kernel。

融合前

def no_fusion_elementwise(x):
    x = x + 1       # Add kernel
    x = x * 2       # Mul kernel
    x = torch.relu(x)  # ReLU kernel
    return x

融合后

def fused_elementwise(x):
    # 三个操作合成一个 kernel
    return fused_add_mul_relu(x, 1, 2)

# GE 会自动识别并融合

查看融合效果

import torch
import torch_npu

class ElementwiseModel(nn.Module):
    def forward(self, x):
        return torch.relu(x * 2 + 1)

model = ElementwiseModel().npu()
input_tensor = torch.randn(1, 1024, 1024).npu()

# 导出计算图
import os
os.environ["GE_GRAPH_SAVE_PATH"] = "./graph"
model(input_tensor)

# 查看节点数量
# 未优化:3 个节点(Add, Mul, ReLU)
# 优化后:1 个节点(FusedElementwise)

六、融合规则配置

GE 支持配置融合规则,可以开启或关闭特定的融合。

关闭融合

atc --model=model.onnx \
    --framework=5 \
    --output=model \
    --enable_fusion=false

部分开启融合

# 通过环境变量控制融合规则
import os

# 关闭特定融合
os.environ["GE_DISABLE_CONV_BN_RELU_FUSION"] = "1"
os.environ["GE_DISABLE_FLASH_ATTENTION_FUSION"] = "1"

# 编译
os.system("atc --model=model.onnx --framework=5 --output=model")

自定义融合规则

# 定义自定义融合规则(高级用法)
fusion_rules = {
    "conv_bn_relu": {
        "pattern": ["Conv", "BatchNormalization", "Relu"],
        "output_op": "ConvBNReLU",
        "conditions": {
            "bn_mode": "inference"
        }
    },
    "matmul_relu": {
        "pattern": ["MatMul", "Relu"],
        "output_op": "MatMulReLU",
        "conditions": {}
    }
}

# 保存配置文件
import json
with open("fusion_rules.json", "w") as f:
    json.dump(fusion_rules, f)

七、查看融合效果

统计融合节点数

# 编译日志中统计融合效果
atc --model=resnet50.onnx \
    --framework=5 \
    --output=resnet50 \
    --log=info 2>&1 | grep -E "nodes|Fused"

# 输出示例:
# [GE] Original graph: 178 nodes
# [GE] After optimization: 125 nodes
# [GE] Fused 53 nodes
# [GE] Fusion details:
# [GE]   Conv+BN+ReLU: 16 times
# [GE]   MatMul+Add: 3 times
# [GE]   ElementWise: 8 times

性能对比测试

import torch
import torch_npu
import time

# 加载未融合模型
model_no_fusion = torch.jit.load("model_no_fusion.om", map_location="npu:0")

# 加载融合模型
model_fused = torch.jit.load("model_fused.om", map_location="npu:0")

# 性能测试
def benchmark(model, input_tensor, iterations=100):
    # 预热
    for _ in range(10):
        model(input_tensor)
    
    torch.npu.synchronize()
    start = time.time()
    
    for _ in range(iterations):
        model(input_tensor)
    
    torch.npu.synchronize()
    end = time.time()
    
    return (end - start) / iterations * 1000  # ms

input_tensor = torch.randn(1, 3, 224, 224).npu()
print(f"未融合: {benchmark(model_no_fusion, input_tensor):.2f} ms")
print(f"融合后: {benchmark(model_fused, input_tensor):.2f} ms")

八、融合失败排查

某些情况下融合不会触发:

# 情况1:BN 在训练模式
class Model(nn.Module):
    def __init__(self):
        self.bn = nn.BatchNorm2d(64)
    
    def forward(self, x):
        return self.bn(x)  # 训练模式下不会融合

# 解决:切换到 eval 模式
model.eval)

# 情况2:中间结果被其他节点使用
class Model(nn.Module):
    def forward(self, x):
        conv_out = self.conv(x)
        bn_out = self.bn(conv_out)
        
        # conv_out 被其他节点使用,不能融合
        side_output = self.other_op(conv_out)
        
        return self.relu(bn_out), side_output

# 情况3:精度不匹配
# 某些融合在 FP16 下精度不达标,GE 会自动跳过

参考资源

  • 算子融合规则详解:https://www.hiascend.com/document/detail/zh/CANN/
  • GE 优化指南:https://www.hiascend.com/document/detail/zh/CANN/
  • FlashAttention 实现:https://atomgit.com/cann/ops-transformer
  • 性能调优最佳实践:https://www.hiascend.com/document/detail/zh/CANN/

总结

GE 的算子融合分三类:Conv+BN+ReLU 这种前后依赖的算子合成一个、ElementWise 多个逐元素操作合成一个 kernel、FlashAttention 这种复杂模式识别后优化。融合的触发条件是数据流连续、没有副作用、精度可控。通过 --enable_fusion 控制开关,编译日志里能看到融合详情。如果融合没触发,检查 BN 是否在 eval 模式、中间结果是否被复用、精度是否符合要求。融合后的模型节点数减少 30-50%,推理速度提升 20-40%。

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐