【深度学习精通】第20章 | 视觉Transformer - ViT、Swin与DETR
环境声明
- Python版本:Python 3.10+
- PyTorch版本:PyTorch 2.0+
- CUDA版本:CUDA 11.8+ (推荐)
- 开发工具:PyCharm 或 VS Code
- 操作系统:Windows / macOS / Linux (通用)
- 依赖库:torchvision, timm, einops, numpy, matplotlib
学习目标
通过本章学习,你将掌握:
- 理解Vision Transformer(ViT)的核心思想和架构设计
- 掌握图像分块(Patch Embedding)和位置编码的实现
- 理解Swin Transformer的层次化设计和移位窗口机制
- 掌握DETR端到端目标检测的原理和集合预测
- 了解DeiT、BEiT、MAE等ViT改进模型
- 理解CNN与Transformer在视觉任务中的差异和取舍
- 了解2025年视觉Transformer的最新进展
- 能够使用PyTorch实现完整的ViT模型
摘要
2020年,Google提出的Vision Transformer(ViT)彻底改变了计算机视觉领域。它证明了纯Transformer架构无需卷积操作也能在图像分类任务上达到甚至超越CNN的效果。此后,Swin Transformer通过引入层次化特征和移位窗口注意力解决了ViT的计算效率问题,DETR则将Transformer应用于目标检测实现了真正的端到端检测。本章将深入解析这些革命性模型的工作原理,涵盖ViT、Swin、DETR及其改进版本,并提供完整的PyTorch实现代码。
1. Vision Transformer(ViT)原理
1.1 从CNN到Transformer的范式转变
在ViT出现之前,卷积神经网络(CNN)一直是计算机视觉领域的主流架构。从LeNet到AlexNet,从VGG到ResNet,CNN通过局部感受野、权值共享和层次化特征提取取得了巨大成功。
然而,CNN存在以下局限性:
- 局部性限制:卷积核的局部感受野难以捕捉全局依赖关系
- 归纳偏置:平移不变性等归纳偏置在某些任务中反而成为限制
- 扩展性:随着模型规模增大,CNN的性能提升逐渐饱和
Transformer在自然语言处理领域的成功启发了研究者:能否将Transformer直接应用于图像?
核心思想:将图像分割成一系列小块(Patch),将每个Patch视为一个"词",然后直接应用标准的Transformer编码器。
1.2 ViT架构详解
ViT的整体架构包含以下几个关键组件:
图像分块(Image Patching)
假设输入图像大小为 H×W×CH \times W \times CH×W×C,将其分割为大小为 P×PP \times PP×P 的小块:
- 分块数量:N=H×WP2N = \frac{H \times W}{P^2}N=P2H×W
- 每个Patch的维度:P2×CP^2 \times CP2×C
例如,对于224x224的图像,使用16x16的Patch大小:
- 分块数量:(224/16)2=196(224/16)^2 = 196(224/16)2=196 个Patch
- 每个Patch的维度:16×16×3=76816 \times 16 \times 3 = 76816×16×3=768
线性嵌入(Linear Embedding)
将每个展平的Patch通过线性变换映射到指定的嵌入维度 DDD:
z0=[xclass;xp1E;xp2E;⋯ ;xpNE]+Epos\mathbf{z}_0 = [\mathbf{x}_{class}; \mathbf{x}_p^1\mathbf{E}; \mathbf{x}_p^2\mathbf{E}; \cdots; \mathbf{x}_p^N\mathbf{E}] + \mathbf{E}_{pos}z0=[xclass;xp1E;xp2E;⋯;xpNE]+Epos
其中:
- xpi\mathbf{x}_p^ixpi 是第 iii 个展平的Patch
- E∈R(P2⋅C)×D\mathbf{E} \in \mathbb{R}^{(P^2 \cdot C) \times D}E∈R(P2⋅C)×D 是嵌入矩阵
- xclass\mathbf{x}_{class}xclass 是可学习的类别Token(CLS Token)
- Epos\mathbf{E}_{pos}Epos 是位置嵌入
CLS Token
ViT借鉴了BERT的设计,在序列开头添加一个可学习的类别Token(CLS Token)。这个Token在Transformer层之间传递信息,最终的输出用于图像分类。
为什么使用CLS Token而不是全局平均池化?
Transformer中的自注意力机制是全局的,CLS Token可以与所有Patch Token交互,从而聚合全局信息。实验表明,CLS Token的性能略优于全局平均池化。
位置编码
由于Transformer本身不具备位置感知能力,需要添加位置编码来保留空间信息。ViT使用可学习的一维位置编码:
Epos∈R(N+1)×D\mathbf{E}_{pos} \in \mathbb{R}^{(N+1) \times D}Epos∈R(N+1)×D
研究表明,使用二维位置编码或相对位置编码可以带来轻微的性能提升,但一维位置编码已经足够有效。
Transformer编码器
ViT使用标准的Transformer编码器,由多头自注意力(MSA)和前馈网络(MLP)组成:
zl′=MSA(LN(zl−1))+zl−1\mathbf{z}'_l = \text{MSA}(\text{LN}(\mathbf{z}_{l-1})) + \mathbf{z}_{l-1}zl′=MSA(LN(zl−1))+zl−1
zl=MLP(LN(zl′))+zl′\mathbf{z}_l = \text{MLP}(\text{LN}(\mathbf{z}'_l)) + \mathbf{z}'_lzl=MLP(LN(zl′))+zl′
其中LN表示层归一化(Layer Normalization)。
1.3 ViT的变体
Google提出了多个ViT变体,主要区别在于模型深度和嵌入维度:
| 模型 | 层数 | 嵌入维度 | MLP维度 | 参数量 | ImageNet Top-1 |
|---|---|---|---|---|---|
| ViT-Tiny | 12 | 192 | 768 | 5.7M | 75.4% |
| ViT-Small | 12 | 384 | 1536 | 22M | 81.4% |
| ViT-Base | 12 | 768 | 3072 | 86M | 84.2% |
| ViT-Large | 24 | 1024 | 4096 | 307M | 85.2% |
| ViT-Huge | 32 | 1280 | 5120 | 632M | 86.9% |
2. Swin Transformer:层次化视觉Transformer
2.1 ViT的计算效率问题
标准ViT使用全局自注意力,计算复杂度为 O(N2)O(N^2)O(N2),其中 NNN 是Patch数量。对于高分辨率图像,这会导致巨大的计算开销:
- 224x224图像,16x16 Patch:N=196N = 196N=196,计算量可接受
- 512x512图像,16x16 Patch:N=1024N = 1024N=1024,计算量显著增加
- 目标检测、语义分割等密集预测任务需要高分辨率特征图
2.2 Swin Transformer的核心创新
Swin Transformer(Shifted Window Transformer)由微软亚洲研究院于2021年提出,通过以下创新解决了ViT的效率问题:
层次化特征图
与ViT保持单一尺度的特征图不同,Swin Transformer采用类似CNN的层次化结构:
- Stage 1:H/4×W/4H/4 \times W/4H/4×W/4 分辨率
- Stage 2:H/8×W/8H/8 \times W/8H/8×W/8 分辨率
- Stage 3:H/16×W/16H/16 \times W/16H/16×W/16 分辨率
- Stage 4:H/32×W/32H/32 \times W/32H/32×W/32 分辨率
这种设计使得Swin Transformer可以方便地应用于目标检测、语义分割等需要多尺度特征的任务。
窗口自注意力(Window Self-Attention)
Swin Transformer将特征图划分为不重叠的窗口,在每个窗口内独立计算自注意力:
- 设窗口大小为 M×MM \times MM×M
- 每个窗口包含 M2M^2M2 个Patch
- 计算复杂度从 O(N2)O(N^2)O(N2) 降低到 O(N⋅M2)O(N \cdot M^2)O(N⋅M2)
对于 h×wh \times wh×w 个Patch的特征图:
- 全局MSA复杂度:O((hw)2)O((hw)^2)O((hw)2)
- 窗口MSA复杂度:O((hw)⋅M2)O((hw) \cdot M^2)O((hw)⋅M2)
移位窗口(Shifted Window)
仅使用窗口自注意力会限制跨窗口的信息交互。Swin Transformer引入移位窗口机制:
- 第 lll 层:使用规则窗口划分
- 第 l+1l+1l+1 层:窗口向右下角移动 M/2M/2M/2 个Patch
这种交替的窗口划分方式实现了跨窗口连接,同时保持计算效率。
循环移位(Cyclic Shift):为了高效实现移位窗口注意力,Swin Transformer使用循环移位将分散的窗口重新组织,使得可以使用批量矩阵乘法计算注意力。
相对位置偏置
Swin Transformer引入可学习的相对位置偏置:
Attention(Q,K,V)=SoftMax(QKTd+B)V\text{Attention}(Q, K, V) = \text{SoftMax}(\frac{QK^T}{\sqrt{d}} + B)VAttention(Q,K,V)=SoftMax(dQKT+B)V
其中 BBB 是相对位置偏置矩阵,相比绝对位置编码能更好地捕捉相对位置关系。
2.3 Swin Transformer架构
Swin Transformer的整体架构如下:
- Patch Partition:将输入图像分割为4x4的Patch,通过线性嵌入得到 H/4×W/4×CH/4 \times W/4 \times CH/4×W/4×C 的特征
- Stage 1-4:每个Stage包含多个Swin Transformer Block,通过Patch Merging层实现下采样
- Swin Transformer Block:交替使用窗口MSA(W-MSA)和移位窗口MSA(SW-MSA)
3. DETR:端到端目标检测
3.1 传统目标检测的局限性
传统目标检测器(如Faster R-CNN、YOLO、SSD)依赖于以下手工设计组件:
- 锚框(Anchor Boxes):需要预定义大量锚框
- 非极大值抑制(NMS):需要启发式后处理去除重复检测
- 多尺度特征金字塔:需要复杂的特征融合策略
这些手工设计使得检测流程复杂,且难以端到端优化。
3.2 DETR的核心思想
DETR(Detection Transformer)将目标检测视为集合预测问题,完全摒弃了锚框和NMS:
核心创新:
- 集合预测:直接预测一组目标框和类别
- 二分匹配:使用匈牙利算法将预测与真实目标匹配
- Transformer架构:使用编码器-解码器结构
3.3 DETR架构详解
骨干网络
使用ResNet-50/101提取特征,输出 C=2048C = 2048C=2048 通道的特征图。
Transformer编码器
- 将特征图通过1x1卷积降维到 d=256d = 256d=256
- 将空间维度展平为序列
- 添加二维位置编码
- 通过多层Transformer编码器
Transformer解码器
解码器接收 NNN 个可学习的目标查询(Object Queries):
- NNN 通常设置为100,表示最多检测100个目标
- 每个目标查询是一个 d=256d = 256d=256 维的向量
- 目标查询通过自注意力相互交互
- 通过交叉注意力与编码器输出交互
预测头
每个解码器输出通过共享的前馈网络(FFN)预测:
- 类别概率(包括"无目标"类)
- 边界框坐标(中心坐标和宽高)
3.4 二分匹配损失
DETR使用匈牙利算法在预测集合和真实目标集合之间找到最优匹配:
σ^=argminσ∑i=1NLmatch(yi,y^σ(i))\hat{\sigma} = \arg\min_{\sigma} \sum_{i=1}^N \mathcal{L}_{match}(y_i, \hat{y}_{\sigma(i)})σ^=argσmini=1∑NLmatch(yi,y^σ(i))
匹配损失包括:
- 类别预测损失
- 边界框L1损失
- 边界框GIoU损失
找到最优匹配后,计算最终的检测损失。
3.5 DETR的优缺点
优点:
- 端到端训练,无需手工设计
- 无需NMS后处理
- 全局推理,能处理复杂场景
缺点:
- 小目标检测性能较差
- 训练收敛慢(需要500个epoch)
- 计算量大
3.6 DETR的改进
Deformable DETR
使用可变形注意力替代标准注意力,只关注参考点周围的一小部分采样点,显著降低计算量并加速收敛。
Conditional DETR
改进解码器的交叉注意力机制,将空间查询与内容查询分离,加速训练收敛。
DAB-DETR
使用动态锚框(Dynamic Anchor Boxes)作为查询,将查询显式地表示为4D锚框坐标。
4. ViT的改进与变体
4.1 DeiT:数据高效的图像Transformer
ViT的一个主要缺点是需要大规模数据集(如JFT-300M)进行预训练。DeiT(Data-efficient Image Transformer)通过知识蒸馏解决了这个问题:
核心创新:
- 蒸馏Token(Distillation Token):在CLS Token之外添加一个蒸馏Token,用于学习教师模型(通常是CNN)的输出
- 硬蒸馏(Hard Distillation):使用教师模型的硬标签而非软标签
- 数据增强:使用AutoAugment、RandAugment等强数据增强
DeiT仅在ImageNet上训练就能达到与ViT相当的性能,无需大规模预训练数据。
4.2 BEiT:BERT风格的图像预训练
BEiT(Bidirectional Encoder Representations from Image Transformers)将BERT的掩码语言建模应用于图像:
核心思想:
- 将图像离散化为视觉Token(使用预训练的VQ-VAE)
- 随机掩码部分图像Patch
- 预测被掩码Patch的视觉Token
这种预训练方式使ViT能够学习丰富的视觉表示,在下游任务上表现优异。
4.3 MAE:掩码自编码器
MAE(Masked Autoencoder)是一种简单的自监督预训练方法:
核心思想:
- 随机掩码高比例(75%)的图像Patch
- 编码器只处理可见Patch(提高效率)
- 轻量级解码器重建完整图像
- 使用像素级MSE损失
MAE展示了ViT强大的表示学习能力,使用75%的掩码比例仍能有效重建图像。
4.4 2025年最新进展
EfficientViT
EfficientViT专为边缘计算设备设计,通过多尺度线性注意力模块在保持性能的同时显著提升推理速度。主要创新包括:
- 多尺度线性注意力:降低注意力计算的复杂度
- 三明治布局:在FFN层之间插入注意力层,提高效率
- 轻量级设计:参数量和计算量大幅减少,适合移动端部署
SparseViT
AAAI 2025提出的SparseViT通过稀疏编码实现参数高效的图像处理,主要特点:
- 稀疏注意力:只关注重要的Token
- 非语义中心设计:不依赖语义信息,适用于图像篡改定位等任务
- 参数效率:在保持性能的同时显著减少参数量
其他进展
- CrossViT:结合不同尺度Patch的Transformer
- PiT(Pooling-based Vision Transformer):使用池化操作替代Patch Merging
- CvT(Convolutional Vision Transformer):在Transformer中引入卷积操作
5. CNN vs Transformer对比
5.1 归纳偏置的差异
| 特性 | CNN | Transformer |
|---|---|---|
| 局部性 | 强归纳偏置(局部感受野) | 无归纳偏置(全局注意力) |
| 平移不变性 | 内置(权值共享) | 需通过位置编码学习 |
| 尺度不变性 | 部分(多尺度设计) | 需通过数据学习 |
| 全局依赖 | 需多层堆叠 | 单层即可捕捉 |
5.2 数据效率
CNN:由于归纳偏置,CNN在小数据集上表现更好,可以用较少的数据学习到有效的特征。
Transformer:Transformer缺乏归纳偏置,需要大规模数据才能发挥优势。DeiT、BEiT等技术正在缩小这一差距。
5.3 计算效率
| 模型 | 计算复杂度 | 内存占用 | 推理速度 |
|---|---|---|---|
| ResNet-50 | O(HW)O(HW)O(HW) | 低 | 快 |
| ViT-Base | O((HW)2)O((HW)^2)O((HW)2) | 高 | 慢 |
| Swin-Tiny | O(HW⋅M2)O(HW \cdot M^2)O(HW⋅M2) | 中 | 中等 |
| EfficientViT | O(HW)O(HW)O(HW) | 低 | 快 |
5.4 适用场景
选择CNN的场景:
- 资源受限的嵌入式设备
- 小数据集(<10万张图像)
- 需要极低延迟的实时应用
选择Transformer的场景:
- 大规模数据集
- 需要全局上下文理解的任务
- 多模态学习(文本+图像)
混合架构的趋势:
- ConvNeXt:将CNN设计现代化,借鉴Transformer的优点
- CoAtNet:结合卷积和注意力
- MaxViT:使用多轴注意力结合局部和全局信息
6. 避坑小贴士
图像分块大小选择
Patch大小对模型性能有显著影响:
- 小Patch(8x8):更多Token,计算量大,但细粒度特征更丰富
- 大Patch(16x16或32x32):Token数量少,计算效率高,但可能丢失细节
建议:图像分类任务使用16x16,密集预测任务(检测、分割)使用8x8或更小。
位置编码的重要性
ViT对位置编码敏感,移除位置编码会导致性能显著下降。如果模型性能不佳,检查位置编码是否正确添加。
预训练与微调
ViT通常需要预训练:
- 有监督预训练:ImageNet-21k、JFT-300M等大规模数据集
- 自监督预训练:MAE、BEiT、DINO等方法
- 微调:在目标任务上微调,通常使用较小的学习率
训练稳定性
Transformer训练比CNN更不稳定:
- 使用AdamW优化器,权重衰减0.05
- 学习率预热(Warmup)很重要,通常预热5-10个epoch
- 使用较大的batch size(4096或更大)
- 使用随机深度(Stochastic Depth)正则化
内存优化
ViT的内存占用较大,可以尝试:
- 使用梯度检查点(Gradient Checkpointing)
- 使用混合精度训练(AMP)
- 减小batch size,增加梯度累积步数
7. 完整代码实现
7.1 ViT的PyTorch实现
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
class PatchEmbedding(nn.Module):
"""图像分块嵌入层"""
def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
super().__init__()
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = (img_size // patch_size) ** 2
# 使用卷积实现分块和线性嵌入
self.proj = nn.Conv2d(
in_channels, embed_dim,
kernel_size=patch_size, stride=patch_size
)
def forward(self, x):
# x: (B, C, H, W)
x = self.proj(x) # (B, embed_dim, H/P, W/P)
x = rearrange(x, 'b e h w -> b (h w) e') # (B, N, embed_dim)
return x
class MultiHeadAttention(nn.Module):
"""多头自注意力机制"""
def __init__(self, embed_dim, num_heads, dropout=0.0):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
assert embed_dim % num_heads == 0, "embed_dim必须能被num_heads整除"
self.qkv = nn.Linear(embed_dim, embed_dim * 3)
self.proj = nn.Linear(embed_dim, embed_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
B, N, C = x.shape
# 生成Q、K、V
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
qkv = qkv.permute(2, 0, 3, 1, 4) # (3, B, num_heads, N, head_dim)
q, k, v = qkv[0], qkv[1], qkv[2]
# 计算注意力
attn = (q @ k.transpose(-2, -1)) * (self.head_dim ** -0.5)
attn = F.softmax(attn, dim=-1)
attn = self.dropout(attn)
# 加权求和
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.dropout(x)
return x
class MLP(nn.Module):
"""前馈神经网络"""
def __init__(self, in_features, hidden_features=None, out_features=None, dropout=0.0):
super().__init__()
hidden_features = hidden_features or in_features * 4
out_features = out_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = nn.GELU()
self.fc2 = nn.Linear(hidden_features, out_features)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.dropout(x)
x = self.fc2(x)
x = self.dropout(x)
return x
class TransformerBlock(nn.Module):
"""Transformer编码器块"""
def __init__(self, embed_dim, num_heads, mlp_ratio=4.0, dropout=0.0):
super().__init__()
self.norm1 = nn.LayerNorm(embed_dim)
self.attn = MultiHeadAttention(embed_dim, num_heads, dropout)
self.norm2 = nn.LayerNorm(embed_dim)
self.mlp = MLP(embed_dim, int(embed_dim * mlp_ratio), dropout=dropout)
def forward(self, x):
# 预归一化(Pre-norm)
x = x + self.attn(self.norm1(x))
x = x + self.mlp(self.norm2(x))
return x
class VisionTransformer(nn.Module):
"""Vision Transformer完整实现"""
def __init__(
self,
img_size=224,
patch_size=16,
in_channels=3,
num_classes=1000,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.0,
dropout=0.0
):
super().__init__()
self.num_classes = num_classes
self.embed_dim = embed_dim
# Patch嵌入
self.patch_embed = PatchEmbedding(
img_size, patch_size, in_channels, embed_dim
)
num_patches = self.patch_embed.num_patches
# CLS Token
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
# 位置编码
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
self.pos_dropout = nn.Dropout(dropout)
# Transformer编码器
self.blocks = nn.ModuleList([
TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout)
for _ in range(depth)
])
self.norm = nn.LayerNorm(embed_dim)
# 分类头
self.head = nn.Linear(embed_dim, num_classes)
# 初始化
nn.init.normal_(self.cls_token, std=0.02)
nn.init.normal_(self.pos_embed, std=0.02)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, x):
B = x.shape[0]
# Patch嵌入
x = self.patch_embed(x) # (B, N, embed_dim)
# 添加CLS Token
cls_tokens = self.cls_token.expand(B, -1, -1) # (B, 1, embed_dim)
x = torch.cat([cls_tokens, x], dim=1) # (B, N+1, embed_dim)
# 添加位置编码
x = x + self.pos_embed
x = self.pos_dropout(x)
# Transformer编码器
for block in self.blocks:
x = block(x)
x = self.norm(x)
# 使用CLS Token进行分类
cls_output = x[:, 0]
logits = self.head(cls_output)
return logits
# 测试代码
if __name__ == "__main__":
# 创建模型
model = VisionTransformer(
img_size=224,
patch_size=16,
num_classes=1000,
embed_dim=768,
depth=12,
num_heads=12
)
# 测试前向传播
x = torch.randn(2, 3, 224, 224)
output = model(x)
print(f"输入形状: {x.shape}")
print(f"输出形状: {output.shape}")
# 计算参数量
total_params = sum(p.numel() for p in model.parameters())
print(f"总参数量: {total_params / 1e6:.2f}M")
7.2 使用预训练模型
import torch
import torchvision.transforms as transforms
from PIL import Image
import timm
# 加载预训练的ViT模型
model = timm.create_model('vit_base_patch16_224', pretrained=True)
model.eval()
# 图像预处理
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
# 加载类别标签
import json
import urllib.request
url = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json"
labels = json.loads(urllib.request.urlopen(url).read())
# 推理
def predict(image_path):
image = Image.open(image_path).convert('RGB')
input_tensor = transform(image).unsqueeze(0)
with torch.no_grad():
output = model(input_tensor)
probabilities = torch.nn.functional.softmax(output[0], dim=0)
top5_prob, top5_catid = torch.topk(probabilities, 5)
for i in range(5):
print(f"{labels[top5_catid[i]]}: {top5_prob[i].item():.4f}")
7.3 微调ViT
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import timm
# 配置
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_epochs = 10
batch_size = 32
learning_rate = 1e-4
num_classes = 10 # 例如CIFAR-10
# 数据预处理
train_transform = transforms.Compose([
transforms.Resize(224),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(10),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
val_transform = transforms.Compose([
transforms.Resize(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
# 加载数据集(以CIFAR-10为例)
train_dataset = datasets.CIFAR10(root='./data', train=True,
download=True, transform=train_transform)
val_dataset = datasets.CIFAR10(root='./data', train=False,
download=True, transform=val_transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size,
shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=batch_size,
shuffle=False, num_workers=4)
# 加载预训练模型并修改分类头
model = timm.create_model('vit_base_patch16_224', pretrained=True,
num_classes=num_classes)
model = model.to(device)
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.05)
# 学习率调度器
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
# 训练循环
for epoch in range(num_epochs):
model.train()
train_loss = 0.0
train_correct = 0
train_total = 0
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
train_loss += loss.item()
_, predicted = outputs.max(1)
train_total += labels.size(0)
train_correct += predicted.eq(labels).sum().item()
# 验证
model.eval()
val_loss = 0.0
val_correct = 0
val_total = 0
with torch.no_grad():
for images, labels in val_loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
val_loss += loss.item()
_, predicted = outputs.max(1)
val_total += labels.size(0)
val_correct += predicted.eq(labels).sum().item()
scheduler.step()
print(f"Epoch [{epoch+1}/{num_epochs}]")
print(f" Train Loss: {train_loss/len(train_loader):.4f}, "
f"Acc: {100.*train_correct/train_total:.2f}%")
print(f" Val Loss: {val_loss/len(val_loader):.4f}, "
f"Acc: {100.*val_correct/val_total:.2f}%")
# 保存模型
torch.save(model.state_dict(), 'vit_finetuned.pth')
8. 本章小结
核心知识点回顾
-
Vision Transformer(ViT)
- 将图像分割为Patch,视为序列处理
- 使用CLS Token进行图像分类
- 需要大规模数据预训练
-
Swin Transformer
- 层次化特征图设计
- 窗口自注意力降低计算复杂度
- 移位窗口实现跨窗口信息交互
-
DETR
- 端到端目标检测,无需锚框和NMS
- 集合预测和二分匹配
- Transformer编码器-解码器架构
-
改进模型
- DeiT:知识蒸馏实现数据高效训练
- BEiT:BERT风格的掩码图像建模
- MAE:高比例掩码的自监督预训练
-
CNN vs Transformer
- CNN具有强归纳偏置,数据效率高
- Transformer具有全局建模能力,需要更多数据
- 混合架构是未来的发展趋势
关键公式总结
ViT Patch嵌入:
z0=[xclass;xp1E;⋯ ;xpNE]+Epos\mathbf{z}_0 = [\mathbf{x}_{class}; \mathbf{x}_p^1\mathbf{E}; \cdots; \mathbf{x}_p^N\mathbf{E}] + \mathbf{E}_{pos}z0=[xclass;xp1E;⋯;xpNE]+Epos
多头自注意力:
Attention(Q,K,V)=SoftMax(QKTdk)V\text{Attention}(Q, K, V) = \text{SoftMax}(\frac{QK^T}{\sqrt{d_k}})VAttention(Q,K,V)=SoftMax(dkQKT)V
DETR二分匹配:
σ^=argminσ∑i=1NLmatch(yi,y^σ(i))\hat{\sigma} = \arg\min_{\sigma} \sum_{i=1}^N \mathcal{L}_{match}(y_i, \hat{y}_{\sigma(i)})σ^=argσmini=1∑NLmatch(yi,y^σ(i))
进一步学习资源
- 论文:An Image is Worth 16x16 Words(ViT)
- 论文:Swin Transformer: Hierarchical Vision Transformer
- 论文:End-to-End Object Detection with Transformers(DETR)
- 代码库:timm(PyTorch Image Models)
- 代码库:Hugging Face Transformers
学习建议:本章内容较为前沿,建议读者先确保理解Transformer的基础知识(第13章)。实践方面,可以从使用timm库加载预训练模型开始,逐步尝试微调和自定义实现。视觉Transformer是一个快速发展的领域,建议关注最新的顶会论文(CVPR、ICCV、NeurIPS等)了解最新进展。
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐



所有评论(0)