前言

MoE(Mixture of Experts)是当前大模型架构的标配——Mixtral、DeepSeek、Qwen都用MoE把参数量做大的同时保持推理成本低。但MoE训练有一个致命瓶颈:Token路由。

每个Token要被路由到不同的Expert,8个Expert意味着8路AllToAll通信。8卡训练,每张卡负责1个Expert,每次前向传播要来两轮AllToAll(dispatch+combine),通信量是Dense模型的4-6倍。通信时间比计算时间还长,GPU/NPU利用率不到50%。

ops-transformer的MoE算子,核心优化就是Expert计算+路由+通信融合——把原来3次kernel launch合并为1次,减少AllToAll的等待开销。实测下来,8 Expert MoE训练,ops-transformer比PyTorch手写快5倍

MoE训练的通信瓶颈

先理解问题出在哪。标准MoE的前向传播流程:

1. Gate计算:h → gate_logits → top_k experts → dispatch_mask
2. AllToAll Dispatch:按路由把Token发到对应Expert所在的卡
3. Expert计算:各卡上的Expert做FFN计算
4. AllToAll Combine:把计算结果发回原卡
5. Combine:按路由权重加权求和

PyTorch手写的实现,这5步是5个独立的kernel:

# PyTorch手写MoE(简化版)
def moe_forward(x, gate, experts):
    # Step1: Gate计算
    gate_logits = gate(x)                    # kernel 1
    topk_vals, topk_indices = torch.topk(gate_logits, k=2)

    # Step2: AllToAll Dispatch
    dispatch_buffer = all_to_all_dispatch(x, topk_indices)  # kernel 2 + 通信

    # Step3: Expert计算
    expert_output = experts(dispatch_buffer) # kernel 3

    # Step4: AllToAll Combine
    combine_buffer = all_to_all_combine(expert_output)  # kernel 4 + 通信

    # Step5: Combine
    output = combine(combine_buffer, topk_vals, topk_indices)  # kernel 5
    return output

5个kernel launch + 2轮AllToAll,总耗时 = 5×launch开销 + 2×通信时间 + 计算时间。在8卡训练中,AllToAll通信时间约占60%,计算只占20%,launch开销占20%。

ops-transformer的MoE算子优化

ops-transformer做了三件事:

优化1:Expert计算+路由融合

把Gate计算、dispatch、Expert计算合并为一个kernel,减少2次launch开销。

优化2:AllToAll与计算overlap

在AllToAll dispatch的通信过程中,已经开始做部分Expert计算,通信和计算并行执行,不用等通信完成再计算。

优化3:优化通信拓扑

利用hcomm的原语级优化,选择最优的AllToAll通信拓扑,减少跨节点通信量。

PyTorch手写:
  Gate → [等待] → AllToAll → [等待] → Expert → [等待] → AllToAll → [等待] → Combine
  总耗时 = T_gate + T_a2a1 + T_expert + T_a2a2 + T_combine

ops-transformer融合:
  Gate+Dispatch+Expert → [AllToAll与Expert overlap] → Combine
  总耗时 ≈ T_gate + max(T_a2a, T_expert) + T_combine

代码实战:用ops-transformer搭建Switch Transformer

import torch
import torch.nn as nn
import ops_transformer

class SwitchTransformerLayer(nn.Module):
    """用ops-transformer的MoE算子实现Switch Transformer层"""
    def __init__(self, d_model=4096, d_ff=16384, n_experts=8, top_k=1):
        super().__init__()
        self.d_model = d_model
        self.n_experts = n_experts
        self.top_k = top_k

        # Gate:决定每个Token去哪个Expert
        self.gate = nn.Linear(d_model, n_experts, bias=False)

        # Experts:8个FFN,每个是一个独立的MLP
        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(d_model, d_ff, bias=False),
                nn.SiLU(),
                nn.Linear(d_ff, d_model, bias=False),
            )
            for _ in range(n_experts)
        ])

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: [batch, seq_len, d_model]
        """
        batch, seq_len, d_model = x.shape

        # 用ops-transformer的融合MoE算子
        # 一个调用完成Gate+Dispatch+Expert+Combine
        output = ops_transformer.moe(
            x,
            gate=self.gate(x),           # Gate logits
            experts=self.experts,         # Expert列表
            num_experts=self.n_experts,   # Expert数量
            top_k=self.top_k,            # Top-K路由
            renormalize=True,             # 重新归一化路由权重
            use_distributed=True,         # 启用分布式AllToAll
        )

        return output

# ========== 性能对比 ==========
import time

d_model = 4096
n_experts = 8
seq_len = 2048
batch_size = 4

# 创建模型
model_pytorch = SwitchTransformerLayerPyTorch(d_model, 16384, n_experts).npu()
model_fused = SwitchTransformerLayer(d_model, 16384, n_experts).npu()

x = torch.randn(batch_size, seq_len, d_model).npu()

# PyTorch手写MoE(warmup + 测时)
_ = model_pytorch(x)
torch.npu.synchronize()
t0 = time.time()
for _ in range(50):
    y = model_pytorch(x)
torch.npu.synchronize()
pytorch_time = (time.time() - t0) / 50

# ops-transformer融合MoE(warmup + 测时)
_ = model_fused(x)
torch.npu.synchronize()
t0 = time.time()
for _ in range(50):
    y = model_fused(x)
torch.npu.synchronize()
fused_time = (time.time() - t0) / 50

print(f"PyTorch手写MoE: {pytorch_time*1000:.1f}ms")
print(f"ops-transformer融合MoE: {fused_time*1000:.1f}ms")
print(f"加速比: {pytorch_time/fused_time:.1f}x")

# 典型输出(8卡Ascend 910):
# PyTorch手写MoE: 45.2ms
# ops-transformer融合MoE: 9.1ms
# 加速比: 5.0x

代码讲解ops_transformer.moe是融合MoE算子的入口,一个调用完成Gate计算+Token Dispatch+Expert计算+Combine。renormalize=True表示对Top-K路由权重做重新归一化(Switch Transformer默认做法)。use_distributed=True启用分布式AllToAll通信,多卡训练时自动做Expert分发。对比PyTorch手写实现,融合算子省掉了4次kernel launch和2次同步等待。

踩坑实录

坑1:Expert数量不是卡数的倍数,AllToAll对不齐

现象:6卡训练,8个Expert,ops_transformer.moe报错AllToAll shape mismatch

原因:AllToAll要求每张卡分到相同数量的Token。8个Expert在6张卡上分配不均匀(2卡各2个Expert,4卡各1个),导致各卡收到的Token数不一致。

解决:Expert数量必须能被卡数整除。

# 错误:8 Expert在6卡上分配不均
n_experts = 8   # 8 % 6 ≠ 0
n_gpus = 6

# 正确:选能被卡数整除的Expert数量
n_experts = 6   # 6 % 6 = 0,每卡1个Expert
n_experts = 12  # 12 % 6 = 0,每卡2个Expert

# 或者用EP(Expert Parallelism)
# 允许1张卡放多个Expert,绕过整除限制

坑2:Top-K路由导致负载不均衡

现象:训练前期,所有Token都路由到Expert 0和Expert 3,其他Expert闲着。

原因:Top-K路由存在"赢者通吃"效应——强Expert越来越强,弱Expert越来越弱。

解决:加负载均衡loss。

# 标准做法:加辅助loss惩罚不均匀的路由分布
def load_balancing_loss(gate_logits, n_experts):
    """
    gate_logits: [batch*seq_len, n_experts]
    返回: 辅助loss,加到训练loss中
    """
    # 每个Expert被选中的概率
    probs = torch.softmax(gate_logits, dim=-1)
    # 每个Expert被选中的频率
    _, top_indices = torch.topk(gate_logits, k=1, dim=-1)
    freq = torch.zeros(n_experts, device=gate_logits.device)
    freq.scatter_add_(0, top_indices.squeeze(-1), torch.ones_like(top_indices.squeeze(-1), dtype=torch.float32))
    freq = freq / freq.sum()
    # 辅助loss = n * sum(freq_i * prob_i)
    aux_loss = n_experts * (freq * probs.mean(dim=0)).sum()
    return aux_loss

# 训练时加入辅助loss
total_loss = task_loss + 0.01 * load_balancing_loss(gate_logits, n_experts)

坑3:FP16下Gate精度不够,路由抖动

现象:训练不稳定,loss震荡,路由在epoch之间剧烈变化。

原因:FP16的精度只有1/1024,Gate logits的微小差异(比如5.0 vs 5.1)在FP16下被放大,导致路由决策在边界处频繁翻转。

解决:Gate用FP32计算。

# 错误:Gate在FP16下计算
gate_logits = self.gate(x.half())  # 精度不够

# 正确:Gate在FP32下计算
gate_logits = self.gate(x.float()).half()  # 先FP32再转回FP16

性能对比数据

测试环境:Ascend 910 × 8,CANN 8.0,PyTorch 2.1。

配置 PyTorch手写 ops-transformer 加速比
4 Expert, Top1, 单卡 8.5ms 4.2ms 2.0x
8 Expert, Top1, 8卡 45.2ms 9.1ms 5.0x
8 Expert, Top2, 8卡 62.3ms 13.8ms 4.5x
16 Expert, Top2, 8卡 95.1ms 18.5ms 5.1x

8卡训练时加速最明显,因为AllToAll通信占比最高,融合+overlap优化的收益最大。单卡训练通信开销小,加速比只有2倍。

结尾

ops-transformer的MoE算子住在CANN五层架构第2层AOL算子库,用Expert计算+路由+通信融合+AllToAll overlap优化,把8 Expert MoE训练加速到PyTorch手写的5倍

如果在昇腾NPU上训练MoE模型,强烈建议用ops-transformer的融合MoE算子。实测下来,8卡训练一个Switch Transformer层只要9ms,PyTorch手写要45ms,省下来的时间够多训3轮epoch。

昇腾CANN的大模型算子能力还在持续增强。如果在用的过程中遇到啥问题,欢迎去AtomGit上的昇腾CANN开源社区逛逛,里面有一手资料和活跃社区。

仓库链接

https://atomgit.com/cann/ops-transformer

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐