在模型蒸馏中,蒸馏损失学生损失是两个核心的损失函数组成部分。它们共同决定了学生模型的学习方向。

要理解它们,需要先明确蒸馏的基本设定:

  • 教师模型:一个已经训练好的、参数量大、性能强但推理慢的模型。它的输出提供了丰富的“暗知识”(比如,分类猫时,它虽然预测“猫”的概率最高,但“老虎”的概率也比“汽车”高)。

  • 学生模型:我们想要训练的小模型。它既要模仿教师的行为,也要学习真实的数据标签。

基于这个设定,蒸馏损失学生损失的作用如下:


1. 学生损失 —— 学习标准答案

  • 对象:学生模型的预测 ↔↔ 真实标签(Ground Truth,即硬目标)。

  • 计算公式:通常是交叉熵损失。

  • 目的:让学生模型不偏离正确答案。它确保学生模型最终的任务目标(例如正确分类猫和狗)不会出错。

  • 通俗理解:这是常规训练中本来就有的损失。它告诉学生:“不管老师怎么教,这道题的正确答案就是‘猫’,你必须学会这个。”

2. 蒸馏损失 —— 模仿老师的思维方式

  • 对象:学生模型的预测 ↔↔ 教师模型的预测(Soft Labels/Targets,即软目标)。

  • 关键技巧:为了让教师模型输出更丰富的信息,通常会使用一个温度参数(T)软化概率分布。高温会让分布更平滑(类别间的差异变小,隐藏的关系暴露出来)。

  • 计算公式:通常是KL散度或带温度的交叉熵损失。

  • 目的:让学生模型模仿教师模型的泛化能力。通过匹配教师输出的概率分布,学生能学到类别间的相似性(比如,在教师看来,猫和老虎有点像,但猫和汽车一点都不像)。

  • 通俗理解:这是蒸馏特有的损失。它告诉学生:“虽然答案是猫,但你要学老师的思维方式,知道猫在特征空间里离老虎近,离汽车远。”


它们如何协同工作?

在训练时,通常将这两个损失通过加权求和的方式结合起来,形成最终的损失函数:

L=α⋅Lsoft+(1−α)⋅LhardL=α⋅Lsoft​+(1−α)⋅Lhard​

其中:

  • LsoftLsoft​ = 蒸馏损失(学生vs教师),权重通常较高(如0.7)。

  • LhardLhard​ = 学生损失(学生vs真实标签),权重通常较低(如0.3)。

一个直观的例子:手写数字识别

假设任务是识别手写数字“2”。

  • 教师模型输出

    • 数字2:90%

    • 数字3:7%

    • 数字7:2%

    • 数字8:1%
      (因为2和3、7在笔画上确实有点像,教师捕捉到了这种相似性)

  • 真实标签(One-hot)

    • 数字2:100%

    • 其他:0%

两种损失的作用:
  1. 学生损失只关心:学生模型对“2”这个类别的预测值要达到接近100%。

  2. 蒸馏损失关心:学生模型的概率分布是否也像教师一样,对“3”有较高的值(7%),对“7”有稍低的值(2%)。

结果
如果只用学生损失,学生只会学会把2识别成2,面对一个写得很潦草的、有点像3的2时,可能就会犯错。
如果只用蒸馏损失,学生可能会模仿得很好,但万一老师犯错了(比如老师把2误判成3的概率很高),学生也会跟着错。

总结

  • 学生损失保底,让学生专注于主要任务。

  • 蒸馏损失提升,让学生从老师那里学到数据分布的更丰富信息(知识迁移),从而获得更好的泛化能力。

怎么理解模型蒸馏中的softmax-t公式

理解 Softmax-T(带温度参数 TT 的 Softmax)是理解模型蒸馏核心机制的关键。这个小小的 TT 实际上是信息放大镜知识载体

我们可以通过对比原始 Softmax 和带温度的 Softmax 来理解它的作用和必要性。


1. 原始 Softmax:输出“答案”

在常规的分类任务中,Softmax 的作用是将模型输出的“分数”(logits,即逻辑值)转换成概率分布。

pi=ezi∑jezjpi​=∑j​ezj​ezi​​

  • 特点:这是竞争性的。指数运算会放大差异。

  • 结果:对于训练好的模型,正确答案的概率会非常接近 1,错误答案的概率会非常接近 0。

  • 例子:手写数字 2。

    • 模型最后输出的分数(logits):对类别 2 打 100 分,对类别 3 打 90 分,对类别 7 打 80 分。

    • 经过原始 Softmax 计算:

      • P(2) ≈ 99.9%

      • P(3) ≈ 0.05%

      • P(7) ≈ 0.00%

    • 问题:虽然分数上 100、90、80 差距没那么大,但 Softmax 让它们“赢家通吃”。虽然教师知道“2”和“3”很像(分数接近),但经过 Softmax 处理后,这个信息几乎丢失了。这种概率叫硬目标(Hard Target),对学生来说信息量太少。


2. Softmax-T:输出“思路”

为了让教师模型不仅仅告诉学生答案,还要告诉学生思考过程(即它对哪些相似类别感到犹豫),我们引入了温度 TT。

公式变为:

pi=ezi/T∑jezj/Tpi​=∑j​ezj​/Tezi​/T​

其中 TT 就是温度参数

当 T=1 时:

就是标准的 Softmax。

当 T > 1 时(蒸馏常用 T=2 到 20):

指数运算内部除以了一个大于 1 的数,这会软化概率分布。

  • 作用:抹平了分数差异带来的指数级差距。

  • 结果

    • 原来 P(3) 只有 0.05%,现在可能上升到 15%

    • 原来 P(7) 几乎为 0%,现在可能上升到 5%

    • 原来 P(2) 从 99.9% 下降到 80%

  • 意义:现在我们得到的是一个软目标(Soft Target)。这个分布清晰地揭示了教师的内在知识:“2 有点像 3,一点点像 7,但完全不像 8。”

当 T -> ∞ 时:

所有类别的概率趋近于相等,分布变成均匀分布,信息消失。

当 T -> 0 时:

趋近于 one-hot 编码,放大正确答案的权重,退化为原始 Softmax。


3. 为什么蒸馏必须用 Softmax-T?

原因一:知识迁移的载体

学生模型学习,本质上是在学习训练数据中隐含的流形结构(即类别之间的关系)。

  • 如果没有 Softmax-T,教师给学生的就是 one-hot 标签(猫是1,狗是0)。

  • 有了 Softmax-T,教师可以告诉学生:“在我想来,猫和老虎的相似度是 0.3,猫和狗的相似度是 0.1,猫和汽车的相似度是 0.0001。”

  • 这相当于把教师庞大的参数空间中蕴含的“暗知识”提取出来,浓缩在概率分布里喂给学生。

原因二:解决梯度消失,提供更丰富的训练信号

如果直接用原始 Softmax 的输出(0.99, 0.005, 0.005...)作为损失函数的目标:

  • 对于负标签(类别3、7),预测值 0.005 和 0.001 的差别微乎其微。

  • 反向传播时,这些负标签对应的权重几乎得不到更新。

如果使用带温度的 Softmax:

  • 负标签的值变得显著(如 0.15, 0.05)。

  • 学生模型如果预测 3 的概率低了,损失会明显增加,从而迫使学生的权重调整,去匹配教师的分布。

一个公式解释全过程

完整的蒸馏损失函数通常是让学生和教师都经过 Softmax-T 处理,然后计算两者的散度(如 KL 散度):

Loss=KL Divergence(Softmax(Teacher_logits/T) || Softmax(Student_logits/T))Loss=KL Divergence(Softmax(Teacher_logits/T) ​​ Softmax(Student_logits/T))

训练时

  1. 教师学生都使用 T>1 计算 Softmax。

  2. 计算两者的分布差异(蒸馏损失)。

  3. 反向传播更新学生的参数。

  4. 推理时:学生模型恢复使用 T=1 进行预测。

总结:Softmax-T 的本质

可以把 Softmax-T 理解为一个信息调节阀

  • TT 越大,信息熵越大,负标签(错误答案)携带的信息量越大,学生能学到的类间关系越丰富。

  • TT 越小,信息熵越小,正标签(正确答案)的主导性越强,学生更专注于模仿教师的最终判断。

通过调整 TT,我们能够控制从教师那里提取知识的粒度

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐