8.1 MoE 系列:从 GShard 到 Mixtral,再到 Switch Transformer
8.1 MoE 系列:从 GShard 到 Mixtral,再到 Switch Transformer
Mixture-of-Experts(MoE)是一类“条件计算(conditional computation)”结构:模型不再对每个 token 都执行同样的全量前馈计算,而是通过一个路由器(Router / Gate)为每个 token 选择少量专家(Experts)参与计算。这样可以在总参数量大幅增长的同时,让每个 token 的计算量保持接近不变(或增长很少)。
MoE 的核心由三部分构成:
- Router / Gate:对每个 token 计算其分配到各个 expert 的概率(或得分),并选择 top-kkk 个 expert。
- Experts:通常是一组并行的前馈网络(FFN / SwiGLU 等),每个 expert 拥有独立参数。
- 容量与负载均衡:由于某些 expert 可能被大量 token 选择,必须限制每个 expert 的容量(capacity)并用负载均衡损失避免路由塌缩。
下面分别以 GShard(Top-2)、Mixtral(Top-2 SwiGLU Experts)、**Switch Transformer(Top-1)**为主线,详细展开。
1. GShard:Conditional Computation + Automatic Sharding 的 MoE 训练范式
1.1 Router / Gate:从 token 表示到 expert 概率分布
设一个 batch 中共有 SSS 个 token(例如把 batch size 和 seq len 展开后的 token 总数),token 的隐藏维度为 MMM,expert 数量为 EEE。
- token 表示矩阵:X∈RS×MX \in \mathbb{R}^{S \times M}X∈RS×M
- Gate(线性层)权重:Wg∈RM×EW_g \in \mathbb{R}^{M \times E}Wg∈RM×E
Gate 输出每个 token 到每个 expert 的logits,再经 softmax 得到概率:
Z=XWg∈RS×E Z = X W_g \in \mathbb{R}^{S \times E} Z=XWg∈RS×E
P=softmax(Z)∈RS×E P = \mathrm{softmax}(Z) \in \mathbb{R}^{S \times E} P=softmax(Z)∈RS×E
其中 Ps,eP_{s,e}Ps,e 表示第 sss 个 token 分配给第 eee 个 expert 的概率。
GShard 使用 Top-2 routing:对每个 token 只选择概率最大的两个 expert(记作 e1,e2e_1, e_2e1,e2),并用这两个概率作为混合权重参与输出聚合。
1.2 Top-2 的直观案例:8 个 token、4 个 expert
假设:
- token 数 S=8S=8S=8
- expert 数 E=4E=4E=4
- Top-kkk 中 k=2k=2k=2
对每个 token,Gate 会输出一个长度为 4 的概率向量,例如:
- token0:[0.55,0.30,0.10,0.05][0.55, 0.30, 0.10, 0.05][0.55,0.30,0.10,0.05]
Top-2 为 expert0(0.55)与 expert1(0.30) - token1:[0.05,0.10,0.80,0.05][0.05, 0.10, 0.80, 0.05][0.05,0.10,0.80,0.05]
Top-2 为 expert2(0.80)与 expert1(0.10)
每个 token 计算时只会被送入两个 expert,得到两个输出,然后按权重做加权和。
2. Expert 与溢出(Overflow):为什么必须要 capacity + drop tokens
2.1 Expert buffer 与容量(capacity)的来源
如果路由是完美均匀的,那么平均每个 expert 接收的 token 数为:
SE \frac{S}{E} ES
但 Top-2 routing 会让每个 token 进入两个 expert,因此在理想均匀情况下,每个 expert 的期望接收量会乘以 KKK:
SE⋅K \frac{S}{E} \cdot K ES⋅K
为了给波动留出空间,引入 capacity factor(容量系数)以及最小容量 min_capacitymin\_capacitymin_capacity,典型形式为:
capacity=max(SE⋅K⋅capacity_factor, min_capacity) capacity = \max\left(\frac{S}{E} \cdot K \cdot capacity\_factor,\; min\_capacity\right) capacity=max(ES⋅K⋅capacity_factor,min_capacity)
- SSS:token 总数
- EEE:expert 数
- KKK:每 token 选择的 expert 数(GShard Top-2 则 K=2K=2K=2)
- capacity_factorcapacity\_factorcapacity_factor:>1 时允许更宽松的容量,减少溢出;<1 时更严格但更省通信与计算
- min_capacitymin\_capacitymin_capacity:防止小 batch 时容量过小
具体案例(与你给的例子一致)
- S=8S=8S=8,E=4E=4E=4,K=2K=2K=2
- 理想均匀时,每个 expert 接收上限可先按 SE⋅K=4\frac{S}{E}\cdot K = 4ES⋅K=4
- 因此 expert buffer 长度可以设为 4(每个 expert 最多容纳 4 个 token)
2.2 Overflow 时怎么做:两种情况
当某个 token 的某个目标 expert 的 buffer 已满,就发生 overflow。GShard 的处理思路是:尽量让 token 至少被一个 expert 正常处理,否则就“跳过 MoE”。
情况 A:只有一个 expert 溢出(比如 2nd expert 满了)
- token 的 1st expert 仍然可用
- 那么直接把该 token 的权重全部给 1st expert,相当于:
若原先 top-2 权重为 (w1,w2)(w_1, w_2)(w1,w2),且 2nd expert 溢出,则变为:
(w1′,w2′)=(1,0) (w_1', w_2') = (1, 0) (w1′,w2′)=(1,0)
该 token 仍然通过 1st expert 的输出进入后续计算。
情况 B:两个 expert 都溢出(1st、2nd 都满)
此时 token 无法通过任何 expert 计算,GShard 采取“残差透传”:
- 该 token 在 MoE 子层不做 FFN 变换
- 直接把输入原样通过残差连接送到下一层(例如下一层 attention)
可以把它理解为:MoE 子层对该 token 的增量为 0。
2.3 Drop tokens 与 Zero padding
Overflow 会导致部分 token 没有被 expert 处理(或只被一个 expert 处理),这一类策略常被统称为 drop tokens(对溢出的 token 做丢弃/退化处理)。
同时,为了实现高效并行(尤其在跨设备的 all-to-all 通信与矩阵计算中),每个 expert 通常会按固定 capacity 建立 buffer:
- 没被填满的位置用全 0 向量补齐(Zero padding)
- 这样每个 expert 的输入张量尺寸固定,便于批处理与设备间对齐
3. Random Routing:为什么要给 2nd expert 加噪声
Top-2 的一个常见问题是:如果对 2nd expert 也完全按最大概率选,很容易出现:
- 大量 token 的 Top-2 组合高度相似
- 某些 expert 长期作为“热门 2nd”,造成不稳定或偏置
- 训练早期容易路由塌缩:少数 expert 承担几乎所有 token
因此 GShard 在选择 2nd expert 时引入随机性:
1st expert 固定为最大概率 expert;
2nd expert 在 logits 上加噪声后再选最大(并 mask 掉 1st expert)。
一个常见的描述流程:
- 得到未加噪 logits:z∈REz \in \mathbb{R}^{E}z∈RE
- 采样噪声向量 ϵ∈RE\epsilon \in \mathbb{R}^{E}ϵ∈RE(实现上可能多次采样、取某种聚合,核心是“扰动 logits”)
- 得到扰动 logits:z′=z+ϵz' = z + \epsilonz′=z+ϵ
- mask 掉 1st expert 的位置
- 在剩余 expert 中选择 z′z'z′ 最大者作为 2nd expert
重要点:最终用于混合权重的仍是原始 softmax 概率(而不是加噪后的概率),并在 top-2 上做归一化:
若选择的两个 expert 在原始分布上的概率为 P0,P1P_0, P_1P0,P1,则权重为:
w0=P0P0+P1,w1=P1P0+P1 w_0 = \frac{P_0}{P_0 + P_1}, \quad w_1 = \frac{P_1}{P_0 + P_1} w0=P0+P1P0,w1=P0+P1P1
最终 token 输出为两个 expert 输出的加权和:
y=w0⋅E0(x)+w1⋅E1(x) y = w_0 \cdot E_0(x) + w_1 \cdot E_1(x) y=w0⋅E0(x)+w1⋅E1(x)
其中 E0(⋅),E1(⋅)E_0(\cdot), E_1(\cdot)E0(⋅),E1(⋅) 是被选中的两个专家网络。
4. GShard 的负载均衡损失:Auxiliary Loss 的含义与推导直觉
4.1 为什么需要负载均衡
如果 Gate 学到“总是把所有 token 送到少数几个 expert”,会带来:
- 这些 expert 的 buffer 频繁溢出 → drop tokens 增多 → 训练信号变差
- 其他 expert 基本不训练 → 参数浪费
- 通信与计算分布极不均匀 → 性能差、吞吐不稳定
因此要在训练目标中加入“让路由更均匀”的正则项,即 Auxiliary Loss。
4.2 你给出的 lauxl_{aux}laux:每一项代表什么
GShard 的一种常见形式为:
laux=1E∑e=1EceS⋅me l_{aux} = \frac{1}{E} \sum_{e = 1}^{E} \frac{c_e}{S} \cdot m_e laux=E1e=1∑ESce⋅me
解释每个量:
- EEE:expert 数
- SSS:token 总数
- cec_ece:第 eee 个 expert buffer 中被放入的 token 数(常以“作为 1st expert 接收的 token 数”统计)
- mem_eme:这些 token 在该 expert 上的平均权重(avg(weight))
直觉可以理解为:
- ceS\frac{c_e}{S}Sce 是“这个 expert 接到了多少 token 的比例”
- mem_eme 是“这些 token 对这个 expert 的平均路由强度”
- 两者相乘:表示“该 expert 的有效负载强度”
- 对所有 expert 平均:鼓励每个 expert 的有效负载接近均衡
5. Mixtral of Experts:Top-2 + SwiGLU Experts 的工业级 MoE
Mixtral 8x7B 是典型的稀疏 MoE:每层 FFN 由一组专家 FFN 组成,路由器对每个 token 选择 2 个专家进行计算。其关键卖点是:
- 总参数量大:例如 8 个专家累加后总参数显著增大
- 每 token 只激活其中 2 个专家:因此推理成本更接近“激活参数规模”,而不是总参数规模
常见的表达是:
- total parameters 很大(例如 46.7B)
- but only uses a subset per token(例如每 token 约 12.9B 参与计算)
这体现了 MoE 的核心优势:参数扩展与计算扩展解耦。
5.1 Mixtral 的层内计算结构
一个标准 Transformer 层中:
- Attention 子层:自注意力 + 残差
- FFN 子层:在 Mixtral 中由 MoE 代替(多个 expert FFN + router)
即:把每层的 FFN 替换成 MoE 层。
Mixtral 中 expert 常采用 SwiGLU 结构,并设置 K=2K=2K=2(Top-2)。
5.2 Mixtral 的 Top-2 聚合公式(你给出的版本)
设 token 表示为 xxx,router 权重为 WgW_gWg,router logits 为 x⋅Wgx \cdot W_gx⋅Wg。对 logits 做 top-2,并在 top-2 上做 softmax 得到权重,再乘以对应 expert 的 SwiGLU 输出,最终求和:
y=∑i=0n−1Softmax(Top2(x⋅Wg))i⋅SwiGLUi(x) y = \sum_{i=0}^{n-1} \mathrm{Softmax}\left(\mathrm{Top}2(x \cdot W_g)\right)_i \cdot \mathrm{SwiGLU}_i(x) y=i=0∑n−1Softmax(Top2(x⋅Wg))i⋅SwiGLUi(x)
理解这条式子时,可以按以下步骤读:
- x⋅Wgx \cdot W_gx⋅Wg:得到对所有 expert 的打分
- Top2(⋅)\mathrm{Top}2(\cdot)Top2(⋅):只保留分数最高的两个 expert(其余置为 −∞-\infty−∞ 或 mask 掉)
- Softmax(⋅)\mathrm{Softmax}(\cdot)Softmax(⋅):只在这两个 expert 上归一化得到权重
- SwiGLUi(x)\mathrm{SwiGLU}_i(x)SwiGLUi(x):第 iii 个 expert 对 token 的 FFN 变换
- 加权求和:得到 MoE 子层输出
6. MoE 推理侧优化:以 Sliding Window Attention 为核心
MoE 本身节省的是 FFN 的计算,但推理时大模型的主要瓶颈还包括 KV cache 的显存占用与带宽压力,尤其在因果解码(causal decoder)下,注意力对历史 token 的依赖使 cache 随序列长度增长。
6.1 Sliding Window Attention:把“看到全部历史”改为“只看最近 WWW 个 token”
在标准 causal attention 中,每个 token 都可以 attend 到它之前的全部 token,因此 KV cache 的大小与序列长度 LLL 近似线性相关。
Sliding Window 的核心改动是:对位置 ttt 的 token,只允许其 attend 到区间 [t−W+1,t][t-W+1, t][t−W+1,t] 的 token:
- WWW 为窗口大小(window size)
- cache 的有效长度被限制为 WWW,显存压力与 LLL 脱钩,转为与 WWW 成正比
直觉解释:距离越远的上下文对当前 token 的贡献往往越弱,因此限制注意力范围可以换取显存与吞吐收益。
6.2 “看不全历史”会不会丢信息:用“感受野”理解深层网络
即便单层注意力窗口有限,但模型足够深时,不同层可以逐步传播信息:
- 第 1 层:只能把信息在局部窗口内聚合
- 第 2 层:上一层的聚合结果又被下一层在局部窗口内再聚合
- 多层叠加后,信息传播范围扩大,类似 CNN 的感受野随层数增长
因此 Sliding Window 往往不会像“硬截断上下文”那样简单粗暴,其影响与深度、任务类型、窗口大小有关。
6.3 Rolling Buffer Cache:用循环数组存 KV
当窗口大小固定为 WWW 时,可以用“循环下标”复用 KV cache 的存储槽位。
若 prompt 的第 iii 个 token 的 KV 被写入缓存位置:
i mod W i \bmod W imodW
当序列继续增长时,新的 token 会覆盖最旧的 token 的 cache 槽位,从而保证 cache 总大小始终为 WWW 对应的规模。这就是 Rolling Buffer Cache 的基本思想。
6.4 Chunking:把超长 prompt 切块喂入
超长 prompt 会造成显存尖峰或吞吐下降。Chunking 的做法是:
- 将 prompt 切成若干块(chunk)
- 每次只喂 1 个 chunk,更新一次 KV cache
- 常见设置是让 chunk size 与 window size 对齐:
chunk_size=W chunk\_size = W chunk_size=W
这样每次更新刚好填满一个窗口,工程实现更一致,也更利于吞吐稳定。
7. Switch Transformer:Top-1(k=1)MoE 的极简化与可扩展性
Switch Transformer 基于 T5 的 encoder-decoder 架构,把每层 FFN 替换为 MoE,但它做了一个关键简化:
- 不再 Top-2,而是 Top-1(k=1k=1k=1)
- 每个 token 只路由到 1 个 expert
这样的 MoE 层常被称为 Switch layer。
7.1 为什么 Top-1 会很重要
Top-1 的直接好处:
- 通信量更小(只发往 1 个 expert)
- 计算更省(只算 1 个 expert)
- 实现更简单、更容易扩展到非常多 expert(例如上千个)
Switch 的核心观点是:在保持 FLOPs/token 大致不变的情况下,通过增加 expert 数与总参数量,模型能力可以继续提升。
7.2 Switch 的容量定义:capacity factor 的取舍
Switch 中常见的 expert capacity 设定为:
expert capacity=(tokens per batchnumber of experts)×capacity factor expert\ capacity = \left(\frac{tokens\ per\ batch}{number\ of\ experts}\right) \times capacity\ factor expert capacity=(number of expertstokens per batch)×capacity factor
- capacity factor 越大:每个 expert 能接收更多 token → overflow 更少,但计算/通信更重
- capacity factor 越小:更省资源,但更容易 overflow → drop tokens 更多,影响训练
经验上,为了训练超大规模模型,需要把 token 丢弃率控制得足够低,而负载均衡损失能让较小 capacity factor 也能维持较低 overflow。
7.3 Switch 的负载均衡损失:fif_ifi 不可导、PiP_iPi 可导 的组合设计
设共有 NNN 个 expert,一个 batch B\mathcal{B}B 中共有 TTT 个 token。
Switch 的负载均衡损失常写为:
loss=α⋅N⋅∑i=1Nfi⋅Pi loss = \alpha \cdot N \cdot \sum_{i = 1}^{N} f_i \cdot P_i loss=α⋅N⋅i=1∑Nfi⋅Pi
其中:
- fif_ifi:实际被分配到第 iii 个 expert 的 token 比例(不可导)
fi=1T∑x∈B1{argmaxp(x)=i} f_i = \frac{1}{T} \sum_{x \in \mathcal{B}} \mathbf{1}\left\{\arg\max p(x) = i\right\} fi=T1x∈B∑1{argmaxp(x)=i}
- PiP_iPi:batch 内 token 分配给 expert iii 的概率总和(可导)
Pi=1T∑x∈Bpi(x) P_i = \frac{1}{T} \sum_{x \in \mathcal{B}} p_i(x) Pi=T1x∈B∑pi(x)
这个设计为什么合理
- fif_ifi 反映“硬路由”的真实负载,但不可导
- PiP_iPi 反映“软概率”的期望负载,可导
- 用 fi⋅Pif_i \cdot P_ifi⋅Pi 把“真实负载”与“可导信号”绑定在一起,让梯度能推动概率分布向均衡方向调整
为什么乘一个 NNN
当完全均衡时:
- fi=1Nf_i = \frac{1}{N}fi=N1
- Pi=1NP_i = \frac{1}{N}Pi=N1
则:
∑i=1NfiPi=∑i=1N1N⋅1N=1N \sum_{i=1}^{N} f_i P_i = \sum_{i=1}^{N} \frac{1}{N}\cdot\frac{1}{N} = \frac{1}{N} i=1∑NfiPi=i=1∑NN1⋅N1=N1
再乘一个 NNN:
N⋅∑i=1NfiPi=1 N \cdot \sum_{i=1}^{N} f_i P_i = 1 N⋅i=1∑NfiPi=1
这样在“理想均衡”时,loss 会保持一个与 expert 数无关的常数量级,便于在不同规模下复用 α\alphaα 的经验范围。实践中 α\alphaα 常在 10−510^{-5}10−5 到 10−110^{-1}10−1 扫描,发现 10−210^{-2}10−2 往往足以维持负载均衡且不过度干扰收敛。
8. 三者对比总结:你需要抓住的关键差异
8.1 路由粒度与稀疏度
- GShard:Top-2(K=2K=2K=2),并对 2nd expert 引入 random routing
- Mixtral:Top-2(K=2K=2K=2),expert 采用 SwiGLU,强调“总参大但每 token 只激活少量参数”
- Switch:Top-1(K=1K=1K=1),极简路由,利于扩展到超多 expert
8.2 capacity / overflow 的一致性逻辑
三者都必须面对同一个工程现实:
- 某些 expert 会变“热门”
- capacity 不够就 overflow
- overflow 会导致 drop tokens 或退化路径
- 因此必须通过负载均衡损失与 capacity factor 做权衡
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐


所有评论(0)