深入大模型架构
学习资料https://datawhalechina.github.io/base-llm/#/
一、手搓一个大模型
1、Llama2 架构总览(核心差异对比经典 Transformer)
架构类型:纯 Decoder-Only(自回归生成模型)
关键改进点(相对于原始 Transformer Decoder):
- 预归一化(Pre-Normalization):在每个子层(Attention / FFN)之前做 RMSNorm,而不是之后。
- 归一化替换:LayerNorm → RMSNorm(更高效)。
- 位置编码:绝对位置编码(加法)→ RoPE(旋转位置编码)(相对位置,长度外推能力强)。
- 注意力机制:Multi-Head Attention(MHA)→ Grouped-Query Attention(GQA)(推理加速 + 显存节省)。
- 前馈网络:ReLU FFN → SwiGLU(性能更优的门控激活)。
- 残差连接:保持不变,但配合 Pre-Norm 使用更稳定。
整体数据流: 输入 token_ids → Embedding → N × TransformerBlock(Pre-Norm + Attention + Pre-Norm + SwiGLU + 残差)→ Final RMSNorm → Linear(输出 logits)

2、核心组件详解
1. RMSNorm(预归一化)
- 作用:稳定深层网络训练,比 LayerNorm 计算更简单、更快。
- 公式:y = x * (1 / sqrt(mean(x²) + ε)) * γ(去掉均值中心化,只做尺度缩放)。
- 实现要点:
- 张量
x形状为[batch_size, seq_len, dim]。 - 只保留可学习参数 weight(γ)。
- torch.rsqrt() + mean(dim=-1, keepdim=True) 是核心。
- 张量
- 优势:计算效率高,数值稳定。
# code/C6/llama2/src/norm.py
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim)) # 对应公式中的 gamma
def _norm(self, x: torch.Tensor) -> torch.Tensor:
# 核心计算:x * (x^2的均值 + eps)的平方根的倒数
# self.eps 则是一个为了防止除以零而添加的小常数,保证了数值稳定性
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
def forward(self, x: torch.Tensor) -> torch.Tensor:
out = self._norm(x.float()).type_as(x)
return out * self.weight
2. RoPE(旋转位置编码)—— 最优雅的位置编码

- 核心思想:把每对维度视为复数,通过复数乘法对 Query 和 Key 进行旋转。位置信息只影响 Q 和 K 的几何方向关系,从而影响注意力分数。
- 优势:
- 相对位置编码(注意力只依赖相对距离,具有平移不变性)。
- 优秀长度外推能力(对超出训练长度的序列仍能较好地处理位置关系)。
- 计算高效(不改变向量模长,避免了额外的归一化操作)。
- 实现关键函数:
- precompute_freqs_cis():预计算旋转角度(复数形式)。
- apply_rotary_emb():将 Q 和 K 转为复数后相乘,再转回实数。
- reshape_for_broadcast():处理广播。
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor:
# 1. 计算频率:1 / (theta^(2i/dim))
# torch.arange(0, dim, 2) / dim 对应公式中的 2i/dim:i 实际遍历的是偶数维索引(长度为 dim/2)。
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
# 2. 生成位置序列 t = [0, 1, ..., end-1]
t = torch.arange(end, device=freqs.device)
# 3. 计算相位:t 和 freqs 的外积
freqs = torch.outer(t, freqs).float()
# 4. 转换为复数形式 (cos(theta) + i*sin(theta))
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
return freqs_cis
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
ndim = x.ndim
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
# 将 Q/K 向量视为复数
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
# 准备广播
freqs_q = reshape_for_broadcast(freqs_cis, xq_) # 针对 Q 的广播视图
# 复数乘法即为旋转
xq_out = torch.view_as_real(xq_ * freqs_q).flatten(3)
# K 向量可能与 Q 向量有不同的头数(GQA),所以需单独生成广播视图
freqs_k = reshape_for_broadcast(freqs_cis, xk_)
xk_out = torch.view_as_real(xk_ * freqs_k).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xq)
- 工程技巧:初始化时预计算超过 max_seq_len 的长度(通常 ×2),支持推理时长序列。
3. Grouped-Query Attention(GQA)
- 解决的问题:标准 MHA 中 KV 缓存和计算量随头数线性增长,推理成本高。
- 核心机制:多个 Query 头 共享 同一组 Key/Value 头。
- n_heads:Query 头数
- n_kv_heads:KV 头数(通常 n_kv_heads < n_heads)
- n_rep = n_heads // n_kv_heads:重复倍数
- 关键实现:
- wq、wk、wv 的输出维度不同(体现分组)。
- repeat_kv() 函数:把 KV 头复制 n_rep 次,使形状与 Q 对齐。
- 收益:显著降低 KV Cache 显存占用和推理计算量(Llama2-70B 特别明显)。
# code/C6/llama2/src/rope.py
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
batch_size, seq_len, n_kv_heads, head_dim = x.shape
if n_rep == 1:
return x
return (
x[:, :, :, None, :]
.expand(batch_size, seq_len, n_kv_heads, n_rep, head_dim)
.reshape(batch_size, seq_len, n_kv_heads * n_rep, head_dim)
)
class GroupedQueryAttention(nn.Module):
def __init__(self, dim: int, n_heads: int, n_kv_heads: int | None = None, ...):
...
self.n_local_heads = n_heads
self.n_local_kv_heads = n_kv_heads
self.n_rep = self.n_local_heads // self.n_local_kv_heads # Q头与KV头的重复比
...
self.wq = nn.Linear(dim, n_heads * self.head_dim, bias=False)
self.wk = nn.Linear(dim, n_kv_heads * self.head_dim, bias=False)
self.wv = nn.Linear(dim, n_kv_heads * self.head_dim, bias=False)
...
def forward(self, x, start_pos, freqs_cis, mask):
xq = self.wq(x).view(batch_size, seq_len, self.n_local_heads, self.head_dim)
xk = self.wk(x).view(batch_size, seq_len, self.n_local_kv_heads, self.head_dim)
xv = self.wv(x).view(batch_size, seq_len, self.n_local_kv_heads, self.head_dim)
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
# ... KV Cache 逻辑 ...
keys = repeat_kv(keys, self.n_rep) # <-- 关键步骤
values = repeat_kv(values, self.n_rep) # <-- 关键步骤
scores = torch.matmul(xq.transpose(1, 2), keys.transpose(1, 2).transpose(2, 3)) / ...
...
为什么用同一个 K、V 就能达到和多头注意力类似的效果?
这是 GQA 最核心的“智慧”所在,有以下几点原因:
(1)注意力头之间存在天然冗余(Redundancy)
- 在训练好的大模型中,不同注意力头虽然学习不同“子空间”(subspace),但很多头关注的模式其实高度相似。
- 同一组内的 Query 头会自适应地学习去关注相似的模式(因为它们共享相同的 Key/Value 信息源)。
- 实验观察:训练后查看注意力图(attention maps),同一组内的 Query 头往往关注非常相似的 token 组合。
(2)不同组提供足够的多样性
- GQA 不是全部共享(MQA),而是分成多个组(通常 4~8 组)。
- 不同组拥有独立的 KV 投影,能捕捉不同方面的信息(语法、语义、长距离依赖等)。
- 这保留了 MHA “多头并行关注不同表示子空间”的核心优势,只是把头数从 H 降低到 G。
(3)Query 头的灵活性补偿了 KV 的共享
- 每个 Query 头仍有独立的 Wq 投影矩阵。
- Query 可以“主动适应”共享的 KV,学习不同的注意力权重分布。
- 即使 K/V 相同,不同 Query 也可以通过 softmax 后的加权求和,产生差异化的输出。
(4)训练过程让模型自动适应
- GQA 通常在从 MHA 转换(uptraining)时,通过 mean-pooling 原有头来初始化共享 KV,然后继续少量训练。
- 模型会快速学会如何利用共享的 KV,同时保持整体表现。
4. SwiGLU 前馈网络
- 结构:三个线性层 + 门控机制。
- xW → Swish 激活(
swish(x) = x * sigmoid(x)) - xV → 作为门(Gate),与前一步的结果进行逐元素相乘
- 通过第三个变换
W2输出
- xW → Swish 激活(
- SwiGLU(x,W,V,W2)=(swish(xW)⊗xV)W2
- 优势:门控机制让网络能动态控制信息流,性能优于传统 ReLU FFN。
- hidden_dim 处理:使用 multiple_of 对齐,提高硬件利用率。
# code/C6/llama2/src/ffn.py
class FeedForward(nn.Module):
def __init__(self, dim: int, hidden_dim: int, multiple_of: int, ...):
super().__init__()
# hidden_dim 计算,并用 multiple_of 对齐以提高硬件效率
hidden_dim = int(2 * hidden_dim / 3)
...
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
self.w1 = nn.Linear(dim, hidden_dim, bias=False) # 对应 W
self.w2 = nn.Linear(hidden_dim, dim, bias=False) # 对应 W2
self.w3 = nn.Linear(dim, hidden_dim, bias=False) # 对应 V
def forward(self, x: torch.Tensor) -> torch.Tensor:
# F.silu(self.w1(x)) 实现了 swish(xW)
# * self.w3(x) 实现了门控机制
return self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
3、模型组装结构
TransformerBlock(单层):
h = x + Attention(RMSNorm(x))
out = h + SwiGLU(RMSNorm(h))
LlamaTransformer(整体):
- tok_embeddings:词嵌入
- layers:nn.ModuleList 堆叠 N 个 TransformerBlock
- norm:最终 RMSNorm
- output:线性层投射到 vocab_size(权重不共享)
- freqs_cis:注册为 buffer,预计算 RoPE
前向传播关键点:
- 根据 start_pos 切片 freqs_cis(支持 KV Cache 推理)
- 构建因果掩码(Causal Mask),并处理 start_pos 偏移
- 逐层调用 TransformerBlock
# code/C6/llama2/src/transformer.py
class LlamaTransformer(nn.Module):
def __init__(self, vocab_size: int, ...):
...
self.tok_embeddings = nn.Embedding(vocab_size, dim)
self.layers = nn.ModuleList([TransformerBlock(...) for i in range(n_layers)])
self.norm = RMSNorm(dim, eps=norm_eps)
self.output = nn.Linear(dim, vocab_size, bias=False)
self.register_buffer("freqs_cis", precompute_freqs_cis(...))
def forward(self, tokens: torch.Tensor, start_pos: int) -> torch.Tensor:
h = self.tok_embeddings(tokens)
# 1. 准备 RoPE 旋转矩阵
freqs_cis = self.freqs_cis[start_pos : start_pos + seq_len]
# 2. 准备因果掩码 (Causal Mask)
mask = None
if seq_len > 1:
mask = torch.full((seq_len, seq_len), float("-inf"), device=tokens.device)
mask = torch.triu(mask, diagonal=1)
# 考虑 KV Cache 的偏移
mask = torch.hstack([torch.zeros((seq_len, start_pos), ...), mask]).type_as(h)
# 3. 循环通过所有 TransformerBlock
for layer in self.layers:
h = layer(h, start_pos, freqs_cis, mask)
h = self.norm(h)
logits = self.output(h).float()
return logits
4、学习收获与核心知识点一览
最重要概念:
- Pre-Norm + RMSNorm:现代大模型训练稳定的关键。
- RoPE:目前最受欢迎的位置编码方案(Llama、Mistral、Qwen 等都在用)。
- GQA:大模型推理加速的核心技术(从 MHA → GQA/MQA 的演进)。
- SwiGLU:比 ReLU 更好的 FFN 设计。
工程思维:
- 组件解耦(norm、rope、attention、ffn 独立文件)。
- 预计算(freqs_cis)。
- 形状变换与广播技巧(view、reshape、repeat_kv)。
- 支持推理优化(KV Cache + start_pos)。
与经典 Transformer 的对比总结:
| 组件 | 经典 Transformer | Llama2 | 改进效果 |
|---|---|---|---|
| 归一化 | Post-LayerNorm | Pre-RMSNorm | 训练更稳定 |
| 位置编码 | 加法(绝对) | RoPE(相对) | 长度外推更好 |
| 注意力 | MHA | GQA | 推理更快、更省显存 |
| 前馈网络 | ReLU FFN | SwiGLU | 性能提升 |
二、MoE 架构解析
1、MoE 的核心概念与优势
稠密模型 vs MoE 模型
- 稠密模型(Llama2、GPT-3 等):每个 token 激活所有参数,计算量随参数量线性增长。
- MoE 模型:引入稀疏激活(Sparse Activation)机制,每个 token 只激活少量专家(通常 Top-2),实现参数量巨大,但活跃参数(Active Parameters)很少。
- 优势:知识容量大 + 推理成本低(推理速度接近小模型,知识容量接近大模型)。
核心机制:
- 专家网络(Experts):多个独立的 FFN(通常是 SwiGLU)。每个专家不再需要处理全局任务,只需专注于输入空间中的一个局部区域或一类特定的子任务。
- 门控网络(Router / Gate):为每个 token 动态计算路由分数,决定激活哪些专家。它接收与专家相同的输入 x,并输出一组混合比例(Mixing Proportions)pi,即选择每个专家的概率。
- 稀疏路由:每个 token 只选择 Top-k 个专家(常用 k=2),其余专家不参与计算。
2、MoE 的历史发展与关键里程碑(核心知识点)
1. 起源(1991):Adaptive Mixture of Local Experts
- 提出 分治思想(Divide and Conquer),解决多任务学习的强干扰效应。
- 引入 门控网络 + 竞争机制(使用负对数似然损失,而非简单均方误差)。
- 核心:鼓励“赢家通吃”,实现专家专业化,避免所有专家相互干扰。
2. 2013:Deep Mixture of Experts (DMoE)
- 将 MoE 模块化 并嵌入深度网络的多层。
- 层级化门控:实现指数级专家组合路径。
- 自动学习不同层级的特征解耦(位置专家、类别专家等)。
3. 2017:Sparsely-Gated MoE(Google Brain)
- 首次将 MoE 扩展到超大规模(137B 参数)。
- 提出 条件计算(Conditional Computation):大幅增加参数量,但计算量几乎不变。
- 关键技术:
- 带噪声的 Top-k 稀疏门控(只激活 k 个专家)。
- 辅助损失(Auxiliary Loss):Importance Loss + Load Loss,解决专家崩塌(部分专家饿死)问题。
4. 大模型时代关键工作
| 模型/工作 | 年份 | 参数规模 | 关键创新 | 路由策略 | 架构特点 |
|---|---|---|---|---|---|
| GShard | 2020 | 600B | 分布式 MoE 并行框架 | Top-2 | 隔层替换 FFN,专家分片 |
| Switch Transformer | 2021 | 1.6T | Top-1 Routing + Token Dropping | Top-1 | 极致简化,专家容量机制 |
| GLaM | 2021 | 1.2T | Decoder-only MoE | Top-2 | 活跃参数仅 8% |
| Mistral 8x7B | 2023 | 47B(活跃13B) | 开源实用 MoE | Top-2 | 8 专家,性能超 Llama2-70B |
| DeepSeekMoE | 2024 | - | 细粒度专家 + 共享专家 | Top-K + Shared | 显著提升参数效率 |
GShard:隔层替换 FFN,专家分片 + All-to-All 通信。
- 保留 Attention:Transformer 的 Self-Attention 层保持不变,因为其参数量相对较小且计算关键。
- 替换 FFN:将 Transformer Block 中的前馈神经网络替换为 MoE 层。
- 隔层设置:通常采用“隔层替换”的策略(例如第 1、3、5 层使用 MoE,第 2、4、6 层保留标准 FFN),在增加容量和保持稳定性之间取得平衡。
对于 MoE 层的计算,GShard 明确了输入 Token x 的输出 y 是由门控网络 G 选择的专家输出的加权和:

其中:
- pi(x) 是门控网络(Router)计算出的第 i 个专家的权重(通常是 Softmax 后的 Top-k 概率,其余为 0)。
- Ei(x) 是第 i 个专家网络(Expert FFN)对输入 x 的处理结果。
GShard 创造性地结合了数据并行(Data Parallelism)与模型并行(Model Parallelism),解决了超大模型的存储与通信难题。
- 非 MoE 层(如 Attention):采用复制(Replicated)策略。所有设备持有相同的副本,进行标准的数据并行训练。
- MoE 层:采用分片(Sharded)策略。专家网络被切分并分布在不同设备上(例如 2048 个专家分布在 2048 个 TPU 核上)。
当一个 Token 需要被路由到不在当前设备的专家时,系统会通过高效的 All-to-All 通信原语,将该 Token 发送到目标设备。计算完成后,再将结果传回。

Switch Transformer:Top-1 Routing + Capacity Factor + Token Dropping + 辅助损失。
它用稀疏的 Switch FFN 层(浅蓝色区域)替换了标准 Transformer 中的稠密 FFN 层。在该层中,对于输入序列中的每个 Token(例如图中的 "More" 和 "Parameters"),路由器(Router)会计算其路由概率,并将其分发给唯一的一个专家(实线箭头)进行处理。
这种 Top-1 Routing(单专家路由)机制是 Switch Transformer 与传统 MoE(通常路由给 Top-k 个专家,k > 1)最大的区别。尽管看似激进,但它带来了显著的优势:
- 减少路由计算:路由决策更简单。
- 降低通信成本:每个 Token 只需发送到一个目的地。
- 减小专家批量:每个专家需要处理的 Token 数量(Expert Capacity)至少减半。
虽然直觉上 k=1 可能限制了专家的协作,但实验证明这种简化不仅保持了模型质量,还显著提高了计算效率。

Switch Transformer 必须解决动态路由带来的负载不均问题。由于硬件通常要求静态的 Tensor 形状,模型必须预设每个专家能处理的最大 Token 数量,即专家容量(Expert Capacity): Capacity=(TotalTokensNumExperts)×CapacityFactor
- Capacity Factor(容量因子):通常设置为大于 1.0(如 1.0 或 1.25),这一机制的作用如图 6-7 所示。图中每个方块代表专家的处理槽位,Capacity Factor > 1.0 为专家提供了额外的缓冲空间(图中白色空槽位),以应对 Token 分配不均的情况。
- Token Dropping(丢弃机制):当路由到某个专家的 Token 数量超过其容量上限(即公式计算出的 CapacityCapacity)时(图中红色虚线所示的溢出部分),就会触发丢弃机制。这些多余的 Token 将不会被该专家处理,而是直接通过残差连接传递到下一层。这虽然保证了并行计算的静态形状要求,但也可能导致信息损失,所以合理的容量设置至关重要。

引入了一个辅助损失函数来尽量减少 Token 的丢弃,鼓励 Token 均匀分布到所有专家:
Loss_aux=α⋅N⋅∑i=1~N fi⋅Pi
- fi 是实际分发给专家 i 的 Token 比例(实际上有多少人去了专家 i 那里)。
- Pi 是预期路由给专家 i 的概率总和(门控网络觉得专家 i 应该接收多少人)。
GLaM:证明 MoE 在 Decoder-only 架构上同样高效,推理 FLOPs 仅为 GPT-3 的一半。
GLaM 展示了如何将 MoE 层有效地应用于 Decoder-only 的语言模型中。如图 6-8 它采用了隔层替换策略,即在标准的 Transformer 堆叠中,每隔一个层(upper block)将其中的 FFN 替换为 MoE 层(bottom block)。
在 MoE 层中,Gating 模块会根据输入 Token(例如 "roses")的特性,从 64 个专家中动态选择出最相关的 2 个专家(蓝色网格所示)。随后,这两个专家的输出经过加权平均后,传递给下一层的 Transformer 模块。这种机制确保了模型在拥有巨大参数量的同时,每次推理仅需激活极少部分的参数。
- 隔层稀疏:类似于 GShard,GLaM 采用隔层替换策略,将每隔一个 Transformer 层中的前馈网络(FFN)替换为 MoE 层。
- Top-2 路由:每个 MoE 层包含 64 个专家,对于每个输入 Token,门控网络会选择权重最高的 2 个专家进行处理。
- 活跃参数:尽管总参数量高达 1.2T,但对于每个 Token,仅激活 966 亿(96.6B) 参数(约占总量的 8%)。这意味着在推理时,GLaM 的计算量(FLOPs)仅为 GPT-3(175B 全激活)的约一半。

Mistral 8x7B:开源标杆,47B 总参数但仅激活 13B,性能超越 Llama2-70B。
-
架构参数:它拥有约 470 亿(47B) 的总参数量(Sparse Parameters),但对于每个 Token,仅激活 130 亿(13B) 参数(Active Parameters)。这使得它在推理时拥有 13B 模型的计算速度,却能发挥出 47B 模型的知识容量。需要注意的是,虽然计算量较小,但由于所有专家参数都需要加载到内存中,其显存占用(VRAM Usage)依然是 47B 模型级别的。
-
路由机制:每一层包含 8 个专家(Experts),采用标准的 Top-2 Routing 策略。如图 6-10 所示,每个输入 Token 会被 Router 网络分配给 8 个专家中的 2 个,这两个专家的输出经过加权求和后作为该层的最终输出。这种机制巧妙地在增加模型容量(更多专家)的同时,保持了极低的推理成本(只激活 2 个)。

专家并没有按预想的那样根据“学科领域”(如生物、数学、哲学)进行分工。
他们统计了不同领域数据(如 arXiv, PubMed, Wikipedia 等)在不同层(Layer 0, 15, 31)的专家分配比例。同一行(即同一个专家)在不同列(不同数据集)上的颜色深浅非常接近。这说明,无论输入文本属于哪个领域,Router 选择各专家的概率分布几乎是一样的。专家似乎更多地是根据语法和Token 结构(如缩进、介词)来分工,而非人类定义的知识领域。
DeepSeekMoE:细粒度专家分割 + 共享专家隔离(Shared Expert 始终激活),减少知识冗余,提升专业化程度。
DeepSeek 在 MoE 架构上进行了更深度的创新,提出了 DeepSeekMoE 10 架构,目标是解决传统 Top-k 路由中的“知识冗余”和“专业化不足”问题。

-
细粒度专家分割(Fine-Grained Expert Segmentation): DeepSeek 将一个标准的大专家拆分为多个更小的专家。对比图 6-13 的 (a) 和 (b) 可以看到,原本的专家 1 被进一步拆分为更小的专家 1 和 2。为了保持总计算量不变,激活的专家数量 K 也相应倍增(从 K=2 变为 K=4)。这种变化使得组合的可能性呈指数级增加,让模型能更灵活地组合不同的“知识碎片”来应对复杂输入,从而实现了更高的专家专业化。
-
共享专家隔离(Shared Expert Isolation): 这是 DeepSeekMoE 的核心创新。如图 6-13(c) 所示,专家 1 被指定为绿色的共享专家(Shared Expert)。它不再经过 Router 选择,而是通过一条独立的通路直接接收输入(Input Hidden),对所有 Token 总是被激活。Router 仅负责从剩余的路由专家中选择 K=3 个进行补充。
这种设计让共享专家负责捕获通用的、跨任务的知识(如语法),而路由专家则专注于特定的领域知识。通过这种“通用+专用”的分离,有效减少了路由专家中重复学习通用知识的冗余,显著提升了参数效率。

其中 S 代表共享专家集合(总是被激活),R 代表路由专家集合(仅选择性激活)。这种双路径结构是其区别于传统 MoE的关键。
3、MoE 的技术挑战与解决方案
- 专家崩塌(Expert Collapse) → 解决方案:辅助损失(Auxiliary Loss) + 带噪声的门控 + 负载均衡。
- 负载不均衡 → Capacity Factor(>1.0) + Token Dropping + Load Loss。
- 通信开销(分布式训练) → All-to-All 通信 + 专家分片策略。
- 训练不稳定 → Router z-loss、选择性精度、更小的初始化方差。
4、现代 MoE 核心设计要点
- 路由策略:Top-2 最常用(Mistral、GLaM);Top-1(Switch);细粒度 + Shared(DeepSeek)。
- 专家替换位置:通常替换 Transformer Block 中的 Feed-Forward Network (FFN)。
- 隔层替换:并非每层都用 MoE,常用隔层替换以平衡稳定性和容量。
- 活跃参数 vs 总参数:总参数巨大,活跃参数很少(Mistral 8x7B:47B → 13B)。
- 显存占用:推理时仍需加载所有专家参数,VRAM 接近总参数量。
5、代码实战核心知识点(Llama2 + MoE 实现)

实现思路:
- 保留原有 Attention、RMSNorm、RoPE 等组件。
- 将 FeedForward 替换为 MoE 类。
- MoE 类包含:
- gate:线性层(dim → num_experts),计算路由 logits。
- experts:nn.ModuleList 存放多个独立的 FeedForward(SwiGLU)。
- forward 流程(清晰但非最优):
- Gate 计算 logits。
- Top-k 选择专家 + Softmax 归一化权重。
- 对每个专家:收集分配给它的 tokens → 专家计算 → 加权 → index_add_ 累加。
关键超参数:
- num_experts = 8(Mistral 经典配置)
- top_k = 2
工程特点:
- 接口完全兼容原有 LlamaTransformer(输入输出形状不变)。
- 代码修改量极小(只需替换 FFN 层)。
# code/C6/MoE/src/ffn.py
# ... (保留原有的 FeedForward 类)
class MoE(nn.Module):
def __init__(self, dim: int, hidden_dim: int, multiple_of: int, ffn_dim_multiplier: Optional[float], num_experts: int = 8, top_k: int = 2):
super().__init__()
self.num_experts = num_experts
self.top_k = top_k
# 门控网络:决定每个 Token 去往哪个专家
# 通过 self.gate(x_flat) 计算每个 Token 对所有 8 个专家的打分(Logits)。
self.gate = nn.Linear(dim, num_experts, bias=False)
# 专家列表:创建 num_experts 个独立的 FeedForward 网络
self.experts = nn.ModuleList([
FeedForward(dim, hidden_dim, multiple_of, ffn_dim_multiplier)
for _ in range(num_experts)
])
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: (batch_size, seq_len, dim)
B, T, D = x.shape
x_flat = x.view(-1, D)
# 1. 门控网络
gate_logits = self.gate(x_flat) # (B*T, num_experts)
# 2. Top-k 路由
# 使用 torch.topk 选出每个 Token 分数最高的 k=2 个专家及其索引。并通过 Softmax 对这 k 个权重进行归一化,确保它们的和为 1。
weights, indices = torch.topk(gate_logits, self.top_k, dim=-1)
weights = F.softmax(weights, dim=-1) # 归一化权重
output = torch.zeros_like(x_flat)
# Dispatch(分发与计算): 这是 MoE 的核心。我们遍历每一个专家:
#通过 torch.where 找出所有被分配给当前专家的 Token 索引;
#将这些 Token 挑选出来(Index Select),送入对应的 expert 网络(即一个 SwiGLU FFN)进行计算。
for i, expert in enumerate(self.experts):
# 3. 找出所有选中当前专家 i 的 token 索引
batch_idx, k_idx = torch.where(indices == i)
if len(batch_idx) == 0:
continue
# 4. 取出对应的输入进行计算
expert_input = x_flat[batch_idx]
expert_out = expert(expert_input)
# 5. 获取对应的权重
expert_weights = weights[batch_idx, k_idx].unsqueeze(-1) # (num_selected, 1)
# 6. 将结果加权累加回输出张量
output.index_add_(0, batch_idx, expert_out * expert_weights)
return output.view(B, T, D)
6、学习总结与核心知识点
最重要概念:
- 稀疏激活:MoE 的本质——用海量参数换取极低的活跃计算量。
- Router(门控网络) + Experts:MoE 的两大核心组件。
- Top-k Routing:目前最主流的路由策略(尤其是 Top-2)。
- 负载均衡:MoE 训练中最关键的挑战,必须通过辅助损失解决。
- 共享专家(DeepSeek):现代 MoE 的重要创新方向。
MoE vs 稠密模型对比:
- 参数容量:MoE ≫ 稠密
- 推理计算量:MoE ≈ 小模型
- 训练难度:MoE > 稠密(需处理负载均衡、通信等)
- 显存占用:MoE ≈ 总参数量(所有专家仍需加载)
实际意义: MoE 是目前实现“更大、更强、更省”的最有效技术路径之一,已成为开源大模型的重要发展方向(Mistral、DeepSeek、Qwen2.5-MoE 等)。
三、手撕大模型生成策略
通过调试 Hugging Face Transformers 库中的 pipeline("text-generation") 和 model.generate(),彻底搞清楚从输入 Prompt 到输出文本的完整数据流,理解大模型“逐 token 生成”的底层工作原理,以及各种解码策略(Decoding Strategies)的实现机制和实际效果。
1、生成流程整体概览(5 个核心阶段)
一次完整的文本生成任务可以拆解为以下 5 个接力阶段:
- 预处理(Preprocess)
- tokenizer(prompt) → input_ids + attention_mask
- 输出形状示例:(batch_size=1, seq_len=4)
- 生成入口(model.generate())
- Pipeline 的 _forward() 调用 model.generate()
- 合并用户参数与 GenerationConfig 默认配置(max_new_tokens、temperature、do_sample 等)
- 策略选择与分发
- 根据 GenerationConfig 判断生成模式(GenerationMode)
- 常见模式:
- GREEDY_SEARCH(do_sample=False, num_beams=1)
- SAMPLE(do_sample=True, num_beams=1)← 最常用
- BEAM_SEARCH(do_sample=False, num_beams>1)
- BEAM_SAMPLE(do_sample=True, num_beams>1)
- 解码循环(Decoding Loop) —— 核心中的核心
- while 循环(直到满足停止条件)
- 每一步流程:
- prepare_inputs_for_generation()(处理 KV Cache)
- 模型 forward() → logits[:, -1, :](只取最后一个位置)
- logits_processor(...)(规则修正 logits)
- 选择 next_token(argmax 或 multinomial 采样)
- torch.cat() 将新 token 拼接到 input_ids
- stopping_criteria() 判断是否停止
- 后处理(Postprocess)
- tokenizer.decode() 将 token ids 转回文本
- 根据 return_type 处理 FULL_TEXT(带 prompt)或 NEW_TEXT(仅新增内容)
- 使用 prompt_length 裁剪 prompt 部分
2、关键技术组件详解(核心知识点)
1. KV Cache(past_key_values)
- 极大加速逐 token 生成的关键。
- 每生成一个新 token,只需计算新 token 的 Q/K/V,其余历史 K/V 直接复用。
- 在调试中看到 DynamicCache + 多层 DynamicLayer。
2. LogitsProcessorList(logits 规则链)
LogitsProcessorList 就是大模型生成时的“规则引擎”,它在每一步生成之前,对模型输出的原始 logits 进行各种加工、惩罚、过滤和扭曲,最终决定“模型这一步应该从哪些词里选下一个词,以及怎么选”。
- _get_logits_processor() 的核心作用:把生成参数翻译成一系列处理器(Processors / Warpers)。
- 常见 Warpers(按执行顺序):
- TemperatureLogitsWarper:logits / temperature(控制随机性)
- TopKLogitsWarper:只保留 top-k,其余置 -inf
- TopPLogitsWarper(Nucleus Sampling):累计概率 ≥ p 的最小集合
- TypicalLogitsWarper、EpsilonLogitsWarper、EtaLogitsWarper 等
温度(Temperature)作用:
- T < 1:分布变尖 → 更确定、保守
- T = 1:原始分布
- T > 1:分布变平 → 更多样性、随机
top_k vs top_p:
- top_k:固定候选数量(更稳定,但可能重复)
- top_p:动态候选集合(根据分布尖锐程度自适应,更常用)
3. 常用解码策略对比(重点掌握)
| 策略 | do_sample | num_beams | 特点 | 适用场景 | 优缺点 |
|---|---|---|---|---|---|
| Greedy | False | 1 | 每步选概率最大的 token | 需要强确定性任务 | 快、稳定,但易重复、局部最优 |
| Sampling | True | 1 | 从(处理后的)概率分布中随机采样 | 开放式生成、对话、创意写作 | 多样性好,是 LLM 最常用策略 |
| Beam Search | False | >1 | 维护多条候选路径,序列级优化 | 翻译、摘要等需要精确性的任务 | 更“全局”但保守、计算开销大 |
| Beam Sample | True | >1 | Beam 框架 + 采样随机性 | 需要一定多样性的搜索任务 | 折中方案,实际使用较少 |
实际推荐(大模型对话生成):
- 默认使用 Sampling + temperature + top_p(最平衡)
- Greedy 适合 baseline 测试
- Beam Search 适合机器翻译、摘要等结构化任务
4. Stopping Criteria(停止条件)
- 达到 max_new_tokens 或 max_length
- 遇到 eos_token_id
- 自定义停止字符串等
3、调试技巧总结(实用干货)
- 化繁为简:从 pipeline(...) 和 generator(prompt) 开始下断点。
- 善用 Ctrl + B(转到定义)回溯调用链。
- 关注关键变量:
- input_ids、attention_mask
- past_key_values(KV Cache)
- logits 形状(尤其是 [:, -1, :])
- prepared_logits_processor(规则链)
- generation_mode
- 步入 vs 步过:重要函数步入,一般胶水代码步过。
- 断点位置技巧:变量赋值行通常在“执行前”停住,需单步一次才能看到最终值。
4、整体调用链总结(一句话版)
Pipeline → preprocess()(tokenizer) → model.generate() → GenerationConfig + 策略选择 → _sample() / _greedy_search() 等解码循环 (prepare → forward → logits_processor → 采样/贪心 → cat → stopping_criteria) → postprocess()(decode + FULL/NEW 处理)
5、学习收获与核心知识点一览
最重要概念:
- 逐 token 自回归生成:大模型生成本质是循环调用模型,每次只预测下一个 token。
- KV Cache:推理加速的核心技术。
- LogitsProcessorList:所有生成控制参数(temperature, top_p, top_k 等)的实际执行者。
- Sampling 是主流:开放式生成中最常用、最灵活的策略。
- Greedy vs Sampling:确定性 vs 多样性 的权衡。
工程认知:
- model.generate() 封装了大量复杂逻辑,调试是理解它的最佳方式。
- 真正影响生成质量的往往不是模型架构,而是 生成参数 + logits 处理规则。
- Transformers 库的生成模块设计非常模块化(Processor、Warper、StoppingCriteria 等),易于扩展。
四、上下文学习与提示词技术
理解大语言模型在不更新任何权重的情况下,如何通过提示词(Prompt)实现“上下文学习”(In-Context Learning, ICL),掌握从 Zero-shot 到 Few-shot 的基本机制,以及进阶提示词技术(思维链、思维树、专用推理模型)的原理、效果与局限性。
1、上下文学习(In-Context Learning)核心概念
定义: 在推理阶段,仅通过在提示词中提供任务描述 + 少量示例(或无示例),让模型直接完成新任务,而不进行任何梯度更新或参数微调。
1. Zero-Shot Learning(零样本学习)
- 特点:只给出任务指令 + 输入,不提供任何输出示例。
- 依赖:
- 预训练阶段积累的广博知识与共享语义空间
- 对齐微调后形成的指令遵循能力(Instruction Following)
- 机制:模型将新任务映射到已学过的语义概念上,通过自回归生成直接输出结果。
- 优势:简单、快速
- 局限:对输出格式控制弱,复杂或反直觉任务效果不稳定。
示例:
阅读以下评论,并判断一下评论表达的情感:
评论:这部电影的剧情有些拖沓。
情感:
2. Few-Shot Learning(少样本学习)
- 特点:在提示词中提供若干输入-输出示例(示范),帮助模型理解任务格式和隐含规则。
- One-Shot:只给 1 个示例(Few-Shot 的特例)。
- 效果:显著提升输出格式一致性、稳定性,以及对非标准规则的遵循能力。
为什么 Few-Shot 比 Zero-Shot 强?
- 多个一致性示例在上下文窗口中构建了一个临时的任务概率分布。
- 模型通过注意力机制提炼并模仿示例中的模式,并将其泛化到当前输入。
- 能有效对抗模型自身的语言先验和对话习惯。
反直觉规则示例(证明 Few-Shot 的强大):
- 要求情感词倒序输出(正面 → 面正,负面 → 面负)。
- 1 个示例往往不够,3–4 个示例才能让模型确信这是一条必须遵循的规则。
2、上下文学习的内在机制(核心理论解释)
- 感应头机制(Induction Heads)
- Transformer 中存在特殊的注意力头,行为模式是“匹配并复制”。
- 当发现当前输入在上下文中出现过时,会把注意力投向前一个匹配项,并倾向于复制其后面的 token。
- 这是 Few-Shot 中模式匹配的微观基础。
- 隐式学习动力学(Implicit Weight Update)
- 前向计算过程本身可近似看作一种“瞬时适配”。
- 上下文中的示例信息通过自注意力写入激活,再经 MLP 处理,产生类似低秩适配(LoRA)的瞬时效果,影响后续 logits。
- (近似)贝叶斯视角
- 将上下文学习视为在上下文中对“潜在任务”进行后验推断的过程。
- Zero-Shot 更依赖先验,Few-Shot 通过示例更新后验。
总结:上下文学习是大模型在预训练中习得的元学习能力(learning to learn),通过注意力机制和激活传播实现动态模式匹配。
3、进阶提示词技术
1. 思维链(Chain-of-Thought, CoT)
- 核心思想:让模型显式输出中间推理步骤,而不是直接给出答案。
- 作用:
- 增加“思考时间”,把复杂问题拆解为多个简单逻辑步骤。
- 缓解“直觉短路”和局部最优问题。
- 演进路径:
- Few-Shot CoT:手动编写详细推理示例(最早、最稳)
- Zero-Shot CoT:在提示末尾加 “Let's think step by step”(最简便)
- 注意事项:
- 触发词效果与语言强相关(英文 “Let's think step by step” 往往优于中文直译)。
- 效果高度依赖模型的预训练数据分布(逻辑严密的数据越多越好)。
2. 思维树(Tree of Thoughts, ToT)
- 核心思想:将线性 CoT 扩展为树状多路径探索。
- 关键环节:
- 生成多个候选下一步(分支)
- 模型自我评估每个分支的潜力(打分)
- 使用搜索算法(BFS/DFS)选择最优路径 + 回溯机制
- 适用任务:需要全局规划、容易陷入死胡同的问题(如 24 点、填字游戏、复杂规划)。
- 优缺点:
- 优势:搜索空间更大,更接近系统性求解。
- 缺点:计算成本高、自我评估易偏差、搜索空间易爆炸。
- 工程实现:可用轻量 Python 脚本 + BFS/DFS + API 调用实现,无需复杂 Agent 框架。
3. 专用推理模型(Reasoning Models)
专用推理模型就是把“思考”这件事从提示词层面升级到了模型权重层面,通过强化学习让模型学会主动进行长时间、多步骤、自我纠错的深度推理,从而在高难度推理任务上实现质的飞跃。
- 代表:OpenAI o1 系列、DeepSeek-R1 等。
- 训练范式转变:
- 不再仅靠监督微调 + 人类偏好对齐。
- 大量使用强化学习(RL),以正确答案为奖励信号,鼓励模型延长推导链条、自我校验、纠错。
- 外部表现:
- 自动进行长链条思考(耗时数秒至数分钟)。
- 输出前有明显的“思考过程”。
- 工程建议:
- 使用路由机制(Router):简单任务用基础模型,复杂推理任务路由给推理模型,平衡质量与成本。
4、核心知识点一览(重点复习)
| 技术 | 示例数量 | 是否需要中间步骤 | 适用场景 | 主要机制 | 局限性 |
|---|---|---|---|---|---|
| Zero-Shot | 0 | 否 | 简单指令任务 | 指令遵循 + 先验知识 | 输出格式不稳定 |
| Few-Shot | 1~N | 否 | 需要特定格式或规则的任务 | 上下文模式匹配 + 感应头 | 示例过多会占用上下文 |
| Chain-of-Thought | 0~N | 是 | 逻辑推理、数学、规划 | 显式分步思考 | 触发词语言敏感 |
| Tree-of-Thoughts | - | 是(多路径) | 复杂规划、搜索类任务 | 多分支探索 + 自我评估 + 回溯 | 计算成本高 |
| Reasoning Models | - | 是(自动长链) | 高难度推理任务 | RL 强化 + 长链条自省 | 推理延迟大、成本高 |
最重要原理:
- 上下文学习本质是利用注意力机制在激活空间中进行动态模式匹配和隐式适配。
- 提示词技术的本质是通过精心设计的上下文,引导模型激活已有的能力或构建临时任务分布。
- 更强的推理能力越来越依赖后训练(尤其是强化学习) 而非单纯提示工程。
5、学习总结与实践启示
- 提示词不是魔法,而是激活模型已有能力的杠杆。
- 示例数量:简单任务 Zero-Shot 即可;复杂/反直觉任务需要更多 Few-Shot。
- CoT 是性价比最高的进阶技术,但触发词效果与模型训练数据分布强相关。
- ToT 适合需要系统性搜索的任务,但工程成本较高。
- 专用推理模型 代表未来趋势:把复杂推理能力内化到权重中,而非仅靠提示词临时激发。
实际使用建议:
- 常规任务 → Zero/Few-Shot + 清晰指令
- 需要推理 → 优先尝试 Zero-Shot CoT(加 “Let's think step by step”)
- 高难度规划 → ToT 或直接调用推理模型(如 o1、DeepSeek-R1)
- 生产系统 → 实现 Prompt Router,根据任务复杂度动态选择模型与提示策略
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐

所有评论(0)