完整流程就是:

  1. Prompt 喂给 Student,Student 自回归生成输出序列
  2. Prompt + Student生成的序列 整个喂给 Teacher 和 Student 各做一次前向传播
  3. 在每个生成的 token 位置上,拿到 Teacher 的 logits 和 Student 的 logits
  4. 计算 KL 散度,反向传播更新 Student

一、为什么需要 OPD?设计动机

要理解 OPD 的设计思路,先要理解它解决的是什么问题。

传统的知识蒸馏(Off-Policy)做法是:准备一批固定的文本数据,Teacher 和 Student 同时在这批数据上计算 logits,然后对齐。这个方式有一个根本性的缺陷,叫做分布偏移(Distribution Shift)

想象一下 Student 在推理时的真实状态:它是自回归生成的,第 t 步的输入 prefix 是它自己在前 t-1 步生成的内容。但训练时,prefix 全部来自外部数据集,Student 从来没有在"自己生成的文字"上被训练过。

于是就出现了一个矛盾:训练时 Student 看到的是"完美的 prefix",推理时它看到的是"自己可能犯错的 prefix"。一旦 Student 在某一步生成了一个偏差词,后续所有 token 都处于训练时从未遇到过的分布之下,错误会滚雪球式累积。

OPD 的核心思想就是:既然推理时 Student 要面对自己生成的 prefix,那训练时就让它在自己生成的 prefix 上学习。让 Teacher 跟着 Student 走,而不是让 Student 跟着 Teacher 走。


二、具体实现流程

整体流程

给定一个 prompt xxx,完整的一个训练步骤如下:

Step 1 — Student 自回归采样

把 prompt 喂给 Student,让它自回归地生成完整的输出序列 y^=(y^1,y^2,…,y^T)\hat{y} = (\hat{y}_1, \hat{y}_2, \ldots, \hat{y}_T)y^=(y^1,y^2,,y^T)。注意此时是真实的采样或 greedy decode,不是 teacher-forcing。

Step 2 — 双向前向传播

把完整的序列 (x,y^)(x, \hat{y})(x,y^) 同时喂给 Teacher 和 Student,各做一次前向传播。由于 Transformer 的 causal mask,每个位置 t 只能看到它之前的 token,所以一次前向传播就能拿到所有位置的 logits,等价于逐个 prefix 单独输入,但效率高得多。

Step 3 — 计算每个位置的 KL 散度

在每个生成 token 的位置 t 上,分别拿到:

  • Teacher 的概率分布:pT(⋅∣x,y^<t)p_T(\cdot \mid x, \hat{y}_{<t})pT(x,y^<t)
  • Student 的概率分布:pS(⋅∣x,y^<t)p_S(\cdot \mid x, \hat{y}_{<t})pS(x,y^<t)

计算这两个分布之间的 KL 散度。注意 prompt 部分的位置不参与 loss 计算。

Step 4 — 汇总损失,更新 Student

把所有位置的 KL 累加,对 Student 的参数做反向传播:

LOPD=Ex∼D Ey^∼pS(⋅∣x)[∑t=1TDKL ⁣(pT(⋅∣x,y^<t)  ∥  pS(⋅∣x,y^<t))]\mathcal{L}_{OPD} = \mathbb{E}_{x \sim \mathcal{D}}\, \mathbb{E}_{\hat{y} \sim p_S(\cdot|x)} \left[ \sum_{t=1}^{T} D_{KL}\!\left( p_T(\cdot \mid x, \hat{y}_{<t}) \;\Big\|\; p_S(\cdot \mid x, \hat{y}_{<t}) \right) \right]LOPD=ExDEy^pS(x)[t=1TDKL(pT(x,y^<t) pS(x,y^<t))]

重复以上步骤,每次迭代 Student 参数更新后,下一轮采样的 y^\hat{y}y^ 自然也会随之变化,这就是"On-Policy"的含义——训练分布始终跟随当前策略。

KL 方向的选择

KL 散度不对称,方向的选择有实质影响:

前向 KL,即 DKL(pT∥pS)D_{KL}(p_T \| p_S)DKL(pTpS),也叫 Mode-Covering。它惩罚 Student 对 Teacher 高概率区域给出低概率,驱使 Student 尽量覆盖 Teacher 的所有可能性。实现简单,训练稳定,大多数工作默认使用这个方向。

反向 KL,即 DKL(pS∥pT)D_{KL}(p_S \| p_T)DKL(pSpT),也叫 Mode-Seeking。它惩罚 Student 在 Teacher 低概率的地方给出高概率,驱使 Student 专注于模仿 Teacher 最主要的生成模式,可以忽略一些次要的长尾分布。MiniLLM 使用的就是这个方向,理论上更适合生成任务,但由于梯度需要用 REINFORCE 估计,训练方差较高。


三、与 Off-Policy 蒸馏的本质区别

两者的数学形式看起来类似,但期望的分布完全不同:

Off-Policy 计算的是:

Ey<t ∼ pdata[ DKL(⋯ ) ]\mathbb{E}_{y_{<t}\,\sim\, p_{\text{data}}}[\,D_{KL}(\cdots)\,]Ey<tpdata[DKL()]

OPD 计算的是:

Ey^<t ∼ pS(⋅∣x)[ DKL(⋯ ) ]\mathbb{E}_{\hat{y}_{<t}\,\sim\, p_S(\cdot|x)}[\,D_{KL}(\cdots)\,]Ey^<tpS(x)[DKL()]

这一个下标的差异,决定了 Student 是否能在"自己真实会走到的状态"上得到指导。


四、优点

分布一致性强:训练和推理时 Student 面对的 prefix 分布相同,从根本上解决了 exposure bias 问题,在长文本生成任务上效果提升尤为明显。

错误恢复能力:Student 生成了错误的 prefix 时,Teacher 会在这个错误的上下文上给出指导,Student 能学到"走偏之后如何纠正",这是 Off-Policy 完全无法做到的。

无需高质量标注数据:训练数据完全由 Student 自己生成,只需要 prompt,不需要人工标注的 ground truth 回答,数据准备成本低。

兼容性好:可以和 SFT、RLHF 等训练范式结合,也可以作为任何序列生成模型的压缩手段。


五、缺点

计算开销大:每个训练步都需要 Student 做一次完整的自回归采样,然后 Teacher 和 Student 各做一次前向传播。相比 Off-Policy 只需要两次前向传播,训练成本大约高 2~3 倍,Teacher 越大开销越明显。

训练初期不稳定:Student 初期质量很差,采出来的序列往往是乱的,导致 Teacher 打分的上下文也很混乱,梯度信号噪声大。常见解决办法是先做 Off-Policy 热身(warm-up),让 Student 具备基本能力后再切换到 On-Policy。

梯度方差问题:尤其是使用反向 KL 时,梯度需要通过 REINFORCE 估计,方差较高,需要加 baseline 或 variance reduction 技术来稳定训练。

Teacher 必须在线可用:Teacher 需要在训练过程中实时做推理,无法提前把 Teacher 的输出缓存下来,这意味着 Teacher 必须全程驻留在显存中,对硬件资源要求较高。


六、应用场景

LLM 模型压缩:这是最主流的应用,用 70B、405B 级别的大模型作为 Teacher,蒸馏出 7B、13B 级别的小模型,使其在对话、推理等任务上尽量接近大模型的能力。MiniLLM、DistiLLM 都是这个场景下的代表工作。

长文本生成任务:摘要、故事续写、代码生成等任务中,输出序列很长,Off-Policy 的分布偏移问题会随序列长度成倍放大,OPD 的优势在这类任务上尤为突出。

推理与规划任务:对于数学解题、逻辑推理这类需要多步骤连贯输出的任务,中间步骤的错误会直接影响后续步骤,OPD 能让 Student 在自己真实会生成的推理链上得到纠正。

领域适配蒸馏:在医疗、法律、代码等垂直领域,收集高质量标注数据成本极高,而 OPD 只需要 prompt,由 Student 自己生成回答后再由领域 Teacher 打分指导,极大降低了数据门槛。

强化学习结合:OPD 和 RLHF/PPO 的结构非常相似(都是 on-policy 的策略优化),可以把 KL 蒸馏损失作为 reward shaping 的一部分,约束 RL 训练时 Student 不要偏离 Teacher 太远,防止 reward hacking。

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐