第5讲:Transformer在视觉领域的开山之作——ViT

一、从一个拼图游戏开始

1.1 你玩过"九宫格拼图"吗?

想象一张猫的照片,被切成9块:

┌─────────┬─────────┬─────────┐
│  猫耳朵  │  猫额头  │  猫耳朵  │
├─────────┼─────────┼─────────┤
│  猫眼睛  │  猫鼻子  │  猫眼睛  │
├─────────┼─────────┼─────────┤
│  猫胡须  │  猫嘴巴  │  猫胡须  │
└─────────┴─────────┴─────────┘

问题:如果我把这9块打乱顺序给你,你能认出这是猫吗?

当然可以! 因为:

  1. 每块本身包含信息(眼睛、鼻子、耳朵)
  2. 你知道这些块的相对位置(眼睛在鼻子上方)

ViT的核心思想:把图像切成这样的"补丁块",然后像处理句子里的"单词"一样,用Transformer处理这些"图像块"。


1.2 为什么要把图像当"句子"?

回忆第2-4讲:

  • Transformer在NLP上非常成功(BERT、GPT)
  • 它的核心是自注意力:每个词能看到所有其他词
  • 但Transformer是为序列数据(句子)设计的

图像的问题

  • 图像是二维网格,不是一维序列
  • 像素之间有空间关系(上下左右),不是前后关系

ViT的解决方案

别管什么二维三维,切成块,排成一行,当成"句子"处理!


二、CNN vs ViT:两种看世界的哲学

2.1 CNN:像人类视觉一样"层层抽象"

人类看图片的过程:
  第1层:看到边缘、颜色、纹理
  第2层:组合成眼睛、鼻子、耳朵
  第3层:组合成"猫脸"
  第4层:识别出"这是一只猫"

CNN的卷积层:
  第1层卷积 → 检测边缘(水平线、垂直线)
  第2层卷积 → 检测纹理(斑点、条纹)
  第3层卷积 → 检测部件(眼睛、轮子)
  第4层卷积 → 检测整体(猫、汽车)

CNN的"归纳偏置"(Inductive Bias)

偏置类型 含义 好处
局部性 卷积核只覆盖3×3区域 关注局部细节,参数少
平移不变性 猫在图片左上角或右下角,都能识别 位置不影响判断
层次性 浅层学边缘,深层学语义 结构合理,学习高效

归纳偏置 = 模型"先入为主"的假设。CNN假设"图像有局部相关性",这让它在中小数据集上表现很好。


2.2 ViT:Transformer的"一视同仁"

ViT的处理方式:

原始图像(224×224像素,3通道)
         ↓
切成16×16的补丁块 → 共(224/16)² = 196块
         ↓
每块变成向量(16×16×3 = 768维)
         ↓
加上位置编码(告诉模型"这是第几块")
         ↓
输入Transformer编码器
         ↓
[CLS]标记的输出 → 分类结果

关键区别

CNN ViT
处理方式 滑动窗口,层层卷积 切块,当成序列
感受野 逐渐扩大(从局部到全局) 第一层就是全局(自注意力)
归纳偏置 强(局部性、平移不变性) 弱(几乎没有)
数据需求 中小数据即可 需要大数据(ImageNet-21k或更大)
计算效率 有下采样,计算量可控 自注意力O(n²),196×196也不小

2.3 为什么ViT需要大数据?

核心原因:没有归纳偏置,必须从零学起

CNN学图像:
  "老师,我知道图像有局部特征,给我1000张图我就能学会"

ViT学图像:
  "老师,我不知道图像有什么规律,先给我看1000万张图,我自己摸索"

具体解释

  1. 局部性:CNN的3×3卷积核强制模型学局部模式。ViT的自注意力从全局开始,必须自己发现"相邻像素更相关"。

  2. 平移不变性:CNN的权重共享让模型自动获得平移不变性。ViT必须自己学会"左上角的猫耳朵和右下角的猫耳朵是同一个东西"。

  3. 层次性:CNN的层结构强制分层学习。ViT的所有层结构相同,必须自己组织出层次。

结论:ViT的灵活性更高(能学任意关系),但需要更多数据来"补偿"缺失的先验知识。


三、ViT架构详解

3.1 整体流程:从像素到分类

输入图像(224×224,RGB)
         │
         ▼
┌─────────────────────────────┐
│  步骤1:图像分块(Patchify)  │
│  16×16切块 → 196个补丁块      │
│  每个块:16×16×3 = 768个像素值 │
└─────────────────────────────┘
         │
         ▼
┌─────────────────────────────┐
│  步骤2:线性投影(Patch Embedding)│
│  768维像素 → 768维特征向量      │
│  (可学习的线性变换)            │
└─────────────────────────────┘
         │
         ▼
┌─────────────────────────────┐
│  步骤3:加位置编码             │
│  1D位置编码:第1块、第2块...第196块│
│  2D位置编码:考虑行列位置(可选)  │
└─────────────────────────────┘
         │
         ▼
┌─────────────────────────────┐
│  步骤4:加[CLS]标记            │
│  特殊标记,用于最终分类          │
│  类似BERT的[CLS]               │
└─────────────────────────────┘
         │
         ▼
┌─────────────────────────────┐
│  步骤5:输入Transformer编码器   │
│  12层(ViT-Base)或24层(ViT-Large)│
│  每层:多头自注意力 + MLP        │
└─────────────────────────────┘
         │
         ▼
┌─────────────────────────────┐
│  步骤6:分类头                 │
│  [CLS]标记的输出 → MLP → 类别概率 │
└─────────────────────────────┘

3.2 关键组件详解

组件1:Patch Embedding(补丁嵌入)
原始图像块(16×16×3 = 768个数字):
  [0.8, 0.2, 0.1, 0.9, 0.3, ..., 0.5]  ← 原始像素值

线性投影(768×768矩阵):
  可学习的权重矩阵 W(768, 768)

输出:
  Patch Embedding = 像素向量 × W
  形状:[196, 768]  ← 196个块,每个768维

通俗理解

把"像素值"(原始颜色数字)变成"语义向量"(模型能理解的特征)。就像把"RGB数值"翻译成"边缘、纹理、颜色"的描述。


组件2:位置编码(1D vs 2D)

1D位置编码(ViT原版用的):

把196个块按行优先排成一列:
  块0  块1  块2  ... 块13
  块14 块15 块16 ... 块27
  ...
  块182 ...          块195

位置编码:给每个块一个唯一ID(0到195)
  块0 → 位置向量0
  块1 → 位置向量1
  ...
  块195 → 位置向量195

问题:1D编码丢失了2D空间关系!

块1(第1行第2列)和块14(第2行第1列)在1D中距离很远
但实际上它们在图像中是对角相邻的!

2D位置编码(改进版):

给每个块两个坐标:
  行位置:0到13(共14行)
  列位置:0到13(共14列)

位置编码 = 行位置编码 + 列位置编码

实验发现

1D和2D位置编码效果差不多!因为Transformer的自注意力能自己学到空间关系。但2D在某些任务上略好。


组件3:[CLS] Token(分类标记)
输入序列:
  [CLS]  [块0]  [块1]  [块2]  ...  [块195]

[CLS]是一个特殊的可学习向量,类似BERT

为什么不用全局平均池化?
  BERT的经验:[CLS]经过Transformer层,自然聚合了全局信息
  实验证明:[CLS]比全局池化效果更好

3.3 不同规模的ViT

模型 层数 隐藏维度 MLP维度 头数 参数量 图像块大小
ViT-Tiny 12 192 768 3 5.7M 16×16
ViT-Small 12 384 1536 6 22M 16×16
ViT-Base 12 768 3072 12 86M 16×16
ViT-Large 24 1024 4096 16 307M 16×16
ViT-Huge 32 1280 5120 16 632M 14×14

块大小越小,块数量越多,计算量越大(O(n²))。ViT-Huge用14×14,(224/14)² = 256块。


四、动手实验:用PyTorch实现ViT

4.1 从零手写ViT核心组件

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

# ============================================
# 步骤1:图像分块(Patch Embedding)
# ============================================

class PatchEmbed(nn.Module):
    """
    把图像切成补丁块,并映射到指定维度
    
    输入:图像 [B, C, H, W]
    输出:补丁序列 [B, N, D]  (N=H*W/(patch_size²), D=embed_dim)
    """
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2  # 196块
        
        # 用卷积实现分块+线性投影(高效!)
        # 卷积核=patch_size,步长=patch_size → 天然实现不重叠分块
        self.proj = nn.Conv2d(
            in_chans, 
            embed_dim, 
            kernel_size=patch_size, 
            stride=patch_size
        )
    
    def forward(self, x):
        """
        x: [B, 3, 224, 224]
        """
        B, C, H, W = x.shape
        
        # 卷积分块: [B, 3, 224, 224] → [B, 768, 14, 14]
        x = self.proj(x)
        
        # 展平: [B, 768, 14, 14] → [B, 768, 196] → [B, 196, 768]
        x = x.flatten(2).transpose(1, 2)
        
        return x  # [B, 196, 768]


# ============================================
# 步骤2:位置编码 + [CLS] Token
# ============================================

class ViTEmbedding(nn.Module):
    """
    组合:Patch Embedding + [CLS] + 位置编码
    """
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, dropout=0.1):
        super().__init__()
        
        self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
        num_patches = self.patch_embed.num_patches  # 196
        
        # [CLS] token: 可学习的向量 [1, 1, embed_dim]
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        
        # 位置编码: [num_patches + 1, embed_dim]  (+1是因为[CLS])
        self.pos_embed = nn.Parameter(
            torch.zeros(1, num_patches + 1, embed_dim)
        )
        
        self.dropout = nn.Dropout(dropout)
        
        # 初始化
        nn.init.normal_(self.cls_token, std=0.02)
        nn.init.normal_(self.pos_embed, std=0.02)
    
    def forward(self, x):
        B = x.shape[0]
        
        # 分块: [B, 196, 768]
        x = self.patch_embed(x)
        
        # 添加[CLS]: [B, 1, 768] + [B, 196, 768] = [B, 197, 768]
        cls_tokens = self.cls_token.expand(B, -1, -1)  # 复制B份
        x = torch.cat([cls_tokens, x], dim=1)
        
        # 加位置编码
        x = x + self.pos_embed
        
        return self.dropout(x)


# ============================================
# 步骤3:Transformer编码器块
# ============================================

class MultiHeadAttention(nn.Module):
    """
    多头自注意力(复用第2讲代码,适配ViT)
    """
    def __init__(self, dim, num_heads=12, qkv_bias=True, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5
        
        # 一次性生成Q/K/V(效率更高)
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
    
    def forward(self, x):
        B, N, C = x.shape
        
        # QKV: [B, N, 3*dim] → [B, N, 3, num_heads, head_dim] → [3, B, num_heads, N, head_dim]
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]  # 每个: [B, num_heads, N, head_dim]
        
        # 注意力
        attn = (q @ k.transpose(-2, -1)) * self.scale  # [B, heads, N, N]
        attn = F.softmax(attn, dim=-1)
        attn = self.attn_drop(attn)
        
        # 加权
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)  # [B, N, dim]
        x = self.proj(x)
        x = self.proj_drop(x)
        
        return x, attn


class MLP(nn.Module):
    """
    Transformer中的前馈网络
    """
    def __init__(self, in_features, hidden_features=None, out_features=None, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = nn.GELU()  # ViT用GELU激活函数
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)
    
    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


class TransformerBlock(nn.Module):
    """
    ViT的Transformer块:注意力 + MLP + LayerNorm + 残差连接
    """
    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = MultiHeadAttention(dim, num_heads, qkv_bias, attn_drop, drop)
        self.norm2 = nn.LayerNorm(dim)
        
        mlp_hidden_dim = int(dim * mlp_ratio)  # 768 * 4 = 3072
        self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden_dim, drop=drop)
    
    def forward(self, x, return_attention=False):
        # 注意力子层(先LayerNorm,再注意力,再残差)
        attn_out, attn_weights = self.attn(self.norm1(x))
        x = x + attn_out
        
        # MLP子层
        x = x + self.mlp(self.norm2(x))
        
        if return_attention:
            return x, attn_weights
        return x


# ============================================
# 步骤4:完整的ViT模型
# ============================================

class VisionTransformer(nn.Module):
    """
    完整的Vision Transformer
    """
    def __init__(
        self,
        img_size=224,
        patch_size=16,
        in_chans=3,
        num_classes=1000,
        embed_dim=768,
        depth=12,  # Transformer层数
        num_heads=12,
        mlp_ratio=4.,
        qkv_bias=True,
        drop_rate=0.,
        attn_drop_rate=0.
    ):
        super().__init__()
        
        # 嵌入层
        self.embedding = ViTEmbedding(
            img_size, patch_size, in_chans, embed_dim, drop_rate
        )
        
        # Transformer编码器
        self.blocks = nn.ModuleList([
            TransformerBlock(
                embed_dim, num_heads, mlp_ratio, qkv_bias, 
                drop_rate, attn_drop_rate
            )
            for _ in range(depth)
        ])
        
        self.norm = nn.LayerNorm(embed_dim)
        
        # 分类头
        self.head = nn.Linear(embed_dim, num_classes)
        
        # 初始化
        self.apply(self._init_weights)
    
    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
    
    def forward(self, x, return_attention=False):
        """
        x: [B, 3, 224, 224]
        """
        # 嵌入: [B, 197, 768]
        x = self.embedding(x)
        
        # 通过Transformer层
        attentions = []
        for block in self.blocks:
            if return_attention:
                x, attn = block(x, return_attention=True)
                attentions.append(attn)
            else:
                x = block(x)
        
        # 最终LayerNorm
        x = self.norm(x)
        
        # 取[CLS]标记的输出做分类 [B, 768]
        cls_output = x[:, 0]
        
        # 分类头
        logits = self.head(cls_output)  # [B, num_classes]
        
        if return_attention:
            return logits, attentions
        return logits


# ============================================
# 测试模型
# ============================================

def test_vit():
    print("=" * 60)
    print("【测试】Vision Transformer模型构建")
    print("=" * 60)
    
    # 创建ViT-Base模型
    model = VisionTransformer(
        img_size=224,
        patch_size=16,
        num_classes=10,  # CIFAR-10
        embed_dim=768,
        depth=12,
        num_heads=12
    )
    
    # 统计参数量
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    print(f"ViT-Base配置:")
    print(f"  图像大小: 224×224")
    print(f"  补丁大小: 16×16")
    print(f"  补丁数量: 196")
    print(f"  嵌入维度: 768")
    print(f"  Transformer层数: 12")
    print(f"  注意力头数: 12")
    print(f"  总参数量: {total_params / 1e6:.1f}M")
    print(f"  可训练参数: {trainable_params / 1e6:.1f}M")
    
    # 测试前向传播
    dummy_input = torch.randn(2, 3, 224, 224)  # 2张图片
    output = model(dummy_input)
    
    print(f"\n输入形状: {dummy_input.shape}")
    print(f"输出形状: {output.shape}")  # [2, 10]
    print(f"输出示例: {output[0].detach().numpy().round(2)}")
    
    # 测试注意力可视化
    print("\n" + "=" * 60)
    print("【测试】提取注意力权重")
    print("=" * 60)
    
    logits, attentions = model(dummy_input, return_attention=True)
    
    print(f"注意力层数: {len(attentions)}")
    print(f"每层注意力形状: {attentions[0].shape}")
    # [B, num_heads, N, N] = [2, 12, 197, 197]
    
    # 可视化第1层第1个头的注意力
    attn = attentions[0][0, 0].detach().numpy()  # 第1个样本,第1层,第1个头
    
    print(f"\n第1层第1个头的注意力矩阵形状: {attn.shape}")
    print(f"[CLS]对其他块的平均注意力: {attn[0, 1:].mean():.3f}")
    print(f"块之间的平均注意力: {attn[1:, 1:].mean():.3f}")
    
    return model

if __name__ == "__main__":
    model = test_vit()
    print("\n" + "=" * 60)
    print("✅ ViT模型构建完成!")
    print("=" * 60)

4.2 实验2:加载预训练ViT,观察注意力图

import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
import numpy as np

# 尝试加载timm库(如果安装了)
try:
    import timm
    
    print("=" * 60)
    print("【实验】用timm加载预训练ViT")
    print("=" * 60)
    
    # 加载预训练的ViT-Base(在ImageNet-21k上预训练)
    model = timm.create_model(
        'vit_base_patch16_224',
        pretrained=True,  # 自动下载预训练权重
        num_classes=1000
    )
    model.eval()
    
    print(f"✅ 加载预训练ViT-Base成功!")
    print(f"模型来源: Google Research (ImageNet-21k预训练)")
    
    # 获取注意力权重
    def get_attention_map(model, image):
        """
        提取ViT的注意力图,可视化[CLS]关注图像的哪些区域
        """
        # 前向传播,获取注意力
        with torch.no_grad():
            # timm模型支持return_attention
            output = model.forward_features(image.unsqueeze(0))
            # 注意:不同版本timm接口可能不同,这里用通用方法
            
            # 替代方案:手动提取
            x = model.patch_embed(image.unsqueeze(0))
            cls_token = model.cls_token.expand(1, -1, -1)
            x = torch.cat((cls_token, x), dim=1)
            x = x + model.pos_embed
            x = model.pos_drop(x)
            
            # 通过Transformer块,收集注意力
            attentions = []
            for block in model.blocks:
                x, attn = block(x, return_attention=True)
                attentions.append(attn)
            
            return attentions
    
    # 加载一张测试图片
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])
    
    # 用CIFAR-10的一张图演示
    cifar = datasets.CIFAR10(root='./data', train=False, download=True)
    img, label = cifar[0]  # 第1张图(通常是飞机)
    
    img_tensor = transform(img)
    
    # 获取注意力
    attentions = get_attention_map(model, img_tensor)
    
    # 可视化
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    
    # 原图
    ax = axes[0, 0]
    img_np = np.array(img)
    ax.imshow(img_np)
    ax.set_title(f'Original Image (CIFAR-10: {cifar.classes[label]})')
    ax.axis('off')
    
    # 不同层的[CLS]注意力
    layers = [0, 2, 5, 8, 11]  # 第1,3,6,9,12层
    
    for idx, layer_idx in enumerate(layers):
        if idx >= 5:
            break
        
        ax = axes[idx // 3, idx % 3] if idx < 3 else axes[1, idx - 3]
        
        # 取第layer_idx层的注意力
        attn = attentions[layer_idx][0, 0, 0, 1:].reshape(14, 14).cpu().numpy()
        # [CLS]对196个patch的注意力,reshape成14×14
        
        # 上采样到224×224
        attn_resized = np.kron(attn, np.ones((16, 16)))
        
        ax.imshow(img_np)
        ax.imshow(attn_resized, alpha=0.6, cmap='hot')
        ax.set_title(f'Layer {layer_idx+1} [CLS] Attention')
        ax.axis('off')
    
    plt.tight_layout()
    plt.savefig('/mnt/agents/output/vit_attention_maps.png', dpi=150)
    plt.show()
    
    print("\n📊 注意力图已保存!")
    print("""
解读:
- 浅层(Layer 1-3):[CLS]关注分散,学局部特征
- 中层(Layer 6-9):开始关注语义相关区域
- 深层(Layer 12):高度聚焦目标物体,忽略背景
    """)
    
except ImportError:
    print("timm库未安装,运行: pip install timm")
    print("下面用自实现版本继续实验")

4.3 实验3:CIFAR-10上的对比实验(ViT vs ResNet)

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
import time

print("=" * 60)
print("【实验3】ViT vs ResNet在CIFAR-10上的对比")
print("=" * 60)

# 数据预处理
# 注意:CIFAR-10图像是32×32,需要resize到224给ViT
transform_train = transforms.Compose([
    transforms.Resize(224),  # ViT需要224×224
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

transform_test = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

# 加载CIFAR-10
trainset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform_train
)
trainloader = DataLoader(trainset, batch_size=32, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform_test
)
testloader = DataLoader(testset, batch_size=32, shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 
           'dog', 'frog', 'horse', 'ship', 'truck')

# ============================================
# 模型1:自实现ViT-Tiny(小规模,适合CIFAR-10)
# ============================================

class ViT_Tiny(nn.Module):
    """
    缩小版ViT,适合小数据集
    """
    def __init__(self, num_classes=10):
        super().__init__()
        # 用更小的配置
        self.vit = VisionTransformer(
            img_size=224,
            patch_size=16,
            num_classes=num_classes,
            embed_dim=192,      # 缩小到192(原来是768)
            depth=12,
            num_heads=3,       # 对应192/64=3
            mlp_ratio=4.,
            drop_rate=0.1
        )
    
    def forward(self, x):
        return self.vit(x)


# ============================================
# 模型2:简单ResNet(对比用)
# ============================================

class BasicBlock(nn.Module):
    expansion = 1
    
    def __init__(self, in_planes, planes, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        
        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )
    
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet18(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.in_planes = 64
        
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(64, 2, stride=1)
        self.layer2 = self._make_layer(128, 2, stride=2)
        self.layer3 = self._make_layer(256, 2, stride=2)
        self.layer4 = self._make_layer(512, 2, stride=2)
        self.linear = nn.Linear(512 * BasicBlock.expansion, num_classes)
    
    def _make_layer(self, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(BasicBlock(self.in_planes, planes, stride))
            self.in_planes = planes * BasicBlock.expansion
        return nn.Sequential(*layers)
    
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out


# ============================================
# 训练函数
# ============================================

def train_model(model, name, epochs=5):
    """
    快速训练并评估模型
    """
    print(f"\n{'='*50}")
    print(f"训练 {name}")
    print(f"{'='*50}")
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    
    # 统计
    history = {'train_loss': [], 'train_acc': [], 'test_acc': []}
    
    for epoch in range(epochs):
        # 训练
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        for i, (inputs, labels) in enumerate(trainloader):
            inputs, labels = inputs.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            
            if (i + 1) % 100 == 0:
                print(f"  Epoch {epoch+1}, Batch {i+1}, "
                      f"Loss: {loss.item():.3f}, "
                      f"Acc: {100.*correct/total:.2f}%")
        
        train_acc = 100. * correct / total
        train_loss = running_loss / len(trainloader)
        
        # 测试
        model.eval()
        correct = 0
        total = 0
        
        with torch.no_grad():
            for inputs, labels in testloader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                _, predicted = outputs.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()
        
        test_acc = 100. * correct / total
        
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['test_acc'].append(test_acc)
        
        print(f"Epoch {epoch+1} 完成: "
              f"Train Loss={train_loss:.3f}, "
              f"Train Acc={train_acc:.2f}%, "
              f"Test Acc={test_acc:.2f}%")
    
    return history


# ============================================
# 运行对比实验
# ============================================

print("注意:这个实验需要GPU,且训练时间较长(约30分钟-1小时)")
print("如果资源有限,可以减少epochs或batch_size")

# 由于CIFAR-10数据量小(5万张),ViT容易过拟合
# 这里只做1个epoch演示,实际建议用预训练权重

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")

# 创建模型
vit_model = ViT_Tiny(num_classes=10)
resnet_model = ResNet18(num_classes=10)

# 统计参数量
vit_params = sum(p.numel() for p in vit_model.parameters())
resnet_params = sum(p.numel() for p in resnet_model.parameters())

print(f"\n模型对比:")
print(f"  ViT-Tiny:   {vit_params/1e6:.1f}M 参数")
print(f"  ResNet-18:  {resnet_params/1e6:.1f}M 参数")

# 训练(减少epoch数用于演示)
epochs = 2 if device.type == 'cpu' else 5

print(f"\n训练轮数: {epochs}(CPU用2轮,GPU用5轮)")

# 训练ResNet
resnet_history = train_model(resnet_model, "ResNet-18", epochs=epochs)

# 训练ViT
vit_history = train_model(vit_model, "ViT-Tiny", epochs=epochs)

# ============================================
# 可视化对比结果
# ============================================

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# 训练准确率
ax = axes[0]
ax.plot(range(1, epochs+1), resnet_history['train_acc'], 'o-', label='ResNet-18', linewidth=2)
ax.plot(range(1, epochs+1), vit_history['train_acc'], 's-', label='ViT-Tiny', linewidth=2)
ax.set_xlabel('Epoch')
ax.set_ylabel('Train Accuracy (%)')
ax.set_title('Training Accuracy Comparison')
ax.legend()
ax.grid(True, alpha=0.3)

# 测试准确率
ax = axes[1]
ax.plot(range(1, epochs+1), resnet_history['test_acc'], 'o-', label='ResNet-18', linewidth=2)
ax.plot(range(1, epochs+1), vit_history['test_acc'], 's-', label='ViT-Tiny', linewidth=2)
ax.set_xlabel('Epoch')
ax.set_ylabel('Test Accuracy (%)')
ax.set_title('Test Accuracy Comparison')
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('/mnt/agents/output/vit_vs_resnet_cifar10.png', dpi=150)
plt.show()

print("\n" + "=" * 60)
print("📊 对比图已保存!")
print("=" * 60)

print("""
预期结果分析:

在小数据集(CIFAR-10,5万张图)上:
  ResNet-18: ✅ 通常表现更好
    - 有归纳偏置(局部性、平移不变性)
    - 适合小数据,不容易过拟合
  
  ViT-Tiny: ⚠️ 可能表现较差或需要更长时间训练
    - 没有归纳偏置,需要学"图像是什么"
    - 容易过拟合(参数多,数据少)
    - 需要预训练(ImageNet-21k)才能发挥优势

关键结论:
  "ViT需要大数据"不是空话!在CIFAR-10上,
  除非用ImageNet预训练权重,否则ResNet更实用。
""")

五、关键洞察:为什么ViT需要预训练?

5.1 数据量对比实验(概念性)

场景1:用ImageNet-1k训练(130万张图,1000类)

ResNet-50:  ✅ 收敛快,最终精度高
ViT-Base:   ⚠️ 收敛慢,精度不如ResNet

原因:ImageNet-1k对ViT来说还是"小数据"
场景2:用ImageNet-21k训练(1400万张图,21000类)

ResNet-50:  ✅ 有提升,但边际效益递减
ViT-Base:   ✅ 大幅超越ResNet!
ViT-Large:  ✅ 更强!

原因:数据量突破临界点后,ViT的灵活性优势显现
场景3:用JFT-300M训练(3亿张图)

ViT-Huge:   ✅ 碾压所有CNN!
            在ImageNet上达到SOTA,且迁移学习极强

原因:海量数据让ViT学会所有需要的"归纳偏置"

5.2 预训练+微调:ViT的正确打开方式

实际应用流程:

1. 预训练(大数据,长耗时,高成本)
   ImageNet-21k或更大 → 训练ViT-Base/Large
   需要数百GPU训练数周

2. 微调(小数据,短耗时,低成本)
   你的任务数据(如医学影像、工业质检)
   加载预训练权重,只训练分类头或少量层
   几小时到几天即可

代码示例

import timm

# 加载预训练ViT
model = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=10)

# 冻结Transformer层,只训练分类头(快速适应)
for param in model.blocks.parameters():
    param.requires_grad = False

# 或者:用较小的学习率微调全部层(更好效果)
# optimizer = optim.Adam([
#     {'params': model.blocks.parameters(), 'lr': 1e-5},  # 预训练层:小学习率
#     {'params': model.head.parameters(), 'lr': 1e-3}      # 新分类头:大学习率
# ])

六、核心总结

概念 一句话解释
Patch Embedding 把图像切块,用卷积映射成向量序列
[CLS] Token 特殊标记,聚合全局信息用于分类
1D位置编码 给每个块一个"序号",ViT原版方案
归纳偏置 模型内置的先验假设,CNN强、ViT弱
大数据需求 ViT没有归纳偏置,需要更多数据"从零学"
预训练+微调 在大数据上预训练,在小数据上微调,ViT的标准用法

七、面试高频题

Q1:ViT和CNN的本质区别是什么?

:核心在于归纳偏置。CNN通过卷积核的局部连接和权重共享,内置了局部性和平移不变性假设,适合中小数据。ViT几乎没有归纳偏置,通过自注意力学习全局关系,需要大数据(如ImageNet-21k)才能发挥优势,但上限更高。

Q2:为什么ViT用[CLS]而不是全局平均池化?

:[CLS]标记经过Transformer层,通过自注意力与所有patch交互,自然聚合了全局语义信息。实验表明[CLS]比全局池化效果更好,且与BERT的设计一致,便于统一架构理解。

Q3:ViT的位置编码为什么1D就够了?

:虽然图像是2D的,但实验发现1D和2D位置编码效果相近。因为Transformer的自注意力能自动学习空间关系,且图像分块后的行优先顺序已经隐含了部分位置信息。

Q4:ViT在小数据集上表现不好,怎么解决?

:三种方案:1)使用预训练权重(ImageNet-21k)再微调;2)数据增强(AutoAugment、Mixup);3)使用蒸馏(DeiT)或更小的模型(ViT-Tiny);4)结合CNN和Transformer(CoaT、BoTNet等混合架构)。


八、课后作业

作业1:观察不同patch_size的影响

# 尝试patch_size=8, 16, 32
# patch_size越小:
#   - 块数量越多(64×64=4096块!)
#   - 计算量越大(O(n²))
#   - 但细粒度越好

作业2:可视化注意力演变

# 修改实验2代码,观察:
# 1. 浅层注意力:是否关注局部边缘?
# 2. 深层注意力:是否聚焦目标物体?
# 3. [CLS]的注意力如何从分散到聚焦?

作业3:思考混合架构

问题:有没有办法结合CNN和ViT的优点?

提示:一些后续工作如:
  - CNN做特征提取 + ViT做全局建模(CoaT)
  - 在ViT中加入卷积(CvT, CeiT)
  - 金字塔结构(Swin Transformer, PVT)

九、下讲预告

第6讲:检测大模型——从DETR到RT-DETR/R-DETR

我们将:

  • 理解"把检测变成集合预测"的革命性思想
  • 学习匈牙利匹配算法
  • 动手用ultralytics加载RT-DETR
  • 对比Anchor-based和Anchor-free的优劣

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐