8、BEiT-2 详解:用 VQ-KD 把 MIM 从像素重建升级到语义重建
BEiT-2 详解:用 VQ-KD 把 MIM 从像素重建升级到语义重建
BEiT-2(BEiT v2)是在 BEiT 的 Masked Image Modeling(MIM)范式上做出的关键升级:它不再让模型去“补像素”,而是让模型去“补语义 token”。其核心创新是 向量量化知识蒸馏(VQ-KD, Vector-Quantized Knowledge Distillation):用一个强教师(如 CLIP / DINO)提供的语义特征作为监督信号,训练出一个语义丰富的视觉 tokenizer(离散码本 + 量化器 + 解码器),将每个 patch 映射为紧凑的离散 token,使 MIM 的重构目标从像素级提升到语义级。
与此同时,BEiT-2 还引入 patch aggregation 来强化全局语义表达,缓解仅做 patch 级恢复导致的全局表征不足问题。
在下游任务上,BEiT-2 展示了强自监督表征能力:ImageNet-1K 上 base 微调可达约 85%+,large 可达约 87%+;线性探测可达 80% 左右;ADE20K 语义分割 large 可达 56%+ mIoU。这些结果反映出“语义级 MIM”对视觉表征学习的推动作用。
1. 背景:为什么像素级 MIM 不够?
1.1 MIM 的基本思想
MIM 预训练一般做的是:随机 mask 一部分图像 patch,让模型根据未被 mask 的上下文去预测被 mask 的内容。对 Transformer 来说,这和语言模型的 masked token prediction 在形式上很相似。
但一个关键差异是:语言 token 天然是离散语义单位,而图像像素/低层特征并不是语义单位。于是,早期 MIM 常见的目标是像素级或低层特征级重建(例如重建 RGB 或某个浅层特征),这会带来两个问题:
- 重建难度被细节主导:像素细节(纹理、噪声、光照)会让损失更关注局部外观,而非语义结构。
- 语义学习不直接:模型可能学会“补纹理”,但不一定学会“识别物体/场景/关系”。
1.2 一个直观例子:补纹理 vs 补语义
设想一张“狗在草地上”的图像被 mask 掉狗的头部区域:
- 像素级重建:最容易让损失下降的方式可能是生成“差不多的棕色毛发纹理”,并不强迫模型明确“这是狗的头部”。
- 语义级重建:如果目标是一个离散 token,且这个 token 是从教师语义特征离散化而来,那么模型更需要恢复“狗头部”这类语义一致的表示,才能预测正确 token。
BEiT-2 的动机就是把 MIM 的学习信号从“低层外观”转移到“高层语义”。
2. 总览:BEiT-2 的核心组件与训练目标
BEiT-2 可以拆成两步理解:
- 先训练一个语义视觉 tokenizer(VQ-KD):把图像 patch 的连续语义特征离散化成 token。
- 再做语义级 MIM:mask patch 后预测对应的离散 token(分类问题),而不是回归像素(回归问题)。
同时引入 patch aggregation,让全局 token(如 [CLS])更强地聚合语义。
3. 模型架构:ViT + 视觉 tokenizer + MIM Head
3.1 图像 patch 化与 ViT 表示
输入图像为
x∈RH×W×C. x \in \mathbb{R}^{H \times W \times C}. x∈RH×W×C.
将图像切分为大小为 P×PP \times PP×P 的 patch(常见 P=16P=16P=16),patch 数量为
N=HWP2. N = \frac{HW}{P^2}. N=P2HW.
每个 patch 展平后线性投影为 token embedding,形成序列输入 ViT,得到编码输出(按 patch 对应):
{hi}i=1N. \{h_i\}_{i=1}^{N}. {hi}i=1N.
其中 hih_ihi 表示第 iii 个 patch 的连续特征向量(可来自某一层输出)。
3.2 为什么需要“视觉 tokenizer”
在语言中,masked token prediction 的目标是词表中的离散 token。BEiT 系列的关键在于:为图像也构建类似“词表”的离散 token,使 MIM 变成“预测离散 token”的分类任务,从而让学习目标更语义化、更稳定。
BEiT-2 的 tokenizer 不再来自传统 dVAE 的像素压缩,而是来自 VQ-KD:从教师语义空间蒸馏并离散化。
4. VQ-KD:向量量化知识蒸馏如何构造语义 token?
4.1 VQ-KD 的目标:把连续语义空间离散化
VQ-KD 的目标是训练一个 tokenizer,使得每个 patch 的连续表示可以被映射到一个离散 code(视觉 token),而这个 token 对应的语义尽可能接近教师模型对该 patch 的语义理解。
VQ-KD 中关键对象:
-
Codebook(码本):包含 KKK 个离散向量,每个维度为 DDD:
V∈RK×D. \mathcal{V} \in \mathbb{R}^{K \times D}. V∈RK×D.
其中第 jjj 个码向量为 vjv_jvj。 -
量化器:将连续特征 hih_ihi 映射到最近的码向量索引 ziz_izi,形成离散 token。
-
Decoder:基于量化后的 token 表示去重建教师的语义特征,让 token 学到语义。
4.2 量化:最近邻查找得到离散 token
为提升稳定性与几何一致性,常使用 L2L_2L2 归一化后做最近邻匹配。对第 iii 个 patch:
zi=argminj∥ℓ2(hi)−ℓ2(vj)∥2, z_i = \arg\min_j \left\| \ell_2(h_i) - \ell_2(v_j) \right\|_2, zi=argjmin∥ℓ2(hi)−ℓ2(vj)∥2,
其中 ℓ2(⋅)\ell_2(\cdot)ℓ2(⋅) 表示 L2L_2L2 归一化。
直观理解:码本里的每个 vjv_jvj 对应一个“语义原型”,量化就是为每个 patch 选一个最像的语义原型编号 ziz_izi。这样就得到离散序列:
z=[z1;z2;⋯ ;zN]. z = [z_1; z_2; \cdots; z_N]. z=[z1;z2;⋯;zN].
4.3 教师语义:token 不重建像素,而重建教师特征
教师模型(如 CLIP/DINO)会为每个 patch 输出语义特征 tit_iti。VQ-KD 的 decoder 接收量化后的表示并输出 oio_ioi,目标是让 oio_ioi 与 tit_iti 在语义空间对齐,常用余弦相似度:
cos(oi,ti). \cos(o_i, t_i). cos(oi,ti).
这里的关键思想是:token 学到的是“能重建教师语义”的离散码,而不是“能重建像素”的离散码。
4.4 VQ-KD 的损失:语义对齐 + 码本/编码器一致性
可以将 VQ-KD 的训练目标理解为三部分:
- 语义对齐项:让 decoder 输出贴近教师语义
- 码本更新项:让码向量贴近编码器输出
- 编码器更新项:让编码器输出贴近码向量
一种常见的写法如下(使用 stop-gradient 记为 sg[⋅]\text{sg}[\cdot]sg[⋅]):
max∑x∈D∑i=1Ncos(oi,ti)−∥sg[ℓ2(hi)]−ℓ2(vzi)∥22−∥ℓ2(hi)−sg[ℓ2(vzi)]∥22. \max \sum_{x \in \mathcal{D}} \sum_{i=1}^N \cos(o_i, t_i) \mathrm{}- \left\|\text{sg}[\ell_2(h_i)] - \ell_2(v_{z_i})\right\|_2^2 \mathrm{}- \left\|\ell_2(h_i) - \text{sg}[\ell_2(v_{z_i})]\right\|_2^2. maxx∈D∑i=1∑Ncos(oi,ti)−∥sg[ℓ2(hi)]−ℓ2(vzi)∥22−∥ℓ2(hi)−sg[ℓ2(vzi)]∥22.
对每一项做非常具体的解释:
-
第一项:cos(oi,ti)\cos(o_i, t_i)cos(oi,ti)
- tit_iti 是教师模型对第 iii 个 patch 的语义特征。
- oio_ioi 是 decoder 基于量化 token 生成的输出特征。
- 最大化余弦相似度意味着:decoder 的输出方向尽量和教师语义一致,从而让 token 承载高层语义。
-
第二项:∥sg[ℓ2(hi)]−ℓ2(vzi)∥22\left\|\text{sg}[\ell_2(h_i)] - \ell_2(v_{z_i})\right\|_2^2∥sg[ℓ2(hi)]−ℓ2(vzi)∥22
- 这里 sg[ℓ2(hi)]\text{sg}[\ell_2(h_i)]sg[ℓ2(hi)] 表示把编码器输出当成常量,不让梯度回传到编码器。
- 该项让码本向量 vziv_{z_i}vzi 朝着 hih_ihi 靠拢,即更新码本以覆盖编码器输出的分布。
-
第三项:∥ℓ2(hi)−sg[ℓ2(vzi)]∥22\left\|\ell_2(h_i) - \text{sg}[\ell_2(v_{z_i})]\right\|_2^2∥ℓ2(hi)−sg[ℓ2(vzi)]∥22
- 这里把码本向量当成常量,不让梯度回传到码本。
- 该项推动编码器输出 hih_ihi 向已选定的码向量靠拢,让编码器学会产生“可量化”的特征。
这三者联合起来,就能学到一个“语义离散码本”,并让编码器输出与码本形成稳定的双向耦合。
4.5 非可微量化如何训练:直通梯度
量化操作包含 argmin\arg\minargmin 最近邻选择,本质是离散决策,不可微。反向传播会被阻断。
直通梯度(Straight-Through, ST)在这里的核心做法是:
- 前向:正常做最近邻选择得到 ziz_izi,取出码向量参与后续计算。
- 反向:把量化当作“恒等映射”近似处理,让梯度从 decoder 输入直接传回到编码器输出。
直观上,模型被允许“假装”量化步骤可微,以便端到端优化 encoder 与 decoder。
4.6 码本坍缩与 EMA 更新
向量量化常见问题是 码本坍缩:只有少数码向量被频繁选中,大部分码向量闲置,导致离散空间表达力不足。
EMA(指数移动平均)更新码本是一种常用稳健策略。其思想是用历史统计平滑更新码向量。一个简化写法是:
vt=α⋅vt−1+(1−α)⋅v^t, v_t = \alpha \cdot v_{t-1} + (1-\alpha)\cdot \hat{v}_t, vt=α⋅vt−1+(1−α)⋅v^t,
其中 v^t\hat{v}_tv^t 表示当前批次对码向量的估计更新,α∈(0,1)\alpha \in (0,1)α∈(0,1) 控制平滑程度。
EMA 的直观作用:
- 降低码本更新的方差,让码向量不会被单次 batch 的噪声剧烈拉动;
- 提升码本利用率,缓解坍缩。
5. 语义级 MIM 预训练:预测离散视觉 token
完成 VQ-KD 后,BEiT-2 进入 MIM 预训练阶段。此时每个 patch 都有一个离散目标 token ziz_izi(来自 tokenizer)。
5.1 Mask 策略与输入
定义 mask 位置集合为 M\mathcal{M}M,通常随机 mask 约 40% 的 patch。
将被 mask 的 patch 替换为 mask token,输入 ViT 编码得到对应特征 hih_ihi(mask 后上下文条件下的隐表示)。
5.2 MIM Head:对每个 masked patch 做分类
对 mask 位置 i∈Mi \in \mathcal{M}i∈M,用一个分类头预测其离散 token 分布:
p(zi∣hi)=softmaxzi(Wchi+bc). p(z_i \mid h_i) = \text{softmax}_{z_i}(W_c h_i + b_c). p(zi∣hi)=softmaxzi(Wchi+bc).
这里:
- Wc,bcW_c, b_cWc,bc 是分类器参数;
- softmax 的类别数对应码本大小 KKK;
- 任务是“在 KKK 个离散 token 中选对一个”,而不是回归像素。
5.3 MIM 损失:交叉熵只作用于 mask 位置
标准写法为负对数似然(交叉熵):
LMIM=−∑x∈D∑i∈Mlogp(zi∣xM). \mathcal{L}_{\text{MIM}} = \mathrm{}- \sum_{x \in \mathcal{D}} \sum_{i \in \mathcal{M}} \log p(z_i \mid x_{\mathcal{M}}). LMIM=−x∈D∑i∈M∑logp(zi∣xM).
为什么语义 token 预测对表征学习更友好:
- 分类目标更稳定,避免像素回归对细节过敏;
- token 来自教师语义蒸馏,逼迫模型学习语义一致性;
- mask 预测需要利用上下文结构,从而学到更强的关系建模能力。
6. Patch Aggregation:为什么要强化全局语义?
6.1 问题:仅做 patch 级预测可能忽略全局
MIM 的监督发生在 patch 位置上,模型可以在较大程度上依赖局部邻域信息完成预测,导致 [CLS] 等全局表征学习不充分。对于分类、检索、场景理解等任务,全局语义往往非常关键。
6.2 思路:让 [CLS] 参与预测并聚合中间层特征
BEiT-2 引入 patch aggregation:让最后层的 [CLS] token 与中间层 patch token 结合,进入一个浅层 decoder 做额外的 mask 预测,从而迫使 [CLS] 学到全局语义聚合能力。
设最后一层的 [CLS] 表示为 hCLSLh_{\text{CLS}}^LhCLSL,某个中间层 lll 的 patch 表示为 {hil}i=1N\{h_i^l\}_{i=1}^N{hil}i=1N,拼接得到:
S=[hCLSL;h1l;⋯ ;hNl]. S = [h_{\text{CLS}}^L; h_1^l; \cdots; h_N^l]. S=[hCLSL;h1l;⋯;hNl].
浅层 decoder 基于 SSS 输出对 token 的预测:
p(z∣S)=softmaxz(WcS+bc). p(z \mid S) = \text{softmax}_z(W_c S + b_c). p(z∣S)=softmaxz(WcS+bc).
最终训练时将该分支损失与主 MIM 损失联合优化(记额外项为 LMIMc\mathcal{L}_{\text{MIM}}^cLMIMc),整体目标是同时提升:
- patch 局部语义恢复能力;
- [CLS] 的全局聚合与表达能力。
6.3 一个直观例子:局部信息可解 vs 需要全局语义
考虑“这张图是室内还是室外”:
- 很多 patch(比如一小块墙面纹理)在局部上并不决定室内/室外;
- 需要聚合多个区域(天空/窗户/地面/光照)形成全局判断;
- 强化 [CLS] 聚合能让预训练更适配下游分类与场景任务。
7. 实验与结果解读:为什么 BEiT-2 强?
7.1 预训练设置的关键点
常见设置包括:
- 预训练数据:ImageNet-1K 去标签数据(仅用图像本身)
- 分辨率:224×224
- backbone:ViT-B/16、ViT-L/16
- 预训练 epoch:从几百到上千(例如 300e、1600e)
7.2 典型结果现象与原因
BEiT-2 的结果优势通常体现在三个层面:
-
微调分类更强
- 原因:语义 token 预测让 backbone 更早学到高层语义结构,对分类边界更友好。
-
线性探测更强
- 线性探测要求 backbone 输出本身就“可线性分割”,对表征质量要求高。
- 原因:token 来自教师语义蒸馏,提升了表征的语义可分性。
-
密集预测(如分割)更强
- 语义分割依赖每个位置的语义一致性与边界感知。
- 原因:token 级语义监督对 patch 表示更“语义对齐”,比像素级重建更贴近分割需求。
7.3 教师模型与 patch aggregation 的消融意义
-
教师模型选择(CLIP vs DINO)
- 教师提供的语义空间不同,会影响 token 的语义属性。
- CLIP 往往对语义概念与跨实例对齐更强,因此常见现象是 CLIP 教师更占优。
-
patch aggregation 的作用
- 常见现象是线性探测显著提升,说明全局语义聚合对“无监督表征可用性”非常关键。
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐

所有评论(0)