8.1 MoE 系列:从 GShard 到 Mixtral,再到 Switch Transformer

Mixture-of-Experts(MoE)是一类“条件计算(conditional computation)”结构:模型不再对每个 token 都执行同样的全量前馈计算,而是通过一个路由器(Router / Gate)为每个 token 选择少量专家(Experts)参与计算。这样可以在总参数量大幅增长的同时,让每个 token 的计算量保持接近不变(或增长很少)。

MoE 的核心由三部分构成:

  • Router / Gate:对每个 token 计算其分配到各个 expert 的概率(或得分),并选择 top-kkk 个 expert。
  • Experts:通常是一组并行的前馈网络(FFN / SwiGLU 等),每个 expert 拥有独立参数。
  • 容量与负载均衡:由于某些 expert 可能被大量 token 选择,必须限制每个 expert 的容量(capacity)并用负载均衡损失避免路由塌缩。

下面分别以 GShard(Top-2)Mixtral(Top-2 SwiGLU Experts)、**Switch Transformer(Top-1)**为主线,详细展开。


1. GShard:Conditional Computation + Automatic Sharding 的 MoE 训练范式

1.1 Router / Gate:从 token 表示到 expert 概率分布

设一个 batch 中共有 SSS 个 token(例如把 batch size 和 seq len 展开后的 token 总数),token 的隐藏维度为 MMM,expert 数量为 EEE

  • token 表示矩阵:X∈RS×MX \in \mathbb{R}^{S \times M}XRS×M
  • Gate(线性层)权重:Wg∈RM×EW_g \in \mathbb{R}^{M \times E}WgRM×E

Gate 输出每个 token 到每个 expert 的logits,再经 softmax 得到概率:

Z=XWg∈RS×E Z = X W_g \in \mathbb{R}^{S \times E} Z=XWgRS×E

P=softmax(Z)∈RS×E P = \mathrm{softmax}(Z) \in \mathbb{R}^{S \times E} P=softmax(Z)RS×E

其中 Ps,eP_{s,e}Ps,e 表示第 sss 个 token 分配给第 eee 个 expert 的概率。

GShard 使用 Top-2 routing:对每个 token 只选择概率最大的两个 expert(记作 e1,e2e_1, e_2e1,e2),并用这两个概率作为混合权重参与输出聚合。


1.2 Top-2 的直观案例:8 个 token、4 个 expert

假设:

  • token 数 S=8S=8S=8
  • expert 数 E=4E=4E=4
  • Top-kkkk=2k=2k=2

对每个 token,Gate 会输出一个长度为 4 的概率向量,例如:

  • token0:[0.55,0.30,0.10,0.05][0.55, 0.30, 0.10, 0.05][0.55,0.30,0.10,0.05]
    Top-2 为 expert0(0.55)与 expert1(0.30)
  • token1:[0.05,0.10,0.80,0.05][0.05, 0.10, 0.80, 0.05][0.05,0.10,0.80,0.05]
    Top-2 为 expert2(0.80)与 expert1(0.10)

每个 token 计算时只会被送入两个 expert,得到两个输出,然后按权重做加权和。


2. Expert 与溢出(Overflow):为什么必须要 capacity + drop tokens

2.1 Expert buffer 与容量(capacity)的来源

如果路由是完美均匀的,那么平均每个 expert 接收的 token 数为:

SE \frac{S}{E} ES

但 Top-2 routing 会让每个 token 进入两个 expert,因此在理想均匀情况下,每个 expert 的期望接收量会乘以 KKK

SE⋅K \frac{S}{E} \cdot K ESK

为了给波动留出空间,引入 capacity factor(容量系数)以及最小容量 min_capacitymin\_capacitymin_capacity,典型形式为:

capacity=max⁡(SE⋅K⋅capacity_factor,  min_capacity) capacity = \max\left(\frac{S}{E} \cdot K \cdot capacity\_factor,\; min\_capacity\right) capacity=max(ESKcapacity_factor,min_capacity)

  • SSS:token 总数
  • EEE:expert 数
  • KKK:每 token 选择的 expert 数(GShard Top-2 则 K=2K=2K=2
  • capacity_factorcapacity\_factorcapacity_factor:>1 时允许更宽松的容量,减少溢出;<1 时更严格但更省通信与计算
  • min_capacitymin\_capacitymin_capacity:防止小 batch 时容量过小
具体案例(与你给的例子一致)
  • S=8S=8S=8E=4E=4E=4K=2K=2K=2
  • 理想均匀时,每个 expert 接收上限可先按 SE⋅K=4\frac{S}{E}\cdot K = 4ESK=4
  • 因此 expert buffer 长度可以设为 4(每个 expert 最多容纳 4 个 token)

2.2 Overflow 时怎么做:两种情况

当某个 token 的某个目标 expert 的 buffer 已满,就发生 overflow。GShard 的处理思路是:尽量让 token 至少被一个 expert 正常处理,否则就“跳过 MoE”。

情况 A:只有一个 expert 溢出(比如 2nd expert 满了)
  • token 的 1st expert 仍然可用
  • 那么直接把该 token 的权重全部给 1st expert,相当于:

若原先 top-2 权重为 (w1,w2)(w_1, w_2)(w1,w2),且 2nd expert 溢出,则变为:

(w1′,w2′)=(1,0) (w_1', w_2') = (1, 0) (w1,w2)=(1,0)

该 token 仍然通过 1st expert 的输出进入后续计算。

情况 B:两个 expert 都溢出(1st、2nd 都满)

此时 token 无法通过任何 expert 计算,GShard 采取“残差透传”:

  • 该 token 在 MoE 子层不做 FFN 变换
  • 直接把输入原样通过残差连接送到下一层(例如下一层 attention)

可以把它理解为:MoE 子层对该 token 的增量为 0。


2.3 Drop tokens 与 Zero padding

Overflow 会导致部分 token 没有被 expert 处理(或只被一个 expert 处理),这一类策略常被统称为 drop tokens(对溢出的 token 做丢弃/退化处理)。

同时,为了实现高效并行(尤其在跨设备的 all-to-all 通信与矩阵计算中),每个 expert 通常会按固定 capacity 建立 buffer:

  • 没被填满的位置用全 0 向量补齐(Zero padding)
  • 这样每个 expert 的输入张量尺寸固定,便于批处理与设备间对齐

3. Random Routing:为什么要给 2nd expert 加噪声

Top-2 的一个常见问题是:如果对 2nd expert 也完全按最大概率选,很容易出现:

  • 大量 token 的 Top-2 组合高度相似
  • 某些 expert 长期作为“热门 2nd”,造成不稳定或偏置
  • 训练早期容易路由塌缩:少数 expert 承担几乎所有 token

因此 GShard 在选择 2nd expert 时引入随机性:
1st expert 固定为最大概率 expert
2nd expert 在 logits 上加噪声后再选最大(并 mask 掉 1st expert)。

一个常见的描述流程:

  1. 得到未加噪 logits:z∈REz \in \mathbb{R}^{E}zRE
  2. 采样噪声向量 ϵ∈RE\epsilon \in \mathbb{R}^{E}ϵRE(实现上可能多次采样、取某种聚合,核心是“扰动 logits”)
  3. 得到扰动 logits:z′=z+ϵz' = z + \epsilonz=z+ϵ
  4. mask 掉 1st expert 的位置
  5. 在剩余 expert 中选择 z′z'z 最大者作为 2nd expert

重要点:最终用于混合权重的仍是原始 softmax 概率(而不是加噪后的概率),并在 top-2 上做归一化:

若选择的两个 expert 在原始分布上的概率为 P0,P1P_0, P_1P0,P1,则权重为:

w0=P0P0+P1,w1=P1P0+P1 w_0 = \frac{P_0}{P_0 + P_1}, \quad w_1 = \frac{P_1}{P_0 + P_1} w0=P0+P1P0,w1=P0+P1P1

最终 token 输出为两个 expert 输出的加权和:

y=w0⋅E0(x)+w1⋅E1(x) y = w_0 \cdot E_0(x) + w_1 \cdot E_1(x) y=w0E0(x)+w1E1(x)

其中 E0(⋅),E1(⋅)E_0(\cdot), E_1(\cdot)E0(),E1() 是被选中的两个专家网络。


4. GShard 的负载均衡损失:Auxiliary Loss 的含义与推导直觉

4.1 为什么需要负载均衡

如果 Gate 学到“总是把所有 token 送到少数几个 expert”,会带来:

  • 这些 expert 的 buffer 频繁溢出 → drop tokens 增多 → 训练信号变差
  • 其他 expert 基本不训练 → 参数浪费
  • 通信与计算分布极不均匀 → 性能差、吞吐不稳定

因此要在训练目标中加入“让路由更均匀”的正则项,即 Auxiliary Loss。


4.2 你给出的 lauxl_{aux}laux:每一项代表什么

GShard 的一种常见形式为:

laux=1E∑e=1EceS⋅me l_{aux} = \frac{1}{E} \sum_{e = 1}^{E} \frac{c_e}{S} \cdot m_e laux=E1e=1ESceme

解释每个量:

  • EEE:expert 数
  • SSS:token 总数
  • cec_ece:第 eee 个 expert buffer 中被放入的 token 数(常以“作为 1st expert 接收的 token 数”统计)
  • mem_eme:这些 token 在该 expert 上的平均权重(avg(weight))

直觉可以理解为:

  • ceS\frac{c_e}{S}Sce 是“这个 expert 接到了多少 token 的比例”
  • mem_eme 是“这些 token 对这个 expert 的平均路由强度”
  • 两者相乘:表示“该 expert 的有效负载强度”
  • 对所有 expert 平均:鼓励每个 expert 的有效负载接近均衡

5. Mixtral of Experts:Top-2 + SwiGLU Experts 的工业级 MoE

Mixtral 8x7B 是典型的稀疏 MoE:每层 FFN 由一组专家 FFN 组成,路由器对每个 token 选择 2 个专家进行计算。其关键卖点是:

  • 总参数量大:例如 8 个专家累加后总参数显著增大
  • 每 token 只激活其中 2 个专家:因此推理成本更接近“激活参数规模”,而不是总参数规模

常见的表达是:

  • total parameters 很大(例如 46.7B)
  • but only uses a subset per token(例如每 token 约 12.9B 参与计算)

这体现了 MoE 的核心优势:参数扩展与计算扩展解耦


5.1 Mixtral 的层内计算结构

一个标准 Transformer 层中:

  1. Attention 子层:自注意力 + 残差
  2. FFN 子层:在 Mixtral 中由 MoE 代替(多个 expert FFN + router)

即:把每层的 FFN 替换成 MoE 层。

Mixtral 中 expert 常采用 SwiGLU 结构,并设置 K=2K=2K=2(Top-2)。


5.2 Mixtral 的 Top-2 聚合公式(你给出的版本)

设 token 表示为 xxx,router 权重为 WgW_gWg,router logits 为 x⋅Wgx \cdot W_gxWg。对 logits 做 top-2,并在 top-2 上做 softmax 得到权重,再乘以对应 expert 的 SwiGLU 输出,最终求和:

y=∑i=0n−1Softmax(Top2(x⋅Wg))i⋅SwiGLUi(x) y = \sum_{i=0}^{n-1} \mathrm{Softmax}\left(\mathrm{Top}2(x \cdot W_g)\right)_i \cdot \mathrm{SwiGLU}_i(x) y=i=0n1Softmax(Top2(xWg))iSwiGLUi(x)

理解这条式子时,可以按以下步骤读:

  1. x⋅Wgx \cdot W_gxWg:得到对所有 expert 的打分
  2. Top2(⋅)\mathrm{Top}2(\cdot)Top2():只保留分数最高的两个 expert(其余置为 −∞-\infty 或 mask 掉)
  3. Softmax(⋅)\mathrm{Softmax}(\cdot)Softmax():只在这两个 expert 上归一化得到权重
  4. SwiGLUi(x)\mathrm{SwiGLU}_i(x)SwiGLUi(x):第 iii 个 expert 对 token 的 FFN 变换
  5. 加权求和:得到 MoE 子层输出

6. MoE 推理侧优化:以 Sliding Window Attention 为核心

MoE 本身节省的是 FFN 的计算,但推理时大模型的主要瓶颈还包括 KV cache 的显存占用与带宽压力,尤其在因果解码(causal decoder)下,注意力对历史 token 的依赖使 cache 随序列长度增长。

6.1 Sliding Window Attention:把“看到全部历史”改为“只看最近 WWW 个 token”

在标准 causal attention 中,每个 token 都可以 attend 到它之前的全部 token,因此 KV cache 的大小与序列长度 LLL 近似线性相关。

Sliding Window 的核心改动是:对位置 ttt 的 token,只允许其 attend 到区间 [t−W+1,t][t-W+1, t][tW+1,t] 的 token:

  • WWW 为窗口大小(window size)
  • cache 的有效长度被限制为 WWW,显存压力与 LLL 脱钩,转为与 WWW 成正比

直觉解释:距离越远的上下文对当前 token 的贡献往往越弱,因此限制注意力范围可以换取显存与吞吐收益。


6.2 “看不全历史”会不会丢信息:用“感受野”理解深层网络

即便单层注意力窗口有限,但模型足够深时,不同层可以逐步传播信息:

  • 第 1 层:只能把信息在局部窗口内聚合
  • 第 2 层:上一层的聚合结果又被下一层在局部窗口内再聚合
  • 多层叠加后,信息传播范围扩大,类似 CNN 的感受野随层数增长

因此 Sliding Window 往往不会像“硬截断上下文”那样简单粗暴,其影响与深度、任务类型、窗口大小有关。


6.3 Rolling Buffer Cache:用循环数组存 KV

当窗口大小固定为 WWW 时,可以用“循环下标”复用 KV cache 的存储槽位。

若 prompt 的第 iii 个 token 的 KV 被写入缓存位置:

i mod W i \bmod W imodW

当序列继续增长时,新的 token 会覆盖最旧的 token 的 cache 槽位,从而保证 cache 总大小始终为 WWW 对应的规模。这就是 Rolling Buffer Cache 的基本思想。


6.4 Chunking:把超长 prompt 切块喂入

超长 prompt 会造成显存尖峰或吞吐下降。Chunking 的做法是:

  • 将 prompt 切成若干块(chunk)
  • 每次只喂 1 个 chunk,更新一次 KV cache
  • 常见设置是让 chunk size 与 window size 对齐:

chunk_size=W chunk\_size = W chunk_size=W

这样每次更新刚好填满一个窗口,工程实现更一致,也更利于吞吐稳定。


7. Switch Transformer:Top-1(k=1)MoE 的极简化与可扩展性

Switch Transformer 基于 T5 的 encoder-decoder 架构,把每层 FFN 替换为 MoE,但它做了一个关键简化:

  • 不再 Top-2,而是 Top-1(k=1k=1k=1
  • 每个 token 只路由到 1 个 expert

这样的 MoE 层常被称为 Switch layer

7.1 为什么 Top-1 会很重要

Top-1 的直接好处:

  • 通信量更小(只发往 1 个 expert)
  • 计算更省(只算 1 个 expert)
  • 实现更简单、更容易扩展到非常多 expert(例如上千个)

Switch 的核心观点是:在保持 FLOPs/token 大致不变的情况下,通过增加 expert 数与总参数量,模型能力可以继续提升。


7.2 Switch 的容量定义:capacity factor 的取舍

Switch 中常见的 expert capacity 设定为:

expert capacity=(tokens per batchnumber of experts)×capacity factor expert\ capacity = \left(\frac{tokens\ per\ batch}{number\ of\ experts}\right) \times capacity\ factor expert capacity=(number of expertstokens per batch)×capacity factor

  • capacity factor 越大:每个 expert 能接收更多 token → overflow 更少,但计算/通信更重
  • capacity factor 越小:更省资源,但更容易 overflow → drop tokens 更多,影响训练

经验上,为了训练超大规模模型,需要把 token 丢弃率控制得足够低,而负载均衡损失能让较小 capacity factor 也能维持较低 overflow。


7.3 Switch 的负载均衡损失:fif_ifi 不可导、PiP_iPi 可导 的组合设计

设共有 NNN 个 expert,一个 batch B\mathcal{B}B 中共有 TTT 个 token。

Switch 的负载均衡损失常写为:

loss=α⋅N⋅∑i=1Nfi⋅Pi loss = \alpha \cdot N \cdot \sum_{i = 1}^{N} f_i \cdot P_i loss=αNi=1NfiPi

其中:

  • fif_ifi:实际被分配到第 iii 个 expert 的 token 比例(不可导)

fi=1T∑x∈B1{arg⁡max⁡p(x)=i} f_i = \frac{1}{T} \sum_{x \in \mathcal{B}} \mathbf{1}\left\{\arg\max p(x) = i\right\} fi=T1xB1{argmaxp(x)=i}

  • PiP_iPi:batch 内 token 分配给 expert iii 的概率总和(可导)

Pi=1T∑x∈Bpi(x) P_i = \frac{1}{T} \sum_{x \in \mathcal{B}} p_i(x) Pi=T1xBpi(x)

这个设计为什么合理
  • fif_ifi 反映“硬路由”的真实负载,但不可导
  • PiP_iPi 反映“软概率”的期望负载,可导
  • fi⋅Pif_i \cdot P_ifiPi 把“真实负载”与“可导信号”绑定在一起,让梯度能推动概率分布向均衡方向调整
为什么乘一个 NNN

当完全均衡时:

  • fi=1Nf_i = \frac{1}{N}fi=N1
  • Pi=1NP_i = \frac{1}{N}Pi=N1

则:

∑i=1NfiPi=∑i=1N1N⋅1N=1N \sum_{i=1}^{N} f_i P_i = \sum_{i=1}^{N} \frac{1}{N}\cdot\frac{1}{N} = \frac{1}{N} i=1NfiPi=i=1NN1N1=N1

再乘一个 NNN

N⋅∑i=1NfiPi=1 N \cdot \sum_{i=1}^{N} f_i P_i = 1 Ni=1NfiPi=1

这样在“理想均衡”时,loss 会保持一个与 expert 数无关的常数量级,便于在不同规模下复用 α\alphaα 的经验范围。实践中 α\alphaα 常在 10−510^{-5}10510−110^{-1}101 扫描,发现 10−210^{-2}102 往往足以维持负载均衡且不过度干扰收敛。


8. 三者对比总结:你需要抓住的关键差异

8.1 路由粒度与稀疏度

  • GShard:Top-2(K=2K=2K=2),并对 2nd expert 引入 random routing
  • Mixtral:Top-2(K=2K=2K=2),expert 采用 SwiGLU,强调“总参大但每 token 只激活少量参数”
  • Switch:Top-1(K=1K=1K=1),极简路由,利于扩展到超多 expert

8.2 capacity / overflow 的一致性逻辑

三者都必须面对同一个工程现实:

  • 某些 expert 会变“热门”
  • capacity 不够就 overflow
  • overflow 会导致 drop tokens 或退化路径
  • 因此必须通过负载均衡损失与 capacity factor 做权衡
Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐