模型蒸馏,以及softmax-T
在模型蒸馏中,蒸馏损失和学生损失是两个核心的损失函数组成部分。它们共同决定了学生模型的学习方向。
要理解它们,需要先明确蒸馏的基本设定:
-
教师模型:一个已经训练好的、参数量大、性能强但推理慢的模型。它的输出提供了丰富的“暗知识”(比如,分类猫时,它虽然预测“猫”的概率最高,但“老虎”的概率也比“汽车”高)。
-
学生模型:我们想要训练的小模型。它既要模仿教师的行为,也要学习真实的数据标签。
基于这个设定,蒸馏损失和学生损失的作用如下:
1. 学生损失 —— 学习标准答案
-
对象:学生模型的预测 ↔↔ 真实标签(Ground Truth,即硬目标)。
-
计算公式:通常是交叉熵损失。
-
目的:让学生模型不偏离正确答案。它确保学生模型最终的任务目标(例如正确分类猫和狗)不会出错。
-
通俗理解:这是常规训练中本来就有的损失。它告诉学生:“不管老师怎么教,这道题的正确答案就是‘猫’,你必须学会这个。”
2. 蒸馏损失 —— 模仿老师的思维方式
-
对象:学生模型的预测 ↔↔ 教师模型的预测(Soft Labels/Targets,即软目标)。
-
关键技巧:为了让教师模型输出更丰富的信息,通常会使用一个温度参数(T)软化概率分布。高温会让分布更平滑(类别间的差异变小,隐藏的关系暴露出来)。
-
计算公式:通常是KL散度或带温度的交叉熵损失。
-
目的:让学生模型模仿教师模型的泛化能力。通过匹配教师输出的概率分布,学生能学到类别间的相似性(比如,在教师看来,猫和老虎有点像,但猫和汽车一点都不像)。
-
通俗理解:这是蒸馏特有的损失。它告诉学生:“虽然答案是猫,但你要学老师的思维方式,知道猫在特征空间里离老虎近,离汽车远。”
它们如何协同工作?
在训练时,通常将这两个损失通过加权求和的方式结合起来,形成最终的损失函数:
L=α⋅Lsoft+(1−α)⋅LhardL=α⋅Lsoft+(1−α)⋅Lhard
其中:
-
LsoftLsoft = 蒸馏损失(学生vs教师),权重通常较高(如0.7)。
-
LhardLhard = 学生损失(学生vs真实标签),权重通常较低(如0.3)。
一个直观的例子:手写数字识别
假设任务是识别手写数字“2”。
-
教师模型输出:
-
数字2:90%
-
数字3:7%
-
数字7:2%
-
数字8:1%
(因为2和3、7在笔画上确实有点像,教师捕捉到了这种相似性)
-
-
真实标签(One-hot):
-
数字2:100%
-
其他:0%
-
两种损失的作用:
-
学生损失只关心:学生模型对“2”这个类别的预测值要达到接近100%。
-
蒸馏损失关心:学生模型的概率分布是否也像教师一样,对“3”有较高的值(7%),对“7”有稍低的值(2%)。
结果:
如果只用学生损失,学生只会学会把2识别成2,面对一个写得很潦草的、有点像3的2时,可能就会犯错。
如果只用蒸馏损失,学生可能会模仿得很好,但万一老师犯错了(比如老师把2误判成3的概率很高),学生也会跟着错。
总结:
-
学生损失是保底,让学生专注于主要任务。
-
蒸馏损失是提升,让学生从老师那里学到数据分布的更丰富信息(知识迁移),从而获得更好的泛化能力。
怎么理解模型蒸馏中的softmax-t公式
理解 Softmax-T(带温度参数 TT 的 Softmax)是理解模型蒸馏核心机制的关键。这个小小的 TT 实际上是信息放大镜和知识载体。
我们可以通过对比原始 Softmax 和带温度的 Softmax 来理解它的作用和必要性。
1. 原始 Softmax:输出“答案”
在常规的分类任务中,Softmax 的作用是将模型输出的“分数”(logits,即逻辑值)转换成概率分布。
pi=ezi∑jezjpi=∑jezjezi
-
特点:这是竞争性的。指数运算会放大差异。
-
结果:对于训练好的模型,正确答案的概率会非常接近 1,错误答案的概率会非常接近 0。
-
例子:手写数字 2。
-
模型最后输出的分数(logits):对类别 2 打 100 分,对类别 3 打 90 分,对类别 7 打 80 分。
-
经过原始 Softmax 计算:
-
P(2) ≈ 99.9%
-
P(3) ≈ 0.05%
-
P(7) ≈ 0.00%
-
-
问题:虽然分数上 100、90、80 差距没那么大,但 Softmax 让它们“赢家通吃”。虽然教师知道“2”和“3”很像(分数接近),但经过 Softmax 处理后,这个信息几乎丢失了。这种概率叫硬目标(Hard Target),对学生来说信息量太少。
-
2. Softmax-T:输出“思路”
为了让教师模型不仅仅告诉学生答案,还要告诉学生思考过程(即它对哪些相似类别感到犹豫),我们引入了温度 TT。
公式变为:
pi=ezi/T∑jezj/Tpi=∑jezj/Tezi/T
其中 TT 就是温度参数。
当 T=1 时:
就是标准的 Softmax。
当 T > 1 时(蒸馏常用 T=2 到 20):
指数运算内部除以了一个大于 1 的数,这会软化概率分布。
-
作用:抹平了分数差异带来的指数级差距。
-
结果:
-
原来 P(3) 只有 0.05%,现在可能上升到 15%。
-
原来 P(7) 几乎为 0%,现在可能上升到 5%。
-
原来 P(2) 从 99.9% 下降到 80%。
-
-
意义:现在我们得到的是一个软目标(Soft Target)。这个分布清晰地揭示了教师的内在知识:“2 有点像 3,一点点像 7,但完全不像 8。”
当 T -> ∞ 时:
所有类别的概率趋近于相等,分布变成均匀分布,信息消失。
当 T -> 0 时:
趋近于 one-hot 编码,放大正确答案的权重,退化为原始 Softmax。
3. 为什么蒸馏必须用 Softmax-T?
原因一:知识迁移的载体
学生模型学习,本质上是在学习训练数据中隐含的流形结构(即类别之间的关系)。
-
如果没有 Softmax-T,教师给学生的就是 one-hot 标签(猫是1,狗是0)。
-
有了 Softmax-T,教师可以告诉学生:“在我想来,猫和老虎的相似度是 0.3,猫和狗的相似度是 0.1,猫和汽车的相似度是 0.0001。”
-
这相当于把教师庞大的参数空间中蕴含的“暗知识”提取出来,浓缩在概率分布里喂给学生。
原因二:解决梯度消失,提供更丰富的训练信号
如果直接用原始 Softmax 的输出(0.99, 0.005, 0.005...)作为损失函数的目标:
-
对于负标签(类别3、7),预测值 0.005 和 0.001 的差别微乎其微。
-
反向传播时,这些负标签对应的权重几乎得不到更新。
如果使用带温度的 Softmax:
-
负标签的值变得显著(如 0.15, 0.05)。
-
学生模型如果预测 3 的概率低了,损失会明显增加,从而迫使学生的权重调整,去匹配教师的分布。
一个公式解释全过程
完整的蒸馏损失函数通常是让学生和教师都经过 Softmax-T 处理,然后计算两者的散度(如 KL 散度):
Loss=KL Divergence(Softmax(Teacher_logits/T) || Softmax(Student_logits/T))Loss=KL Divergence(Softmax(Teacher_logits/T) Softmax(Student_logits/T))
训练时:
-
教师和学生都使用 T>1 计算 Softmax。
-
计算两者的分布差异(蒸馏损失)。
-
反向传播更新学生的参数。
-
推理时:学生模型恢复使用 T=1 进行预测。
总结:Softmax-T 的本质
可以把 Softmax-T 理解为一个信息调节阀:
-
TT 越大,信息熵越大,负标签(错误答案)携带的信息量越大,学生能学到的类间关系越丰富。
-
TT 越小,信息熵越小,正标签(正确答案)的主导性越强,学生更专注于模仿教师的最终判断。
通过调整 TT,我们能够控制从教师那里提取知识的粒度。
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐



所有评论(0)