前言(MAE代码解析)

Masked Autoencoders for Image Modeling(MAE)是一类基于自监督学习的视觉表征学习方法,其核心思想是通过对输入图像进行随机掩码(mask),迫使模型在缺失大量信息的情况下重建原始图像,从而学习具有语义表达能力的视觉特征。
与传统的监督学习方法不同,MAE不依赖人工标注标签,而是通过“输入重建任务”构造训练目标,使模型在大规模无标注数据上进行预训练。这种机制在Vision Transformer(ViT)架构的基础上,引入了高比例掩码策略(通常为75%甚至更高),显著降低了计算成本,同时提升了表征学习效率。
本代码解析将以官方MAE实现为基础,从工程实现角度系统拆解其核心模块,包括数据处理流程、模型结构设计、Mask机制实现、Encoder-Decoder结构、位置编码生成方式以及训练与优化过程。

下面这段代码是模型入口,包含三个主要部分encoder,decoder和loss三个部分

    def forward(self, imgs, mask_ratio=0.75):
        latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio)
        pred = self.forward_decoder(latent, ids_restore)  # [N, L, p*p*3]
        loss = self.forward_loss(imgs, pred, mask)
        return loss, pred, mask

1 forword_encoder:

    def forward_encoder(self, x, mask_ratio):
        # embed patches
        x = self.patch_embed(x)

        # add pos embed w/o cls token
        x = x + self.pos_embed[:, 1:, :]

        # masking: length -> length * mask_ratio
        x, mask, ids_restore = self.random_masking(x, mask_ratio)

        # append cls token
        cls_token = self.cls_token + self.pos_embed[:, :1, :]
        cls_tokens = cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)

        # apply Transformer blocks
        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)

        return x, mask, ids_restore

在上面这段代码上包含:patch_embed,pos_embed,random_masking,cls_token和block几个部分组成。

1.1 下面这段代码是patch_embed的实现方式

输入图像
[B, 3, H, W]
        ↓
(可选 padding)
        ↓
Conv2D Patch Embedding
        ↓
[B, embed_dim, H/P, W/P]
        ↓
flatten + transpose
        ↓
[B, num_patches, embed_dim]
        ↓
LayerNorm
        ↓
输出 tokens(送入 Transformer)

def forward(self, x):
    # x: 输入图像张量
    # shape = [B, C, H, W]
    # B: batch size
    # C: 通道数(RGB=3)
    # H, W: 图像高度和宽度

    B, C, H, W = x.shape

    # ===============================
    # 1. 输入尺寸合法性检查
    # ===============================

    if self.img_size is not None:
        # 如果模型设定了固定输入尺寸(如224x224)

        if self.strict_img_size:
            # 严格模式:必须完全匹配指定尺寸

            _assert(
                H == self.img_size[0],
                f"Input height ({H}) doesn't match model ({self.img_size[0]})."
            )

            _assert(
                W == self.img_size[1],
                f"Input width ({W}) doesn't match model ({self.img_size[1]})."
            )

        elif not self.dynamic_img_pad:
            # 非严格模式,但要求尺寸必须能被patch整除

            _assert(
                H % self.patch_size[0] == 0,
                f"Input height ({H}) should be divisible by patch size ({self.patch_size[0]})."
            )

            _assert(
                W % self.patch_size[1] == 0,
                f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]})."
            )

    # ===============================
    # 2. 动态padding(工程增强功能)
    # ===============================

    if self.dynamic_img_pad:
        # 如果开启动态padding,则自动补齐到patch整数倍

        # 计算高度方向需要补多少像素
        pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0]

        # 计算宽度方向需要补多少像素
        pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1]

        # F.pad格式:(left, right, top, bottom)
        # 这里只在右侧和下侧补0
        x = F.pad(x, (0, pad_w, 0, pad_h))

    # ===============================
    # 3. Patch Embedding(核心步骤)
    # ===============================

    # 通常 self.proj = Conv2d
    # kernel_size = patch_size
    # stride = patch_size

    # 作用:
    # 将图像划分为不重叠patch,并映射到embedding空间
    #self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size,stride=patch_size,  bias=bias, **dd)

    x = self.proj(x)

    # ===============================
    # 4. 转换为Transformer token格式
    # ===============================

    if self.flatten:
        # 当前 x shape:
        # [B, embed_dim, H/P, W/P]

        # flatten(2):
        # 将空间维度展平 -> [B, embed_dim, num_patches]

        # transpose:
        # 调整为 Transformer 标准输入格式
        # [B, num_patches, embed_dim]
        x = x.flatten(2).transpose(1, 2)

    elif self.output_fmt != Format.NCHW:
        # 如果不是标准NCHW输出格式,则转换格式
        x = nchw_to(x, self.output_fmt)

    # ===============================
    # 5. 归一化层
    # ===============================

    # 对每个patch embedding做LayerNorm
    # 稳定训练,改善收敛
    x = self.norm(x)

    # ===============================
    # 6. 输出
    # ===============================

    # 输出:
    # [B, num_patches, embed_dim](标准ViT token)
    return x

1.2 随机mask机制(random masking)

def random_masking(self, x, mask_ratio):
    """
    MAE核心:对每个样本独立进行随机mask(不使用固定mask pattern)

    x: [N, L, D]
        N = batch size
        L = token长度(patch数量)
        D = embedding维度

    mask_ratio:
        被mask掉的比例(如0.75)
    """

    # ===============================
    # 1. 获取输入shape
    # ===============================
    N, L, D = x.shape  # batch, length, dim

    # ===============================
    # 2. 计算保留token数量
    # ===============================
    len_keep = int(L * (1 - mask_ratio))
    # 例如 L=196, mask_ratio=0.75 → len_keep=49

    # ===============================
    # 3. 生成随机噪声(关键步骤)
    # ===============================
    noise = torch.rand(N, L, device=x.device)
    # shape: [N, L]
    # 每个token一个随机值(0~1之间)

    # ===============================
    # 4. 用排序实现“随机打乱”
    # ===============================
    ids_shuffle = torch.argsort(noise, dim=1)

    """
    解释这一行(非常关键):

    argsort(noise) 的作用:
    - noise小 → 排在前面(保留)
    - noise大 → 排在后面(mask)

    等价于:
    👉 每个样本随机排列token顺序
    """

    # ===============================
    # 5. 计算“恢复顺序索引”
    # ===============================
    ids_restore = torch.argsort(ids_shuffle, dim=1)

    """
    ids_restore的作用:

    👉 记录“原始位置 -> shuffle后位置”的逆映射

    作用在 decoder:
    用来把mask token插回原位置
    """

    # ===============================
    # 6. 取保留token(keep部分)
    # ===============================
    ids_keep = ids_shuffle[:, :len_keep]

    """
    只保留shuffle后前 len_keep 个token
    (因为noise小的被放前面)
    """

    # ===============================
    # 7. 从x中取出保留token
    # ===============================
    x_masked = torch.gather(
        x,
        dim=1,
        index=ids_keep.unsqueeze(-1).repeat(1, 1, D)
    )

    """
    解释 torch.gather:

    x: [N, L, D]
    ids_keep: [N, len_keep]

    → 在sequence维度(dim=1)上按index取token

    输出:
    x_masked: [N, len_keep, D]
    """

    # ===============================
    # 8. 构造mask矩阵(0=保留, 1=mask)
    # ===============================
    mask = torch.ones([N, L], device=x.device)

    # 前 len_keep 设为0(保留)
    mask[:, :len_keep] = 0

    """
    注意:
    这里mask还没有考虑shuffle顺序
    只是“局部顺序版本”
    """

    # ===============================
    # 9. 把mask恢复到原token顺序
    # ===============================
    mask = torch.gather(mask, dim=1, index=ids_restore)

    """
    关键步骤:

    👉 把mask从shuffle空间 → 原始patch顺序

    最终:
    mask[i] 对应原图patch位置
    """

    # ===============================
    # 10. 返回结果
    # ===============================
    return x_masked, mask, ids_restore

在这一步,我们可以清晰的看到,cls_token是不进行mask的。

patch_embed
   ↓
+ pos_embed (patch only)
   ↓
random_masking
   ↓
cls_token append
   ↓
Transformer Encoder

1.3  self.blocks

(0-23): 24 x Block(
    (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
    (attn): Attention(
      (qkv): Linear(in_features=1024, out_features=3072, bias=True)
      (q_norm): Identity()
      (k_norm): Identity()
      (attn_drop): Dropout(p=0.0, inplace=False)
      (norm): Identity()
      (proj): Linear(in_features=1024, out_features=1024, bias=True)
      (proj_drop): Dropout(p=0.0, inplace=False)
    )
    (ls1): Identity()
    (drop_path1): Identity()
    (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
    (mlp): Mlp(
      (fc1): Linear(in_features=1024, out_features=4096, bias=True)
      (act): GELU(approximate='none')
      (drop1): Dropout(p=0.0, inplace=False)
      (norm): Identity()
      (fc2): Linear(in_features=4096, out_features=1024, bias=True)
      (drop2): Dropout(p=0.0, inplace=False)
    )
    (ls2): Identity()
    (drop_path2): Identity()
  )
)

不同的模型版本其中包含的Block个数是不一样的。

2.self.forward_decoder(latent, ids_restore)

def forward_decoder(self, x, ids_restore):
    """
    MAE Decoder前向传播

    输入:
    x:
        Encoder输出token
        shape = [B, 1 + visible_tokens, C_encoder]

        包含:
        - 第1个token: cls token
        - 后续token: encoder保留下来的visible patch tokens

    ids_restore:
        用于恢复原始patch顺序的索引
        shape = [B, num_patches]

    Decoder目标:
    利用 visible tokens + mask token
    重建完整patch序列
    """

    # ==========================================================
    # 1. 将Encoder输出映射到Decoder维度
    # ==========================================================

    # encoder_dim -> decoder_dim
    # 通常decoder维度更小(如512)
    x = self.decoder_embed(x)

    """
    输入:
        [B, 1 + visible_tokens, encoder_dim]

    输出:
        [B, 1 + visible_tokens, decoder_dim]
    """

    # ==========================================================
    # 2. 构造 mask token
    # ==========================================================

    mask_tokens = self.mask_token.repeat(
        x.shape[0],
        ids_restore.shape[1] + 1 - x.shape[1],
        1
    )

    """
    解释:

    ids_restore.shape[1]
        = 原始patch总数(例如196)

    x.shape[1]
        = 当前token数
        = 1 + visible_tokens
        (包含cls token)

    所以:
        ids_restore.shape[1] + 1 - x.shape[1]

    等价于:
        需要补充的mask token数量

    例如:

        原始patch = 196
        保留visible = 49
        当前x长度 = 50(49 + cls)

        需要mask token:
        196 + 1 - 50 = 147

    最终:
        mask_tokens = [B, 147, decoder_dim]
    """

    # ==========================================================
    # 3. visible token + mask token 拼接
    # ==========================================================

    x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1)

    """
    注意:
    x[:, 1:, :]
        去掉cls token

    当前:
        visible tokens = 49
        mask tokens = 147

    拼接后:
        x_ = [B, 196, decoder_dim]

    但:
        此时顺序还是乱的!
        因为visible token来自random masking后的shuffle顺序
    """

    # ==========================================================
    # 4. 利用 ids_restore 恢复原始patch顺序(核心)
    # ==========================================================

    x_ = torch.gather(
        x_,
        dim=1,
        index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])
    )

    """
    这是Decoder最关键的一步:

    ids_restore:
        记录了shuffle前后的映射关系

    gather作用:
        把visible token和mask token
        重新排列回原图patch顺序

    最终:
        x_ = [B, 196, decoder_dim]

    且:
        patch顺序已经恢复到原图位置
    """

    # ==========================================================
    # 5. 把 cls token 再拼回最前面
    # ==========================================================

    x = torch.cat([x[:, :1, :], x_], dim=1)

    """
    x[:, :1, :]
        取cls token

    最终:
        x = [B, 197, decoder_dim]

    结构:
        [CLS, P1, P2, ..., P196]
    """

    # ==========================================================
    # 6. 添加Decoder位置编码
    # ==========================================================

    x = x + self.decoder_pos_embed

    """
    Decoder需要完整位置编码
    因为:
        现在已经恢复完整patch结构

    decoder_pos_embed:
        shape = [1, 197, decoder_dim]
    """

    # ==========================================================
    # 7. 通过Decoder Transformer Blocks
    # ==========================================================

    for blk in self.decoder_blocks:
        x = blk(x)

    """
    Decoder Transformer:
        学习利用visible patch
        推断masked patch内容
    """

    # ==========================================================
    # 8. Decoder LayerNorm
    # ==========================================================

    x = self.decoder_norm(x)

    # ==========================================================
    # 9. 预测每个patch的像素值
    # ==========================================================

    x = self.decoder_pred(x)

    """
    decoder_pred:
        Linear(decoder_dim -> patch_size^2 * 3)

    例如:
        patch_size = 16

    输出维度:
        16 × 16 × 3 = 768

    即:
        每个token预测一个完整patch像素
    """

    # ==========================================================
    # 10. 去掉 cls token
    # ==========================================================

    x = x[:, 1:, :]

    """
    cls token不参与图像重建

    最终输出:
        [B, num_patches, patch_pixels]

    例如:
        [B, 196, 768]
    """

    # ==========================================================
    # 11. 返回Decoder预测结果
    # ==========================================================

    return x

visible patches
      +
mask tokens
      ↓
恢复完整patch序列
      ↓
Transformer推理缺失区域
      ↓
预测每个patch像素

ModuleList(
  (0-7): 8 x Block(
    (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
    (attn): Attention(
      (qkv): Linear(in_features=1024, out_features=3072, bias=True)
      (q_norm): Identity()
      (k_norm): Identity()
      (attn_drop): Dropout(p=0.0, inplace=False)
      (norm): Identity()
      (proj): Linear(in_features=1024, out_features=1024, bias=True)
      (proj_drop): Dropout(p=0.0, inplace=False)
    )
    (ls1): Identity()
    (drop_path1): Identity()
    (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
    (mlp): Mlp(
      (fc1): Linear(in_features=1024, out_features=4096, bias=True)
      (act): GELU(approximate='none')
      (drop1): Dropout(p=0.0, inplace=False)
      (norm): Identity()
      (fc2): Linear(in_features=4096, out_features=1024, bias=True)
      (drop2): Dropout(p=0.0, inplace=False)
    )
    (ls2): Identity()
    (drop_path2): Identity()
  )
)

3.self.loss

def forward_loss(self, imgs, pred, mask):
    """
    MAE 重建损失函数

    参数说明
    -------------------------------------------------

    imgs:
        原始输入图像
        shape = [N, 3, H, W]

    pred:
        Decoder预测结果
        shape = [N, L, p*p*3]

        L:
            patch数量

        p*p*3:
            每个patch展开后的像素向量

            例如:
                patch_size = 16

            则:
                16 × 16 × 3 = 768

    mask:
        patch mask矩阵
        shape = [N, L]

        其中:
            0 = visible patch(保留)
            1 = masked patch(被遮挡)
    """

    # ==========================================================
    # 1. 将原始图像切分为patch
    # ==========================================================

    target = self.patchify(imgs)

    """
    patchify后:

    imgs:
        [N, 3, H, W]

    → target:
        [N, L, patch_size^2 * 3]

    例如:

        输入:
            [2, 3, 224, 224]

        patch_size=16

    则:
        patch数量:
            14 × 14 = 196

        每个patch:
            16×16×3 = 768

    输出:
        [2, 196, 768]
    """

    # ==========================================================
    # 2. 可选:patch级像素归一化
    # ==========================================================

    if self.norm_pix_loss:

        # 计算每个patch像素均值
        mean = target.mean(dim=-1, keepdim=True)

        # 计算每个patch像素方差
        var = target.var(dim=-1, keepdim=True)

        # 标准化
        target = (target - mean) / (var + 1.e-6)**.5

    """
    为什么做pixel normalization?

    目的:
        减少不同patch亮度差异

    提高:
        MAE训练稳定性

    注意:
        这是patch级别标准化
        不是整张图标准化
    """

    # ==========================================================
    # 3. 计算像素级MSE损失
    # ==========================================================

    loss = (pred - target) ** 2

    """
    pred:
        Decoder预测patch

    target:
        原始真实patch

    这里计算:
        每个像素位置平方误差

    输出shape:
        [N, L, patch_dim]
    """

    # ==========================================================
    # 4. 对每个patch内部求平均
    # ==========================================================

    loss = loss.mean(dim=-1)

    """
    dim=-1:
        对patch内部所有像素求平均

    即:

        一个patch
        ↓
        一个loss值

    输出:
        [N, L]

    表示:
        每个patch一个重建误差
    """

    # ==========================================================
    # 5. 只在masked patch上计算loss(MAE核心)
    # ==========================================================

    loss = (loss * mask).sum() / mask.sum()

    """
    这是MAE最关键的设计之一!

    mask:
        0 = visible
        1 = masked

    所以:

        visible patch loss:
            ×0 → 不参与训练

        masked patch loss:
            ×1 → 参与训练

    最终:
        只优化被遮挡区域

    这迫使模型:
        必须通过上下文推理缺失内容
    """

    # ==========================================================
    # 6. 返回最终loss
    # ==========================================================

    return loss

4 下游任务:

class VisionTransformer(timm.models.vision_transformer.VisionTransformer):
    """ Vision Transformer with support for global average pooling
    """
    def __init__(self, global_pool=False, **kwargs):
        super(VisionTransformer, self).__init__(**kwargs)

        self.global_pool = global_pool
        if self.global_pool:
            norm_layer = kwargs['norm_layer']
            embed_dim = kwargs['embed_dim']
            self.fc_norm = norm_layer(embed_dim)

            del self.norm  # remove the original norm

    def forward_features(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)

        cls_tokens = self.cls_token.expand(B, -1, -1)  # stole cls_tokens impl from Phil Wang, thanks
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embed
        x = self.pos_drop(x)

        for blk in self.blocks:
            x = blk(x)

        if self.global_pool:
            x = x[:, 1:, :].mean(dim=1)  # global pool without cls token
            outcome = self.fc_norm(x)
        else:
            x = self.norm(x)
            outcome = x[:, 0]

        return outcome

在这一段上,

cls_token重新成为主角

不再mask,输入完整的图像

去掉decoder,只保留VIT encoder

加分类头

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐