14 MAE代码解析:Masked Autoencoders Are Scalable Vision Learners
前言(MAE代码解析)
Masked Autoencoders for Image Modeling(MAE)是一类基于自监督学习的视觉表征学习方法,其核心思想是通过对输入图像进行随机掩码(mask),迫使模型在缺失大量信息的情况下重建原始图像,从而学习具有语义表达能力的视觉特征。
与传统的监督学习方法不同,MAE不依赖人工标注标签,而是通过“输入重建任务”构造训练目标,使模型在大规模无标注数据上进行预训练。这种机制在Vision Transformer(ViT)架构的基础上,引入了高比例掩码策略(通常为75%甚至更高),显著降低了计算成本,同时提升了表征学习效率。
本代码解析将以官方MAE实现为基础,从工程实现角度系统拆解其核心模块,包括数据处理流程、模型结构设计、Mask机制实现、Encoder-Decoder结构、位置编码生成方式以及训练与优化过程。
下面这段代码是模型入口,包含三个主要部分encoder,decoder和loss三个部分
def forward(self, imgs, mask_ratio=0.75):
latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio)
pred = self.forward_decoder(latent, ids_restore) # [N, L, p*p*3]
loss = self.forward_loss(imgs, pred, mask)
return loss, pred, mask
1 forword_encoder:
def forward_encoder(self, x, mask_ratio):
# embed patches
x = self.patch_embed(x)
# add pos embed w/o cls token
x = x + self.pos_embed[:, 1:, :]
# masking: length -> length * mask_ratio
x, mask, ids_restore = self.random_masking(x, mask_ratio)
# append cls token
cls_token = self.cls_token + self.pos_embed[:, :1, :]
cls_tokens = cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
# apply Transformer blocks
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
return x, mask, ids_restore
在上面这段代码上包含:patch_embed,pos_embed,random_masking,cls_token和block几个部分组成。
1.1 下面这段代码是patch_embed的实现方式
输入图像
[B, 3, H, W]
↓
(可选 padding)
↓
Conv2D Patch Embedding
↓
[B, embed_dim, H/P, W/P]
↓
flatten + transpose
↓
[B, num_patches, embed_dim]
↓
LayerNorm
↓
输出 tokens(送入 Transformer)
def forward(self, x):
# x: 输入图像张量
# shape = [B, C, H, W]
# B: batch size
# C: 通道数(RGB=3)
# H, W: 图像高度和宽度
B, C, H, W = x.shape
# ===============================
# 1. 输入尺寸合法性检查
# ===============================
if self.img_size is not None:
# 如果模型设定了固定输入尺寸(如224x224)
if self.strict_img_size:
# 严格模式:必须完全匹配指定尺寸
_assert(
H == self.img_size[0],
f"Input height ({H}) doesn't match model ({self.img_size[0]})."
)
_assert(
W == self.img_size[1],
f"Input width ({W}) doesn't match model ({self.img_size[1]})."
)
elif not self.dynamic_img_pad:
# 非严格模式,但要求尺寸必须能被patch整除
_assert(
H % self.patch_size[0] == 0,
f"Input height ({H}) should be divisible by patch size ({self.patch_size[0]})."
)
_assert(
W % self.patch_size[1] == 0,
f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]})."
)
# ===============================
# 2. 动态padding(工程增强功能)
# ===============================
if self.dynamic_img_pad:
# 如果开启动态padding,则自动补齐到patch整数倍
# 计算高度方向需要补多少像素
pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0]
# 计算宽度方向需要补多少像素
pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1]
# F.pad格式:(left, right, top, bottom)
# 这里只在右侧和下侧补0
x = F.pad(x, (0, pad_w, 0, pad_h))
# ===============================
# 3. Patch Embedding(核心步骤)
# ===============================
# 通常 self.proj = Conv2d
# kernel_size = patch_size
# stride = patch_size
# 作用:
# 将图像划分为不重叠patch,并映射到embedding空间
#self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size,stride=patch_size, bias=bias, **dd)
x = self.proj(x)
# ===============================
# 4. 转换为Transformer token格式
# ===============================
if self.flatten:
# 当前 x shape:
# [B, embed_dim, H/P, W/P]
# flatten(2):
# 将空间维度展平 -> [B, embed_dim, num_patches]
# transpose:
# 调整为 Transformer 标准输入格式
# [B, num_patches, embed_dim]
x = x.flatten(2).transpose(1, 2)
elif self.output_fmt != Format.NCHW:
# 如果不是标准NCHW输出格式,则转换格式
x = nchw_to(x, self.output_fmt)
# ===============================
# 5. 归一化层
# ===============================
# 对每个patch embedding做LayerNorm
# 稳定训练,改善收敛
x = self.norm(x)
# ===============================
# 6. 输出
# ===============================
# 输出:
# [B, num_patches, embed_dim](标准ViT token)
return x
1.2 随机mask机制(random masking)
def random_masking(self, x, mask_ratio):
"""
MAE核心:对每个样本独立进行随机mask(不使用固定mask pattern)
x: [N, L, D]
N = batch size
L = token长度(patch数量)
D = embedding维度
mask_ratio:
被mask掉的比例(如0.75)
"""
# ===============================
# 1. 获取输入shape
# ===============================
N, L, D = x.shape # batch, length, dim
# ===============================
# 2. 计算保留token数量
# ===============================
len_keep = int(L * (1 - mask_ratio))
# 例如 L=196, mask_ratio=0.75 → len_keep=49
# ===============================
# 3. 生成随机噪声(关键步骤)
# ===============================
noise = torch.rand(N, L, device=x.device)
# shape: [N, L]
# 每个token一个随机值(0~1之间)
# ===============================
# 4. 用排序实现“随机打乱”
# ===============================
ids_shuffle = torch.argsort(noise, dim=1)
"""
解释这一行(非常关键):
argsort(noise) 的作用:
- noise小 → 排在前面(保留)
- noise大 → 排在后面(mask)
等价于:
👉 每个样本随机排列token顺序
"""
# ===============================
# 5. 计算“恢复顺序索引”
# ===============================
ids_restore = torch.argsort(ids_shuffle, dim=1)
"""
ids_restore的作用:
👉 记录“原始位置 -> shuffle后位置”的逆映射
作用在 decoder:
用来把mask token插回原位置
"""
# ===============================
# 6. 取保留token(keep部分)
# ===============================
ids_keep = ids_shuffle[:, :len_keep]
"""
只保留shuffle后前 len_keep 个token
(因为noise小的被放前面)
"""
# ===============================
# 7. 从x中取出保留token
# ===============================
x_masked = torch.gather(
x,
dim=1,
index=ids_keep.unsqueeze(-1).repeat(1, 1, D)
)
"""
解释 torch.gather:
x: [N, L, D]
ids_keep: [N, len_keep]
→ 在sequence维度(dim=1)上按index取token
输出:
x_masked: [N, len_keep, D]
"""
# ===============================
# 8. 构造mask矩阵(0=保留, 1=mask)
# ===============================
mask = torch.ones([N, L], device=x.device)
# 前 len_keep 设为0(保留)
mask[:, :len_keep] = 0
"""
注意:
这里mask还没有考虑shuffle顺序
只是“局部顺序版本”
"""
# ===============================
# 9. 把mask恢复到原token顺序
# ===============================
mask = torch.gather(mask, dim=1, index=ids_restore)
"""
关键步骤:
👉 把mask从shuffle空间 → 原始patch顺序
最终:
mask[i] 对应原图patch位置
"""
# ===============================
# 10. 返回结果
# ===============================
return x_masked, mask, ids_restore
在这一步,我们可以清晰的看到,cls_token是不进行mask的。
patch_embed
↓
+ pos_embed (patch only)
↓
random_masking
↓
cls_token append
↓
Transformer Encoder
1.3 self.blocks
(0-23): 24 x Block(
(norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(qkv): Linear(in_features=1024, out_features=3072, bias=True)
(q_norm): Identity()
(k_norm): Identity()
(attn_drop): Dropout(p=0.0, inplace=False)
(norm): Identity()
(proj): Linear(in_features=1024, out_features=1024, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
)
(ls1): Identity()
(drop_path1): Identity()
(norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=1024, out_features=4096, bias=True)
(act): GELU(approximate='none')
(drop1): Dropout(p=0.0, inplace=False)
(norm): Identity()
(fc2): Linear(in_features=4096, out_features=1024, bias=True)
(drop2): Dropout(p=0.0, inplace=False)
)
(ls2): Identity()
(drop_path2): Identity()
)
)
不同的模型版本其中包含的Block个数是不一样的。
2.self.forward_decoder(latent, ids_restore)
def forward_decoder(self, x, ids_restore):
"""
MAE Decoder前向传播
输入:
x:
Encoder输出token
shape = [B, 1 + visible_tokens, C_encoder]
包含:
- 第1个token: cls token
- 后续token: encoder保留下来的visible patch tokens
ids_restore:
用于恢复原始patch顺序的索引
shape = [B, num_patches]
Decoder目标:
利用 visible tokens + mask token
重建完整patch序列
"""
# ==========================================================
# 1. 将Encoder输出映射到Decoder维度
# ==========================================================
# encoder_dim -> decoder_dim
# 通常decoder维度更小(如512)
x = self.decoder_embed(x)
"""
输入:
[B, 1 + visible_tokens, encoder_dim]
输出:
[B, 1 + visible_tokens, decoder_dim]
"""
# ==========================================================
# 2. 构造 mask token
# ==========================================================
mask_tokens = self.mask_token.repeat(
x.shape[0],
ids_restore.shape[1] + 1 - x.shape[1],
1
)
"""
解释:
ids_restore.shape[1]
= 原始patch总数(例如196)
x.shape[1]
= 当前token数
= 1 + visible_tokens
(包含cls token)
所以:
ids_restore.shape[1] + 1 - x.shape[1]
等价于:
需要补充的mask token数量
例如:
原始patch = 196
保留visible = 49
当前x长度 = 50(49 + cls)
需要mask token:
196 + 1 - 50 = 147
最终:
mask_tokens = [B, 147, decoder_dim]
"""
# ==========================================================
# 3. visible token + mask token 拼接
# ==========================================================
x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1)
"""
注意:
x[:, 1:, :]
去掉cls token
当前:
visible tokens = 49
mask tokens = 147
拼接后:
x_ = [B, 196, decoder_dim]
但:
此时顺序还是乱的!
因为visible token来自random masking后的shuffle顺序
"""
# ==========================================================
# 4. 利用 ids_restore 恢复原始patch顺序(核心)
# ==========================================================
x_ = torch.gather(
x_,
dim=1,
index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])
)
"""
这是Decoder最关键的一步:
ids_restore:
记录了shuffle前后的映射关系
gather作用:
把visible token和mask token
重新排列回原图patch顺序
最终:
x_ = [B, 196, decoder_dim]
且:
patch顺序已经恢复到原图位置
"""
# ==========================================================
# 5. 把 cls token 再拼回最前面
# ==========================================================
x = torch.cat([x[:, :1, :], x_], dim=1)
"""
x[:, :1, :]
取cls token
最终:
x = [B, 197, decoder_dim]
结构:
[CLS, P1, P2, ..., P196]
"""
# ==========================================================
# 6. 添加Decoder位置编码
# ==========================================================
x = x + self.decoder_pos_embed
"""
Decoder需要完整位置编码
因为:
现在已经恢复完整patch结构
decoder_pos_embed:
shape = [1, 197, decoder_dim]
"""
# ==========================================================
# 7. 通过Decoder Transformer Blocks
# ==========================================================
for blk in self.decoder_blocks:
x = blk(x)
"""
Decoder Transformer:
学习利用visible patch
推断masked patch内容
"""
# ==========================================================
# 8. Decoder LayerNorm
# ==========================================================
x = self.decoder_norm(x)
# ==========================================================
# 9. 预测每个patch的像素值
# ==========================================================
x = self.decoder_pred(x)
"""
decoder_pred:
Linear(decoder_dim -> patch_size^2 * 3)
例如:
patch_size = 16
输出维度:
16 × 16 × 3 = 768
即:
每个token预测一个完整patch像素
"""
# ==========================================================
# 10. 去掉 cls token
# ==========================================================
x = x[:, 1:, :]
"""
cls token不参与图像重建
最终输出:
[B, num_patches, patch_pixels]
例如:
[B, 196, 768]
"""
# ==========================================================
# 11. 返回Decoder预测结果
# ==========================================================
return x
visible patches
+
mask tokens
↓
恢复完整patch序列
↓
Transformer推理缺失区域
↓
预测每个patch像素
ModuleList(
(0-7): 8 x Block(
(norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
(attn): Attention(
(qkv): Linear(in_features=1024, out_features=3072, bias=True)
(q_norm): Identity()
(k_norm): Identity()
(attn_drop): Dropout(p=0.0, inplace=False)
(norm): Identity()
(proj): Linear(in_features=1024, out_features=1024, bias=True)
(proj_drop): Dropout(p=0.0, inplace=False)
)
(ls1): Identity()
(drop_path1): Identity()
(norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
(mlp): Mlp(
(fc1): Linear(in_features=1024, out_features=4096, bias=True)
(act): GELU(approximate='none')
(drop1): Dropout(p=0.0, inplace=False)
(norm): Identity()
(fc2): Linear(in_features=4096, out_features=1024, bias=True)
(drop2): Dropout(p=0.0, inplace=False)
)
(ls2): Identity()
(drop_path2): Identity()
)
)
3.self.loss
def forward_loss(self, imgs, pred, mask):
"""
MAE 重建损失函数
参数说明
-------------------------------------------------
imgs:
原始输入图像
shape = [N, 3, H, W]
pred:
Decoder预测结果
shape = [N, L, p*p*3]
L:
patch数量
p*p*3:
每个patch展开后的像素向量
例如:
patch_size = 16
则:
16 × 16 × 3 = 768
mask:
patch mask矩阵
shape = [N, L]
其中:
0 = visible patch(保留)
1 = masked patch(被遮挡)
"""
# ==========================================================
# 1. 将原始图像切分为patch
# ==========================================================
target = self.patchify(imgs)
"""
patchify后:
imgs:
[N, 3, H, W]
→ target:
[N, L, patch_size^2 * 3]
例如:
输入:
[2, 3, 224, 224]
patch_size=16
则:
patch数量:
14 × 14 = 196
每个patch:
16×16×3 = 768
输出:
[2, 196, 768]
"""
# ==========================================================
# 2. 可选:patch级像素归一化
# ==========================================================
if self.norm_pix_loss:
# 计算每个patch像素均值
mean = target.mean(dim=-1, keepdim=True)
# 计算每个patch像素方差
var = target.var(dim=-1, keepdim=True)
# 标准化
target = (target - mean) / (var + 1.e-6)**.5
"""
为什么做pixel normalization?
目的:
减少不同patch亮度差异
提高:
MAE训练稳定性
注意:
这是patch级别标准化
不是整张图标准化
"""
# ==========================================================
# 3. 计算像素级MSE损失
# ==========================================================
loss = (pred - target) ** 2
"""
pred:
Decoder预测patch
target:
原始真实patch
这里计算:
每个像素位置平方误差
输出shape:
[N, L, patch_dim]
"""
# ==========================================================
# 4. 对每个patch内部求平均
# ==========================================================
loss = loss.mean(dim=-1)
"""
dim=-1:
对patch内部所有像素求平均
即:
一个patch
↓
一个loss值
输出:
[N, L]
表示:
每个patch一个重建误差
"""
# ==========================================================
# 5. 只在masked patch上计算loss(MAE核心)
# ==========================================================
loss = (loss * mask).sum() / mask.sum()
"""
这是MAE最关键的设计之一!
mask:
0 = visible
1 = masked
所以:
visible patch loss:
×0 → 不参与训练
masked patch loss:
×1 → 参与训练
最终:
只优化被遮挡区域
这迫使模型:
必须通过上下文推理缺失内容
"""
# ==========================================================
# 6. 返回最终loss
# ==========================================================
return loss
4 下游任务:
class VisionTransformer(timm.models.vision_transformer.VisionTransformer):
""" Vision Transformer with support for global average pooling
"""
def __init__(self, global_pool=False, **kwargs):
super(VisionTransformer, self).__init__(**kwargs)
self.global_pool = global_pool
if self.global_pool:
norm_layer = kwargs['norm_layer']
embed_dim = kwargs['embed_dim']
self.fc_norm = norm_layer(embed_dim)
del self.norm # remove the original norm
def forward_features(self, x):
B = x.shape[0]
x = self.patch_embed(x)
cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
x = torch.cat((cls_tokens, x), dim=1)
x = x + self.pos_embed
x = self.pos_drop(x)
for blk in self.blocks:
x = blk(x)
if self.global_pool:
x = x[:, 1:, :].mean(dim=1) # global pool without cls token
outcome = self.fc_norm(x)
else:
x = self.norm(x)
outcome = x[:, 0]
return outcome
在这一段上,
cls_token重新成为主角
不再mask,输入完整的图像
去掉decoder,只保留VIT encoder
加分类头
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐

所有评论(0)