【LLM】OPD
完整流程就是:
- Prompt 喂给 Student,Student 自回归生成输出序列
- 把 Prompt + Student生成的序列 整个喂给 Teacher 和 Student 各做一次前向传播
- 在每个生成的 token 位置上,拿到 Teacher 的 logits 和 Student 的 logits
- 计算 KL 散度,反向传播更新 Student
一、为什么需要 OPD?设计动机
要理解 OPD 的设计思路,先要理解它解决的是什么问题。
传统的知识蒸馏(Off-Policy)做法是:准备一批固定的文本数据,Teacher 和 Student 同时在这批数据上计算 logits,然后对齐。这个方式有一个根本性的缺陷,叫做分布偏移(Distribution Shift)。
想象一下 Student 在推理时的真实状态:它是自回归生成的,第 t 步的输入 prefix 是它自己在前 t-1 步生成的内容。但训练时,prefix 全部来自外部数据集,Student 从来没有在"自己生成的文字"上被训练过。
于是就出现了一个矛盾:训练时 Student 看到的是"完美的 prefix",推理时它看到的是"自己可能犯错的 prefix"。一旦 Student 在某一步生成了一个偏差词,后续所有 token 都处于训练时从未遇到过的分布之下,错误会滚雪球式累积。
OPD 的核心思想就是:既然推理时 Student 要面对自己生成的 prefix,那训练时就让它在自己生成的 prefix 上学习。让 Teacher 跟着 Student 走,而不是让 Student 跟着 Teacher 走。
二、具体实现流程
整体流程
给定一个 prompt xxx,完整的一个训练步骤如下:
Step 1 — Student 自回归采样
把 prompt 喂给 Student,让它自回归地生成完整的输出序列 y^=(y^1,y^2,…,y^T)\hat{y} = (\hat{y}_1, \hat{y}_2, \ldots, \hat{y}_T)y^=(y^1,y^2,…,y^T)。注意此时是真实的采样或 greedy decode,不是 teacher-forcing。
Step 2 — 双向前向传播
把完整的序列 (x,y^)(x, \hat{y})(x,y^) 同时喂给 Teacher 和 Student,各做一次前向传播。由于 Transformer 的 causal mask,每个位置 t 只能看到它之前的 token,所以一次前向传播就能拿到所有位置的 logits,等价于逐个 prefix 单独输入,但效率高得多。
Step 3 — 计算每个位置的 KL 散度
在每个生成 token 的位置 t 上,分别拿到:
- Teacher 的概率分布:pT(⋅∣x,y^<t)p_T(\cdot \mid x, \hat{y}_{<t})pT(⋅∣x,y^<t)
- Student 的概率分布:pS(⋅∣x,y^<t)p_S(\cdot \mid x, \hat{y}_{<t})pS(⋅∣x,y^<t)
计算这两个分布之间的 KL 散度。注意 prompt 部分的位置不参与 loss 计算。
Step 4 — 汇总损失,更新 Student
把所有位置的 KL 累加,对 Student 的参数做反向传播:
LOPD=Ex∼D Ey^∼pS(⋅∣x)[∑t=1TDKL (pT(⋅∣x,y^<t) ∥ pS(⋅∣x,y^<t))]\mathcal{L}_{OPD} = \mathbb{E}_{x \sim \mathcal{D}}\, \mathbb{E}_{\hat{y} \sim p_S(\cdot|x)} \left[ \sum_{t=1}^{T} D_{KL}\!\left( p_T(\cdot \mid x, \hat{y}_{<t}) \;\Big\|\; p_S(\cdot \mid x, \hat{y}_{<t}) \right) \right]LOPD=Ex∼DEy^∼pS(⋅∣x)[t=1∑TDKL(pT(⋅∣x,y^<t) pS(⋅∣x,y^<t))]
重复以上步骤,每次迭代 Student 参数更新后,下一轮采样的 y^\hat{y}y^ 自然也会随之变化,这就是"On-Policy"的含义——训练分布始终跟随当前策略。
KL 方向的选择
KL 散度不对称,方向的选择有实质影响:
前向 KL,即 DKL(pT∥pS)D_{KL}(p_T \| p_S)DKL(pT∥pS),也叫 Mode-Covering。它惩罚 Student 对 Teacher 高概率区域给出低概率,驱使 Student 尽量覆盖 Teacher 的所有可能性。实现简单,训练稳定,大多数工作默认使用这个方向。
反向 KL,即 DKL(pS∥pT)D_{KL}(p_S \| p_T)DKL(pS∥pT),也叫 Mode-Seeking。它惩罚 Student 在 Teacher 低概率的地方给出高概率,驱使 Student 专注于模仿 Teacher 最主要的生成模式,可以忽略一些次要的长尾分布。MiniLLM 使用的就是这个方向,理论上更适合生成任务,但由于梯度需要用 REINFORCE 估计,训练方差较高。
三、与 Off-Policy 蒸馏的本质区别
两者的数学形式看起来类似,但期望的分布完全不同:
Off-Policy 计算的是:
Ey<t ∼ pdata[ DKL(⋯ ) ]\mathbb{E}_{y_{<t}\,\sim\, p_{\text{data}}}[\,D_{KL}(\cdots)\,]Ey<t∼pdata[DKL(⋯)]
OPD 计算的是:
Ey^<t ∼ pS(⋅∣x)[ DKL(⋯ ) ]\mathbb{E}_{\hat{y}_{<t}\,\sim\, p_S(\cdot|x)}[\,D_{KL}(\cdots)\,]Ey^<t∼pS(⋅∣x)[DKL(⋯)]
这一个下标的差异,决定了 Student 是否能在"自己真实会走到的状态"上得到指导。
四、优点
分布一致性强:训练和推理时 Student 面对的 prefix 分布相同,从根本上解决了 exposure bias 问题,在长文本生成任务上效果提升尤为明显。
错误恢复能力:Student 生成了错误的 prefix 时,Teacher 会在这个错误的上下文上给出指导,Student 能学到"走偏之后如何纠正",这是 Off-Policy 完全无法做到的。
无需高质量标注数据:训练数据完全由 Student 自己生成,只需要 prompt,不需要人工标注的 ground truth 回答,数据准备成本低。
兼容性好:可以和 SFT、RLHF 等训练范式结合,也可以作为任何序列生成模型的压缩手段。
五、缺点
计算开销大:每个训练步都需要 Student 做一次完整的自回归采样,然后 Teacher 和 Student 各做一次前向传播。相比 Off-Policy 只需要两次前向传播,训练成本大约高 2~3 倍,Teacher 越大开销越明显。
训练初期不稳定:Student 初期质量很差,采出来的序列往往是乱的,导致 Teacher 打分的上下文也很混乱,梯度信号噪声大。常见解决办法是先做 Off-Policy 热身(warm-up),让 Student 具备基本能力后再切换到 On-Policy。
梯度方差问题:尤其是使用反向 KL 时,梯度需要通过 REINFORCE 估计,方差较高,需要加 baseline 或 variance reduction 技术来稳定训练。
Teacher 必须在线可用:Teacher 需要在训练过程中实时做推理,无法提前把 Teacher 的输出缓存下来,这意味着 Teacher 必须全程驻留在显存中,对硬件资源要求较高。
六、应用场景
LLM 模型压缩:这是最主流的应用,用 70B、405B 级别的大模型作为 Teacher,蒸馏出 7B、13B 级别的小模型,使其在对话、推理等任务上尽量接近大模型的能力。MiniLLM、DistiLLM 都是这个场景下的代表工作。
长文本生成任务:摘要、故事续写、代码生成等任务中,输出序列很长,Off-Policy 的分布偏移问题会随序列长度成倍放大,OPD 的优势在这类任务上尤为突出。
推理与规划任务:对于数学解题、逻辑推理这类需要多步骤连贯输出的任务,中间步骤的错误会直接影响后续步骤,OPD 能让 Student 在自己真实会生成的推理链上得到纠正。
领域适配蒸馏:在医疗、法律、代码等垂直领域,收集高质量标注数据成本极高,而 OPD 只需要 prompt,由 Student 自己生成回答后再由领域 Teacher 打分指导,极大降低了数据门槛。
强化学习结合:OPD 和 RLHF/PPO 的结构非常相似(都是 on-policy 的策略优化),可以把 KL 蒸馏损失作为 reward shaping 的一部分,约束 RL 训练时 Student 不要偏离 Teacher 太远,防止 reward hacking。
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐



所有评论(0)