第5讲:Transformer在视觉 领域的开山之作——ViT
第5讲:Transformer在视觉领域的开山之作——ViT
一、从一个拼图游戏开始
1.1 你玩过"九宫格拼图"吗?
想象一张猫的照片,被切成9块:
┌─────────┬─────────┬─────────┐
│ 猫耳朵 │ 猫额头 │ 猫耳朵 │
├─────────┼─────────┼─────────┤
│ 猫眼睛 │ 猫鼻子 │ 猫眼睛 │
├─────────┼─────────┼─────────┤
│ 猫胡须 │ 猫嘴巴 │ 猫胡须 │
└─────────┴─────────┴─────────┘
问题:如果我把这9块打乱顺序给你,你能认出这是猫吗?
当然可以! 因为:
- 每块本身包含信息(眼睛、鼻子、耳朵)
- 你知道这些块的相对位置(眼睛在鼻子上方)
ViT的核心思想:把图像切成这样的"补丁块",然后像处理句子里的"单词"一样,用Transformer处理这些"图像块"。
1.2 为什么要把图像当"句子"?
回忆第2-4讲:
- Transformer在NLP上非常成功(BERT、GPT)
- 它的核心是自注意力:每个词能看到所有其他词
- 但Transformer是为序列数据(句子)设计的
图像的问题:
- 图像是二维网格,不是一维序列
- 像素之间有空间关系(上下左右),不是前后关系
ViT的解决方案:
别管什么二维三维,切成块,排成一行,当成"句子"处理!
二、CNN vs ViT:两种看世界的哲学
2.1 CNN:像人类视觉一样"层层抽象"
人类看图片的过程:
第1层:看到边缘、颜色、纹理
第2层:组合成眼睛、鼻子、耳朵
第3层:组合成"猫脸"
第4层:识别出"这是一只猫"
CNN的卷积层:
第1层卷积 → 检测边缘(水平线、垂直线)
第2层卷积 → 检测纹理(斑点、条纹)
第3层卷积 → 检测部件(眼睛、轮子)
第4层卷积 → 检测整体(猫、汽车)
CNN的"归纳偏置"(Inductive Bias):
| 偏置类型 | 含义 | 好处 |
|---|---|---|
| 局部性 | 卷积核只覆盖3×3区域 | 关注局部细节,参数少 |
| 平移不变性 | 猫在图片左上角或右下角,都能识别 | 位置不影响判断 |
| 层次性 | 浅层学边缘,深层学语义 | 结构合理,学习高效 |
归纳偏置 = 模型"先入为主"的假设。CNN假设"图像有局部相关性",这让它在中小数据集上表现很好。
2.2 ViT:Transformer的"一视同仁"
ViT的处理方式:
原始图像(224×224像素,3通道)
↓
切成16×16的补丁块 → 共(224/16)² = 196块
↓
每块变成向量(16×16×3 = 768维)
↓
加上位置编码(告诉模型"这是第几块")
↓
输入Transformer编码器
↓
[CLS]标记的输出 → 分类结果
关键区别:
| CNN | ViT | |
|---|---|---|
| 处理方式 | 滑动窗口,层层卷积 | 切块,当成序列 |
| 感受野 | 逐渐扩大(从局部到全局) | 第一层就是全局(自注意力) |
| 归纳偏置 | 强(局部性、平移不变性) | 弱(几乎没有) |
| 数据需求 | 中小数据即可 | 需要大数据(ImageNet-21k或更大) |
| 计算效率 | 有下采样,计算量可控 | 自注意力O(n²),196×196也不小 |
2.3 为什么ViT需要大数据?
核心原因:没有归纳偏置,必须从零学起
CNN学图像:
"老师,我知道图像有局部特征,给我1000张图我就能学会"
ViT学图像:
"老师,我不知道图像有什么规律,先给我看1000万张图,我自己摸索"
具体解释:
-
局部性:CNN的3×3卷积核强制模型学局部模式。ViT的自注意力从全局开始,必须自己发现"相邻像素更相关"。
-
平移不变性:CNN的权重共享让模型自动获得平移不变性。ViT必须自己学会"左上角的猫耳朵和右下角的猫耳朵是同一个东西"。
-
层次性:CNN的层结构强制分层学习。ViT的所有层结构相同,必须自己组织出层次。
结论:ViT的灵活性更高(能学任意关系),但需要更多数据来"补偿"缺失的先验知识。
三、ViT架构详解
3.1 整体流程:从像素到分类
输入图像(224×224,RGB)
│
▼
┌─────────────────────────────┐
│ 步骤1:图像分块(Patchify) │
│ 16×16切块 → 196个补丁块 │
│ 每个块:16×16×3 = 768个像素值 │
└─────────────────────────────┘
│
▼
┌─────────────────────────────┐
│ 步骤2:线性投影(Patch Embedding)│
│ 768维像素 → 768维特征向量 │
│ (可学习的线性变换) │
└─────────────────────────────┘
│
▼
┌─────────────────────────────┐
│ 步骤3:加位置编码 │
│ 1D位置编码:第1块、第2块...第196块│
│ 2D位置编码:考虑行列位置(可选) │
└─────────────────────────────┘
│
▼
┌─────────────────────────────┐
│ 步骤4:加[CLS]标记 │
│ 特殊标记,用于最终分类 │
│ 类似BERT的[CLS] │
└─────────────────────────────┘
│
▼
┌─────────────────────────────┐
│ 步骤5:输入Transformer编码器 │
│ 12层(ViT-Base)或24层(ViT-Large)│
│ 每层:多头自注意力 + MLP │
└─────────────────────────────┘
│
▼
┌─────────────────────────────┐
│ 步骤6:分类头 │
│ [CLS]标记的输出 → MLP → 类别概率 │
└─────────────────────────────┘
3.2 关键组件详解
组件1:Patch Embedding(补丁嵌入)
原始图像块(16×16×3 = 768个数字):
[0.8, 0.2, 0.1, 0.9, 0.3, ..., 0.5] ← 原始像素值
线性投影(768×768矩阵):
可学习的权重矩阵 W(768, 768)
输出:
Patch Embedding = 像素向量 × W
形状:[196, 768] ← 196个块,每个768维
通俗理解:
把"像素值"(原始颜色数字)变成"语义向量"(模型能理解的特征)。就像把"RGB数值"翻译成"边缘、纹理、颜色"的描述。
组件2:位置编码(1D vs 2D)
1D位置编码(ViT原版用的):
把196个块按行优先排成一列:
块0 块1 块2 ... 块13
块14 块15 块16 ... 块27
...
块182 ... 块195
位置编码:给每个块一个唯一ID(0到195)
块0 → 位置向量0
块1 → 位置向量1
...
块195 → 位置向量195
问题:1D编码丢失了2D空间关系!
块1(第1行第2列)和块14(第2行第1列)在1D中距离很远
但实际上它们在图像中是对角相邻的!
2D位置编码(改进版):
给每个块两个坐标:
行位置:0到13(共14行)
列位置:0到13(共14列)
位置编码 = 行位置编码 + 列位置编码
实验发现:
1D和2D位置编码效果差不多!因为Transformer的自注意力能自己学到空间关系。但2D在某些任务上略好。
组件3:[CLS] Token(分类标记)
输入序列:
[CLS] [块0] [块1] [块2] ... [块195]
[CLS]是一个特殊的可学习向量,类似BERT
为什么不用全局平均池化?
BERT的经验:[CLS]经过Transformer层,自然聚合了全局信息
实验证明:[CLS]比全局池化效果更好
3.3 不同规模的ViT
| 模型 | 层数 | 隐藏维度 | MLP维度 | 头数 | 参数量 | 图像块大小 |
|---|---|---|---|---|---|---|
| ViT-Tiny | 12 | 192 | 768 | 3 | 5.7M | 16×16 |
| ViT-Small | 12 | 384 | 1536 | 6 | 22M | 16×16 |
| ViT-Base | 12 | 768 | 3072 | 12 | 86M | 16×16 |
| ViT-Large | 24 | 1024 | 4096 | 16 | 307M | 16×16 |
| ViT-Huge | 32 | 1280 | 5120 | 16 | 632M | 14×14 |
块大小越小,块数量越多,计算量越大(O(n²))。ViT-Huge用14×14,(224/14)² = 256块。
四、动手实验:用PyTorch实现ViT
4.1 从零手写ViT核心组件
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
# ============================================
# 步骤1:图像分块(Patch Embedding)
# ============================================
class PatchEmbed(nn.Module):
"""
把图像切成补丁块,并映射到指定维度
输入:图像 [B, C, H, W]
输出:补丁序列 [B, N, D] (N=H*W/(patch_size²), D=embed_dim)
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = (img_size // patch_size) ** 2 # 196块
# 用卷积实现分块+线性投影(高效!)
# 卷积核=patch_size,步长=patch_size → 天然实现不重叠分块
self.proj = nn.Conv2d(
in_chans,
embed_dim,
kernel_size=patch_size,
stride=patch_size
)
def forward(self, x):
"""
x: [B, 3, 224, 224]
"""
B, C, H, W = x.shape
# 卷积分块: [B, 3, 224, 224] → [B, 768, 14, 14]
x = self.proj(x)
# 展平: [B, 768, 14, 14] → [B, 768, 196] → [B, 196, 768]
x = x.flatten(2).transpose(1, 2)
return x # [B, 196, 768]
# ============================================
# 步骤2:位置编码 + [CLS] Token
# ============================================
class ViTEmbedding(nn.Module):
"""
组合:Patch Embedding + [CLS] + 位置编码
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, dropout=0.1):
super().__init__()
self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
num_patches = self.patch_embed.num_patches # 196
# [CLS] token: 可学习的向量 [1, 1, embed_dim]
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
# 位置编码: [num_patches + 1, embed_dim] (+1是因为[CLS])
self.pos_embed = nn.Parameter(
torch.zeros(1, num_patches + 1, embed_dim)
)
self.dropout = nn.Dropout(dropout)
# 初始化
nn.init.normal_(self.cls_token, std=0.02)
nn.init.normal_(self.pos_embed, std=0.02)
def forward(self, x):
B = x.shape[0]
# 分块: [B, 196, 768]
x = self.patch_embed(x)
# 添加[CLS]: [B, 1, 768] + [B, 196, 768] = [B, 197, 768]
cls_tokens = self.cls_token.expand(B, -1, -1) # 复制B份
x = torch.cat([cls_tokens, x], dim=1)
# 加位置编码
x = x + self.pos_embed
return self.dropout(x)
# ============================================
# 步骤3:Transformer编码器块
# ============================================
class MultiHeadAttention(nn.Module):
"""
多头自注意力(复用第2讲代码,适配ViT)
"""
def __init__(self, dim, num_heads=12, qkv_bias=True, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
# 一次性生成Q/K/V(效率更高)
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
# QKV: [B, N, 3*dim] → [B, N, 3, num_heads, head_dim] → [3, B, num_heads, N, head_dim]
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # 每个: [B, num_heads, N, head_dim]
# 注意力
attn = (q @ k.transpose(-2, -1)) * self.scale # [B, heads, N, N]
attn = F.softmax(attn, dim=-1)
attn = self.attn_drop(attn)
# 加权
x = (attn @ v).transpose(1, 2).reshape(B, N, C) # [B, N, dim]
x = self.proj(x)
x = self.proj_drop(x)
return x, attn
class MLP(nn.Module):
"""
Transformer中的前馈网络
"""
def __init__(self, in_features, hidden_features=None, out_features=None, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = nn.GELU() # ViT用GELU激活函数
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class TransformerBlock(nn.Module):
"""
ViT的Transformer块:注意力 + MLP + LayerNorm + 残差连接
"""
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.attn = MultiHeadAttention(dim, num_heads, qkv_bias, attn_drop, drop)
self.norm2 = nn.LayerNorm(dim)
mlp_hidden_dim = int(dim * mlp_ratio) # 768 * 4 = 3072
self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden_dim, drop=drop)
def forward(self, x, return_attention=False):
# 注意力子层(先LayerNorm,再注意力,再残差)
attn_out, attn_weights = self.attn(self.norm1(x))
x = x + attn_out
# MLP子层
x = x + self.mlp(self.norm2(x))
if return_attention:
return x, attn_weights
return x
# ============================================
# 步骤4:完整的ViT模型
# ============================================
class VisionTransformer(nn.Module):
"""
完整的Vision Transformer
"""
def __init__(
self,
img_size=224,
patch_size=16,
in_chans=3,
num_classes=1000,
embed_dim=768,
depth=12, # Transformer层数
num_heads=12,
mlp_ratio=4.,
qkv_bias=True,
drop_rate=0.,
attn_drop_rate=0.
):
super().__init__()
# 嵌入层
self.embedding = ViTEmbedding(
img_size, patch_size, in_chans, embed_dim, drop_rate
)
# Transformer编码器
self.blocks = nn.ModuleList([
TransformerBlock(
embed_dim, num_heads, mlp_ratio, qkv_bias,
drop_rate, attn_drop_rate
)
for _ in range(depth)
])
self.norm = nn.LayerNorm(embed_dim)
# 分类头
self.head = nn.Linear(embed_dim, num_classes)
# 初始化
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, x, return_attention=False):
"""
x: [B, 3, 224, 224]
"""
# 嵌入: [B, 197, 768]
x = self.embedding(x)
# 通过Transformer层
attentions = []
for block in self.blocks:
if return_attention:
x, attn = block(x, return_attention=True)
attentions.append(attn)
else:
x = block(x)
# 最终LayerNorm
x = self.norm(x)
# 取[CLS]标记的输出做分类 [B, 768]
cls_output = x[:, 0]
# 分类头
logits = self.head(cls_output) # [B, num_classes]
if return_attention:
return logits, attentions
return logits
# ============================================
# 测试模型
# ============================================
def test_vit():
print("=" * 60)
print("【测试】Vision Transformer模型构建")
print("=" * 60)
# 创建ViT-Base模型
model = VisionTransformer(
img_size=224,
patch_size=16,
num_classes=10, # CIFAR-10
embed_dim=768,
depth=12,
num_heads=12
)
# 统计参数量
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"ViT-Base配置:")
print(f" 图像大小: 224×224")
print(f" 补丁大小: 16×16")
print(f" 补丁数量: 196")
print(f" 嵌入维度: 768")
print(f" Transformer层数: 12")
print(f" 注意力头数: 12")
print(f" 总参数量: {total_params / 1e6:.1f}M")
print(f" 可训练参数: {trainable_params / 1e6:.1f}M")
# 测试前向传播
dummy_input = torch.randn(2, 3, 224, 224) # 2张图片
output = model(dummy_input)
print(f"\n输入形状: {dummy_input.shape}")
print(f"输出形状: {output.shape}") # [2, 10]
print(f"输出示例: {output[0].detach().numpy().round(2)}")
# 测试注意力可视化
print("\n" + "=" * 60)
print("【测试】提取注意力权重")
print("=" * 60)
logits, attentions = model(dummy_input, return_attention=True)
print(f"注意力层数: {len(attentions)}")
print(f"每层注意力形状: {attentions[0].shape}")
# [B, num_heads, N, N] = [2, 12, 197, 197]
# 可视化第1层第1个头的注意力
attn = attentions[0][0, 0].detach().numpy() # 第1个样本,第1层,第1个头
print(f"\n第1层第1个头的注意力矩阵形状: {attn.shape}")
print(f"[CLS]对其他块的平均注意力: {attn[0, 1:].mean():.3f}")
print(f"块之间的平均注意力: {attn[1:, 1:].mean():.3f}")
return model
if __name__ == "__main__":
model = test_vit()
print("\n" + "=" * 60)
print("✅ ViT模型构建完成!")
print("=" * 60)
4.2 实验2:加载预训练ViT,观察注意力图
import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
import numpy as np
# 尝试加载timm库(如果安装了)
try:
import timm
print("=" * 60)
print("【实验】用timm加载预训练ViT")
print("=" * 60)
# 加载预训练的ViT-Base(在ImageNet-21k上预训练)
model = timm.create_model(
'vit_base_patch16_224',
pretrained=True, # 自动下载预训练权重
num_classes=1000
)
model.eval()
print(f"✅ 加载预训练ViT-Base成功!")
print(f"模型来源: Google Research (ImageNet-21k预训练)")
# 获取注意力权重
def get_attention_map(model, image):
"""
提取ViT的注意力图,可视化[CLS]关注图像的哪些区域
"""
# 前向传播,获取注意力
with torch.no_grad():
# timm模型支持return_attention
output = model.forward_features(image.unsqueeze(0))
# 注意:不同版本timm接口可能不同,这里用通用方法
# 替代方案:手动提取
x = model.patch_embed(image.unsqueeze(0))
cls_token = model.cls_token.expand(1, -1, -1)
x = torch.cat((cls_token, x), dim=1)
x = x + model.pos_embed
x = model.pos_drop(x)
# 通过Transformer块,收集注意力
attentions = []
for block in model.blocks:
x, attn = block(x, return_attention=True)
attentions.append(attn)
return attentions
# 加载一张测试图片
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
# 用CIFAR-10的一张图演示
cifar = datasets.CIFAR10(root='./data', train=False, download=True)
img, label = cifar[0] # 第1张图(通常是飞机)
img_tensor = transform(img)
# 获取注意力
attentions = get_attention_map(model, img_tensor)
# 可视化
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
# 原图
ax = axes[0, 0]
img_np = np.array(img)
ax.imshow(img_np)
ax.set_title(f'Original Image (CIFAR-10: {cifar.classes[label]})')
ax.axis('off')
# 不同层的[CLS]注意力
layers = [0, 2, 5, 8, 11] # 第1,3,6,9,12层
for idx, layer_idx in enumerate(layers):
if idx >= 5:
break
ax = axes[idx // 3, idx % 3] if idx < 3 else axes[1, idx - 3]
# 取第layer_idx层的注意力
attn = attentions[layer_idx][0, 0, 0, 1:].reshape(14, 14).cpu().numpy()
# [CLS]对196个patch的注意力,reshape成14×14
# 上采样到224×224
attn_resized = np.kron(attn, np.ones((16, 16)))
ax.imshow(img_np)
ax.imshow(attn_resized, alpha=0.6, cmap='hot')
ax.set_title(f'Layer {layer_idx+1} [CLS] Attention')
ax.axis('off')
plt.tight_layout()
plt.savefig('/mnt/agents/output/vit_attention_maps.png', dpi=150)
plt.show()
print("\n📊 注意力图已保存!")
print("""
解读:
- 浅层(Layer 1-3):[CLS]关注分散,学局部特征
- 中层(Layer 6-9):开始关注语义相关区域
- 深层(Layer 12):高度聚焦目标物体,忽略背景
""")
except ImportError:
print("timm库未安装,运行: pip install timm")
print("下面用自实现版本继续实验")
4.3 实验3:CIFAR-10上的对比实验(ViT vs ResNet)
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
import time
print("=" * 60)
print("【实验3】ViT vs ResNet在CIFAR-10上的对比")
print("=" * 60)
# 数据预处理
# 注意:CIFAR-10图像是32×32,需要resize到224给ViT
transform_train = transforms.Compose([
transforms.Resize(224), # ViT需要224×224
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])
transform_test = transforms.Compose([
transforms.Resize(224),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])
# 加载CIFAR-10
trainset = torchvision.datasets.CIFAR10(
root='./data', train=True, download=True, transform=transform_train
)
trainloader = DataLoader(trainset, batch_size=32, shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(
root='./data', train=False, download=True, transform=transform_test
)
testloader = DataLoader(testset, batch_size=32, shuffle=False, num_workers=2)
classes = ('plane', 'car', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck')
# ============================================
# 模型1:自实现ViT-Tiny(小规模,适合CIFAR-10)
# ============================================
class ViT_Tiny(nn.Module):
"""
缩小版ViT,适合小数据集
"""
def __init__(self, num_classes=10):
super().__init__()
# 用更小的配置
self.vit = VisionTransformer(
img_size=224,
patch_size=16,
num_classes=num_classes,
embed_dim=192, # 缩小到192(原来是768)
depth=12,
num_heads=3, # 对应192/64=3
mlp_ratio=4.,
drop_rate=0.1
)
def forward(self, x):
return self.vit(x)
# ============================================
# 模型2:简单ResNet(对比用)
# ============================================
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, in_planes, planes, stride=1):
super().__init__()
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion * planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(self.expansion * planes)
)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(x)
out = F.relu(out)
return out
class ResNet18(nn.Module):
def __init__(self, num_classes=10):
super().__init__()
self.in_planes = 64
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.layer1 = self._make_layer(64, 2, stride=1)
self.layer2 = self._make_layer(128, 2, stride=2)
self.layer3 = self._make_layer(256, 2, stride=2)
self.layer4 = self._make_layer(512, 2, stride=2)
self.linear = nn.Linear(512 * BasicBlock.expansion, num_classes)
def _make_layer(self, planes, num_blocks, stride):
strides = [stride] + [1] * (num_blocks - 1)
layers = []
for stride in strides:
layers.append(BasicBlock(self.in_planes, planes, stride))
self.in_planes = planes * BasicBlock.expansion
return nn.Sequential(*layers)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
out = F.avg_pool2d(out, 4)
out = out.view(out.size(0), -1)
out = self.linear(out)
return out
# ============================================
# 训练函数
# ============================================
def train_model(model, name, epochs=5):
"""
快速训练并评估模型
"""
print(f"\n{'='*50}")
print(f"训练 {name}")
print(f"{'='*50}")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
# 统计
history = {'train_loss': [], 'train_acc': [], 'test_acc': []}
for epoch in range(epochs):
# 训练
model.train()
running_loss = 0.0
correct = 0
total = 0
for i, (inputs, labels) in enumerate(trainloader):
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
if (i + 1) % 100 == 0:
print(f" Epoch {epoch+1}, Batch {i+1}, "
f"Loss: {loss.item():.3f}, "
f"Acc: {100.*correct/total:.2f}%")
train_acc = 100. * correct / total
train_loss = running_loss / len(trainloader)
# 测试
model.eval()
correct = 0
total = 0
with torch.no_grad():
for inputs, labels in testloader:
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
test_acc = 100. * correct / total
history['train_loss'].append(train_loss)
history['train_acc'].append(train_acc)
history['test_acc'].append(test_acc)
print(f"Epoch {epoch+1} 完成: "
f"Train Loss={train_loss:.3f}, "
f"Train Acc={train_acc:.2f}%, "
f"Test Acc={test_acc:.2f}%")
return history
# ============================================
# 运行对比实验
# ============================================
print("注意:这个实验需要GPU,且训练时间较长(约30分钟-1小时)")
print("如果资源有限,可以减少epochs或batch_size")
# 由于CIFAR-10数据量小(5万张),ViT容易过拟合
# 这里只做1个epoch演示,实际建议用预训练权重
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")
# 创建模型
vit_model = ViT_Tiny(num_classes=10)
resnet_model = ResNet18(num_classes=10)
# 统计参数量
vit_params = sum(p.numel() for p in vit_model.parameters())
resnet_params = sum(p.numel() for p in resnet_model.parameters())
print(f"\n模型对比:")
print(f" ViT-Tiny: {vit_params/1e6:.1f}M 参数")
print(f" ResNet-18: {resnet_params/1e6:.1f}M 参数")
# 训练(减少epoch数用于演示)
epochs = 2 if device.type == 'cpu' else 5
print(f"\n训练轮数: {epochs}(CPU用2轮,GPU用5轮)")
# 训练ResNet
resnet_history = train_model(resnet_model, "ResNet-18", epochs=epochs)
# 训练ViT
vit_history = train_model(vit_model, "ViT-Tiny", epochs=epochs)
# ============================================
# 可视化对比结果
# ============================================
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
# 训练准确率
ax = axes[0]
ax.plot(range(1, epochs+1), resnet_history['train_acc'], 'o-', label='ResNet-18', linewidth=2)
ax.plot(range(1, epochs+1), vit_history['train_acc'], 's-', label='ViT-Tiny', linewidth=2)
ax.set_xlabel('Epoch')
ax.set_ylabel('Train Accuracy (%)')
ax.set_title('Training Accuracy Comparison')
ax.legend()
ax.grid(True, alpha=0.3)
# 测试准确率
ax = axes[1]
ax.plot(range(1, epochs+1), resnet_history['test_acc'], 'o-', label='ResNet-18', linewidth=2)
ax.plot(range(1, epochs+1), vit_history['test_acc'], 's-', label='ViT-Tiny', linewidth=2)
ax.set_xlabel('Epoch')
ax.set_ylabel('Test Accuracy (%)')
ax.set_title('Test Accuracy Comparison')
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('/mnt/agents/output/vit_vs_resnet_cifar10.png', dpi=150)
plt.show()
print("\n" + "=" * 60)
print("📊 对比图已保存!")
print("=" * 60)
print("""
预期结果分析:
在小数据集(CIFAR-10,5万张图)上:
ResNet-18: ✅ 通常表现更好
- 有归纳偏置(局部性、平移不变性)
- 适合小数据,不容易过拟合
ViT-Tiny: ⚠️ 可能表现较差或需要更长时间训练
- 没有归纳偏置,需要学"图像是什么"
- 容易过拟合(参数多,数据少)
- 需要预训练(ImageNet-21k)才能发挥优势
关键结论:
"ViT需要大数据"不是空话!在CIFAR-10上,
除非用ImageNet预训练权重,否则ResNet更实用。
""")
五、关键洞察:为什么ViT需要预训练?
5.1 数据量对比实验(概念性)
场景1:用ImageNet-1k训练(130万张图,1000类)
ResNet-50: ✅ 收敛快,最终精度高
ViT-Base: ⚠️ 收敛慢,精度不如ResNet
原因:ImageNet-1k对ViT来说还是"小数据"
场景2:用ImageNet-21k训练(1400万张图,21000类)
ResNet-50: ✅ 有提升,但边际效益递减
ViT-Base: ✅ 大幅超越ResNet!
ViT-Large: ✅ 更强!
原因:数据量突破临界点后,ViT的灵活性优势显现
场景3:用JFT-300M训练(3亿张图)
ViT-Huge: ✅ 碾压所有CNN!
在ImageNet上达到SOTA,且迁移学习极强
原因:海量数据让ViT学会所有需要的"归纳偏置"
5.2 预训练+微调:ViT的正确打开方式
实际应用流程:
1. 预训练(大数据,长耗时,高成本)
ImageNet-21k或更大 → 训练ViT-Base/Large
需要数百GPU训练数周
2. 微调(小数据,短耗时,低成本)
你的任务数据(如医学影像、工业质检)
加载预训练权重,只训练分类头或少量层
几小时到几天即可
代码示例:
import timm
# 加载预训练ViT
model = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=10)
# 冻结Transformer层,只训练分类头(快速适应)
for param in model.blocks.parameters():
param.requires_grad = False
# 或者:用较小的学习率微调全部层(更好效果)
# optimizer = optim.Adam([
# {'params': model.blocks.parameters(), 'lr': 1e-5}, # 预训练层:小学习率
# {'params': model.head.parameters(), 'lr': 1e-3} # 新分类头:大学习率
# ])
六、核心总结
| 概念 | 一句话解释 |
|---|---|
| Patch Embedding | 把图像切块,用卷积映射成向量序列 |
| [CLS] Token | 特殊标记,聚合全局信息用于分类 |
| 1D位置编码 | 给每个块一个"序号",ViT原版方案 |
| 归纳偏置 | 模型内置的先验假设,CNN强、ViT弱 |
| 大数据需求 | ViT没有归纳偏置,需要更多数据"从零学" |
| 预训练+微调 | 在大数据上预训练,在小数据上微调,ViT的标准用法 |
七、面试高频题
Q1:ViT和CNN的本质区别是什么?
答:核心在于归纳偏置。CNN通过卷积核的局部连接和权重共享,内置了局部性和平移不变性假设,适合中小数据。ViT几乎没有归纳偏置,通过自注意力学习全局关系,需要大数据(如ImageNet-21k)才能发挥优势,但上限更高。
Q2:为什么ViT用[CLS]而不是全局平均池化?
答:[CLS]标记经过Transformer层,通过自注意力与所有patch交互,自然聚合了全局语义信息。实验表明[CLS]比全局池化效果更好,且与BERT的设计一致,便于统一架构理解。
Q3:ViT的位置编码为什么1D就够了?
答:虽然图像是2D的,但实验发现1D和2D位置编码效果相近。因为Transformer的自注意力能自动学习空间关系,且图像分块后的行优先顺序已经隐含了部分位置信息。
Q4:ViT在小数据集上表现不好,怎么解决?
答:三种方案:1)使用预训练权重(ImageNet-21k)再微调;2)数据增强(AutoAugment、Mixup);3)使用蒸馏(DeiT)或更小的模型(ViT-Tiny);4)结合CNN和Transformer(CoaT、BoTNet等混合架构)。
八、课后作业
作业1:观察不同patch_size的影响
# 尝试patch_size=8, 16, 32
# patch_size越小:
# - 块数量越多(64×64=4096块!)
# - 计算量越大(O(n²))
# - 但细粒度越好
作业2:可视化注意力演变
# 修改实验2代码,观察:
# 1. 浅层注意力:是否关注局部边缘?
# 2. 深层注意力:是否聚焦目标物体?
# 3. [CLS]的注意力如何从分散到聚焦?
作业3:思考混合架构
问题:有没有办法结合CNN和ViT的优点?
提示:一些后续工作如:
- CNN做特征提取 + ViT做全局建模(CoaT)
- 在ViT中加入卷积(CvT, CeiT)
- 金字塔结构(Swin Transformer, PVT)
九、下讲预告
第6讲:检测大模型——从DETR到RT-DETR/R-DETR
我们将:
- 理解"把检测变成集合预测"的革命性思想
- 学习匈牙利匹配算法
- 动手用ultralytics加载RT-DETR
- 对比Anchor-based和Anchor-free的优劣
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐



所有评论(0)