环境声明

  • Python版本:Python 3.10+
  • PyTorch版本:PyTorch 2.0+
  • CUDA版本:CUDA 11.8+ (推荐)
  • 开发工具:PyCharm 或 VS Code
  • 操作系统:Windows / macOS / Linux (通用)
  • 依赖库:torchvision, timm, einops, numpy, matplotlib

学习目标

通过本章学习,你将掌握:

  1. 理解Vision Transformer(ViT)的核心思想和架构设计
  2. 掌握图像分块(Patch Embedding)和位置编码的实现
  3. 理解Swin Transformer的层次化设计和移位窗口机制
  4. 掌握DETR端到端目标检测的原理和集合预测
  5. 了解DeiT、BEiT、MAE等ViT改进模型
  6. 理解CNN与Transformer在视觉任务中的差异和取舍
  7. 了解2025年视觉Transformer的最新进展
  8. 能够使用PyTorch实现完整的ViT模型

摘要

2020年,Google提出的Vision Transformer(ViT)彻底改变了计算机视觉领域。它证明了纯Transformer架构无需卷积操作也能在图像分类任务上达到甚至超越CNN的效果。此后,Swin Transformer通过引入层次化特征和移位窗口注意力解决了ViT的计算效率问题,DETR则将Transformer应用于目标检测实现了真正的端到端检测。本章将深入解析这些革命性模型的工作原理,涵盖ViT、Swin、DETR及其改进版本,并提供完整的PyTorch实现代码。


1. Vision Transformer(ViT)原理

1.1 从CNN到Transformer的范式转变

在ViT出现之前,卷积神经网络(CNN)一直是计算机视觉领域的主流架构。从LeNet到AlexNet,从VGG到ResNet,CNN通过局部感受野、权值共享和层次化特征提取取得了巨大成功。

然而,CNN存在以下局限性:

  1. 局部性限制:卷积核的局部感受野难以捕捉全局依赖关系
  2. 归纳偏置:平移不变性等归纳偏置在某些任务中反而成为限制
  3. 扩展性:随着模型规模增大,CNN的性能提升逐渐饱和

Transformer在自然语言处理领域的成功启发了研究者:能否将Transformer直接应用于图像?

核心思想:将图像分割成一系列小块(Patch),将每个Patch视为一个"词",然后直接应用标准的Transformer编码器。

1.2 ViT架构详解

ViT的整体架构包含以下几个关键组件:

图像分块(Image Patching)

假设输入图像大小为 H×W×CH \times W \times CH×W×C,将其分割为大小为 P×PP \times PP×P 的小块:

  • 分块数量:N=H×WP2N = \frac{H \times W}{P^2}N=P2H×W
  • 每个Patch的维度:P2×CP^2 \times CP2×C

例如,对于224x224的图像,使用16x16的Patch大小:

  • 分块数量:(224/16)2=196(224/16)^2 = 196(224/16)2=196 个Patch
  • 每个Patch的维度:16×16×3=76816 \times 16 \times 3 = 76816×16×3=768
线性嵌入(Linear Embedding)

将每个展平的Patch通过线性变换映射到指定的嵌入维度 DDD

z0=[xclass;xp1E;xp2E;⋯ ;xpNE]+Epos\mathbf{z}_0 = [\mathbf{x}_{class}; \mathbf{x}_p^1\mathbf{E}; \mathbf{x}_p^2\mathbf{E}; \cdots; \mathbf{x}_p^N\mathbf{E}] + \mathbf{E}_{pos}z0=[xclass;xp1E;xp2E;;xpNE]+Epos

其中:

  • xpi\mathbf{x}_p^ixpi 是第 iii 个展平的Patch
  • E∈R(P2⋅C)×D\mathbf{E} \in \mathbb{R}^{(P^2 \cdot C) \times D}ER(P2C)×D 是嵌入矩阵
  • xclass\mathbf{x}_{class}xclass 是可学习的类别Token(CLS Token)
  • Epos\mathbf{E}_{pos}Epos 是位置嵌入
CLS Token

ViT借鉴了BERT的设计,在序列开头添加一个可学习的类别Token(CLS Token)。这个Token在Transformer层之间传递信息,最终的输出用于图像分类。

为什么使用CLS Token而不是全局平均池化?

Transformer中的自注意力机制是全局的,CLS Token可以与所有Patch Token交互,从而聚合全局信息。实验表明,CLS Token的性能略优于全局平均池化。

位置编码

由于Transformer本身不具备位置感知能力,需要添加位置编码来保留空间信息。ViT使用可学习的一维位置编码:

Epos∈R(N+1)×D\mathbf{E}_{pos} \in \mathbb{R}^{(N+1) \times D}EposR(N+1)×D

研究表明,使用二维位置编码或相对位置编码可以带来轻微的性能提升,但一维位置编码已经足够有效。

Transformer编码器

ViT使用标准的Transformer编码器,由多头自注意力(MSA)和前馈网络(MLP)组成:

zl′=MSA(LN(zl−1))+zl−1\mathbf{z}'_l = \text{MSA}(\text{LN}(\mathbf{z}_{l-1})) + \mathbf{z}_{l-1}zl=MSA(LN(zl1))+zl1
zl=MLP(LN(zl′))+zl′\mathbf{z}_l = \text{MLP}(\text{LN}(\mathbf{z}'_l)) + \mathbf{z}'_lzl=MLP(LN(zl))+zl

其中LN表示层归一化(Layer Normalization)。

1.3 ViT的变体

Google提出了多个ViT变体,主要区别在于模型深度和嵌入维度:

模型 层数 嵌入维度 MLP维度 参数量 ImageNet Top-1
ViT-Tiny 12 192 768 5.7M 75.4%
ViT-Small 12 384 1536 22M 81.4%
ViT-Base 12 768 3072 86M 84.2%
ViT-Large 24 1024 4096 307M 85.2%
ViT-Huge 32 1280 5120 632M 86.9%

2. Swin Transformer:层次化视觉Transformer

2.1 ViT的计算效率问题

标准ViT使用全局自注意力,计算复杂度为 O(N2)O(N^2)O(N2),其中 NNN 是Patch数量。对于高分辨率图像,这会导致巨大的计算开销:

  • 224x224图像,16x16 Patch:N=196N = 196N=196,计算量可接受
  • 512x512图像,16x16 Patch:N=1024N = 1024N=1024,计算量显著增加
  • 目标检测、语义分割等密集预测任务需要高分辨率特征图

2.2 Swin Transformer的核心创新

Swin Transformer(Shifted Window Transformer)由微软亚洲研究院于2021年提出,通过以下创新解决了ViT的效率问题:

层次化特征图

与ViT保持单一尺度的特征图不同,Swin Transformer采用类似CNN的层次化结构:

  • Stage 1H/4×W/4H/4 \times W/4H/4×W/4 分辨率
  • Stage 2H/8×W/8H/8 \times W/8H/8×W/8 分辨率
  • Stage 3H/16×W/16H/16 \times W/16H/16×W/16 分辨率
  • Stage 4H/32×W/32H/32 \times W/32H/32×W/32 分辨率

这种设计使得Swin Transformer可以方便地应用于目标检测、语义分割等需要多尺度特征的任务。

窗口自注意力(Window Self-Attention)

Swin Transformer将特征图划分为不重叠的窗口,在每个窗口内独立计算自注意力:

  • 设窗口大小为 M×MM \times MM×M
  • 每个窗口包含 M2M^2M2 个Patch
  • 计算复杂度从 O(N2)O(N^2)O(N2) 降低到 O(N⋅M2)O(N \cdot M^2)O(NM2)

对于 h×wh \times wh×w 个Patch的特征图:

  • 全局MSA复杂度:O((hw)2)O((hw)^2)O((hw)2)
  • 窗口MSA复杂度:O((hw)⋅M2)O((hw) \cdot M^2)O((hw)M2)
移位窗口(Shifted Window)

仅使用窗口自注意力会限制跨窗口的信息交互。Swin Transformer引入移位窗口机制:

  • lll:使用规则窗口划分
  • l+1l+1l+1:窗口向右下角移动 M/2M/2M/2 个Patch

这种交替的窗口划分方式实现了跨窗口连接,同时保持计算效率。

循环移位(Cyclic Shift):为了高效实现移位窗口注意力,Swin Transformer使用循环移位将分散的窗口重新组织,使得可以使用批量矩阵乘法计算注意力。

相对位置偏置

Swin Transformer引入可学习的相对位置偏置:

Attention(Q,K,V)=SoftMax(QKTd+B)V\text{Attention}(Q, K, V) = \text{SoftMax}(\frac{QK^T}{\sqrt{d}} + B)VAttention(Q,K,V)=SoftMax(d QKT+B)V

其中 BBB 是相对位置偏置矩阵,相比绝对位置编码能更好地捕捉相对位置关系。

2.3 Swin Transformer架构

Swin Transformer的整体架构如下:

  1. Patch Partition:将输入图像分割为4x4的Patch,通过线性嵌入得到 H/4×W/4×CH/4 \times W/4 \times CH/4×W/4×C 的特征
  2. Stage 1-4:每个Stage包含多个Swin Transformer Block,通过Patch Merging层实现下采样
  3. Swin Transformer Block:交替使用窗口MSA(W-MSA)和移位窗口MSA(SW-MSA)

3. DETR:端到端目标检测

3.1 传统目标检测的局限性

传统目标检测器(如Faster R-CNN、YOLO、SSD)依赖于以下手工设计组件:

  1. 锚框(Anchor Boxes):需要预定义大量锚框
  2. 非极大值抑制(NMS):需要启发式后处理去除重复检测
  3. 多尺度特征金字塔:需要复杂的特征融合策略

这些手工设计使得检测流程复杂,且难以端到端优化。

3.2 DETR的核心思想

DETR(Detection Transformer)将目标检测视为集合预测问题,完全摒弃了锚框和NMS:

核心创新

  1. 集合预测:直接预测一组目标框和类别
  2. 二分匹配:使用匈牙利算法将预测与真实目标匹配
  3. Transformer架构:使用编码器-解码器结构

3.3 DETR架构详解

骨干网络

使用ResNet-50/101提取特征,输出 C=2048C = 2048C=2048 通道的特征图。

Transformer编码器
  1. 将特征图通过1x1卷积降维到 d=256d = 256d=256
  2. 将空间维度展平为序列
  3. 添加二维位置编码
  4. 通过多层Transformer编码器
Transformer解码器

解码器接收 NNN 个可学习的目标查询(Object Queries):

  • NNN 通常设置为100,表示最多检测100个目标
  • 每个目标查询是一个 d=256d = 256d=256 维的向量
  • 目标查询通过自注意力相互交互
  • 通过交叉注意力与编码器输出交互
预测头

每个解码器输出通过共享的前馈网络(FFN)预测:

  • 类别概率(包括"无目标"类)
  • 边界框坐标(中心坐标和宽高)

3.4 二分匹配损失

DETR使用匈牙利算法在预测集合和真实目标集合之间找到最优匹配:

σ^=arg⁡min⁡σ∑i=1NLmatch(yi,y^σ(i))\hat{\sigma} = \arg\min_{\sigma} \sum_{i=1}^N \mathcal{L}_{match}(y_i, \hat{y}_{\sigma(i)})σ^=argσmini=1NLmatch(yi,y^σ(i))

匹配损失包括:

  • 类别预测损失
  • 边界框L1损失
  • 边界框GIoU损失

找到最优匹配后,计算最终的检测损失。

3.5 DETR的优缺点

优点

  • 端到端训练,无需手工设计
  • 无需NMS后处理
  • 全局推理,能处理复杂场景

缺点

  • 小目标检测性能较差
  • 训练收敛慢(需要500个epoch)
  • 计算量大

3.6 DETR的改进

Deformable DETR

使用可变形注意力替代标准注意力,只关注参考点周围的一小部分采样点,显著降低计算量并加速收敛。

Conditional DETR

改进解码器的交叉注意力机制,将空间查询与内容查询分离,加速训练收敛。

DAB-DETR

使用动态锚框(Dynamic Anchor Boxes)作为查询,将查询显式地表示为4D锚框坐标。


4. ViT的改进与变体

4.1 DeiT:数据高效的图像Transformer

ViT的一个主要缺点是需要大规模数据集(如JFT-300M)进行预训练。DeiT(Data-efficient Image Transformer)通过知识蒸馏解决了这个问题:

核心创新

  1. 蒸馏Token(Distillation Token):在CLS Token之外添加一个蒸馏Token,用于学习教师模型(通常是CNN)的输出
  2. 硬蒸馏(Hard Distillation):使用教师模型的硬标签而非软标签
  3. 数据增强:使用AutoAugment、RandAugment等强数据增强

DeiT仅在ImageNet上训练就能达到与ViT相当的性能,无需大规模预训练数据。

4.2 BEiT:BERT风格的图像预训练

BEiT(Bidirectional Encoder Representations from Image Transformers)将BERT的掩码语言建模应用于图像:

核心思想

  1. 将图像离散化为视觉Token(使用预训练的VQ-VAE)
  2. 随机掩码部分图像Patch
  3. 预测被掩码Patch的视觉Token

这种预训练方式使ViT能够学习丰富的视觉表示,在下游任务上表现优异。

4.3 MAE:掩码自编码器

MAE(Masked Autoencoder)是一种简单的自监督预训练方法:

核心思想

  1. 随机掩码高比例(75%)的图像Patch
  2. 编码器只处理可见Patch(提高效率)
  3. 轻量级解码器重建完整图像
  4. 使用像素级MSE损失

MAE展示了ViT强大的表示学习能力,使用75%的掩码比例仍能有效重建图像。

4.4 2025年最新进展

EfficientViT

EfficientViT专为边缘计算设备设计,通过多尺度线性注意力模块在保持性能的同时显著提升推理速度。主要创新包括:

  • 多尺度线性注意力:降低注意力计算的复杂度
  • 三明治布局:在FFN层之间插入注意力层,提高效率
  • 轻量级设计:参数量和计算量大幅减少,适合移动端部署
SparseViT

AAAI 2025提出的SparseViT通过稀疏编码实现参数高效的图像处理,主要特点:

  • 稀疏注意力:只关注重要的Token
  • 非语义中心设计:不依赖语义信息,适用于图像篡改定位等任务
  • 参数效率:在保持性能的同时显著减少参数量
其他进展
  • CrossViT:结合不同尺度Patch的Transformer
  • PiT(Pooling-based Vision Transformer):使用池化操作替代Patch Merging
  • CvT(Convolutional Vision Transformer):在Transformer中引入卷积操作

5. CNN vs Transformer对比

5.1 归纳偏置的差异

特性 CNN Transformer
局部性 强归纳偏置(局部感受野) 无归纳偏置(全局注意力)
平移不变性 内置(权值共享) 需通过位置编码学习
尺度不变性 部分(多尺度设计) 需通过数据学习
全局依赖 需多层堆叠 单层即可捕捉

5.2 数据效率

CNN:由于归纳偏置,CNN在小数据集上表现更好,可以用较少的数据学习到有效的特征。

Transformer:Transformer缺乏归纳偏置,需要大规模数据才能发挥优势。DeiT、BEiT等技术正在缩小这一差距。

5.3 计算效率

模型 计算复杂度 内存占用 推理速度
ResNet-50 O(HW)O(HW)O(HW)
ViT-Base O((HW)2)O((HW)^2)O((HW)2)
Swin-Tiny O(HW⋅M2)O(HW \cdot M^2)O(HWM2) 中等
EfficientViT O(HW)O(HW)O(HW)

5.4 适用场景

选择CNN的场景

  • 资源受限的嵌入式设备
  • 小数据集(<10万张图像)
  • 需要极低延迟的实时应用

选择Transformer的场景

  • 大规模数据集
  • 需要全局上下文理解的任务
  • 多模态学习(文本+图像)

混合架构的趋势

  • ConvNeXt:将CNN设计现代化,借鉴Transformer的优点
  • CoAtNet:结合卷积和注意力
  • MaxViT:使用多轴注意力结合局部和全局信息

6. 避坑小贴士

图像分块大小选择

Patch大小对模型性能有显著影响:

  • 小Patch(8x8):更多Token,计算量大,但细粒度特征更丰富
  • 大Patch(16x16或32x32):Token数量少,计算效率高,但可能丢失细节

建议:图像分类任务使用16x16,密集预测任务(检测、分割)使用8x8或更小。

位置编码的重要性

ViT对位置编码敏感,移除位置编码会导致性能显著下降。如果模型性能不佳,检查位置编码是否正确添加。

预训练与微调

ViT通常需要预训练:

  • 有监督预训练:ImageNet-21k、JFT-300M等大规模数据集
  • 自监督预训练:MAE、BEiT、DINO等方法
  • 微调:在目标任务上微调,通常使用较小的学习率

训练稳定性

Transformer训练比CNN更不稳定:

  • 使用AdamW优化器,权重衰减0.05
  • 学习率预热(Warmup)很重要,通常预热5-10个epoch
  • 使用较大的batch size(4096或更大)
  • 使用随机深度(Stochastic Depth)正则化

内存优化

ViT的内存占用较大,可以尝试:

  • 使用梯度检查点(Gradient Checkpointing)
  • 使用混合精度训练(AMP)
  • 减小batch size,增加梯度累积步数

7. 完整代码实现

7.1 ViT的PyTorch实现

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat


class PatchEmbedding(nn.Module):
    """图像分块嵌入层"""
    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2
        
        # 使用卷积实现分块和线性嵌入
        self.proj = nn.Conv2d(
            in_channels, embed_dim, 
            kernel_size=patch_size, stride=patch_size
        )
        
    def forward(self, x):
        # x: (B, C, H, W)
        x = self.proj(x)  # (B, embed_dim, H/P, W/P)
        x = rearrange(x, 'b e h w -> b (h w) e')  # (B, N, embed_dim)
        return x


class MultiHeadAttention(nn.Module):
    """多头自注意力机制"""
    def __init__(self, embed_dim, num_heads, dropout=0.0):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        assert embed_dim % num_heads == 0, "embed_dim必须能被num_heads整除"
        
        self.qkv = nn.Linear(embed_dim, embed_dim * 3)
        self.proj = nn.Linear(embed_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        B, N, C = x.shape
        
        # 生成Q、K、V
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)  # (3, B, num_heads, N, head_dim)
        q, k, v = qkv[0], qkv[1], qkv[2]
        
        # 计算注意力
        attn = (q @ k.transpose(-2, -1)) * (self.head_dim ** -0.5)
        attn = F.softmax(attn, dim=-1)
        attn = self.dropout(attn)
        
        # 加权求和
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.dropout(x)
        
        return x


class MLP(nn.Module):
    """前馈神经网络"""
    def __init__(self, in_features, hidden_features=None, out_features=None, dropout=0.0):
        super().__init__()
        hidden_features = hidden_features or in_features * 4
        out_features = out_features or in_features
        
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x


class TransformerBlock(nn.Module):
    """Transformer编码器块"""
    def __init__(self, embed_dim, num_heads, mlp_ratio=4.0, dropout=0.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = MultiHeadAttention(embed_dim, num_heads, dropout)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.mlp = MLP(embed_dim, int(embed_dim * mlp_ratio), dropout=dropout)
        
    def forward(self, x):
        # 预归一化(Pre-norm)
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x


class VisionTransformer(nn.Module):
    """Vision Transformer完整实现"""
    def __init__(
        self,
        img_size=224,
        patch_size=16,
        in_channels=3,
        num_classes=1000,
        embed_dim=768,
        depth=12,
        num_heads=12,
        mlp_ratio=4.0,
        dropout=0.0
    ):
        super().__init__()
        self.num_classes = num_classes
        self.embed_dim = embed_dim
        
        # Patch嵌入
        self.patch_embed = PatchEmbedding(
            img_size, patch_size, in_channels, embed_dim
        )
        num_patches = self.patch_embed.num_patches
        
        # CLS Token
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        
        # 位置编码
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        self.pos_dropout = nn.Dropout(dropout)
        
        # Transformer编码器
        self.blocks = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout)
            for _ in range(depth)
        ])
        
        self.norm = nn.LayerNorm(embed_dim)
        
        # 分类头
        self.head = nn.Linear(embed_dim, num_classes)
        
        # 初始化
        nn.init.normal_(self.cls_token, std=0.02)
        nn.init.normal_(self.pos_embed, std=0.02)
        self.apply(self._init_weights)
        
    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
            
    def forward(self, x):
        B = x.shape[0]
        
        # Patch嵌入
        x = self.patch_embed(x)  # (B, N, embed_dim)
        
        # 添加CLS Token
        cls_tokens = self.cls_token.expand(B, -1, -1)  # (B, 1, embed_dim)
        x = torch.cat([cls_tokens, x], dim=1)  # (B, N+1, embed_dim)
        
        # 添加位置编码
        x = x + self.pos_embed
        x = self.pos_dropout(x)
        
        # Transformer编码器
        for block in self.blocks:
            x = block(x)
            
        x = self.norm(x)
        
        # 使用CLS Token进行分类
        cls_output = x[:, 0]
        logits = self.head(cls_output)
        
        return logits


# 测试代码
if __name__ == "__main__":
    # 创建模型
    model = VisionTransformer(
        img_size=224,
        patch_size=16,
        num_classes=1000,
        embed_dim=768,
        depth=12,
        num_heads=12
    )
    
    # 测试前向传播
    x = torch.randn(2, 3, 224, 224)
    output = model(x)
    print(f"输入形状: {x.shape}")
    print(f"输出形状: {output.shape}")
    
    # 计算参数量
    total_params = sum(p.numel() for p in model.parameters())
    print(f"总参数量: {total_params / 1e6:.2f}M")

7.2 使用预训练模型

import torch
import torchvision.transforms as transforms
from PIL import Image
import timm

# 加载预训练的ViT模型
model = timm.create_model('vit_base_patch16_224', pretrained=True)
model.eval()

# 图像预处理
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                        std=[0.229, 0.224, 0.225])
])

# 加载类别标签
import json
import urllib.request

url = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json"
labels = json.loads(urllib.request.urlopen(url).read())

# 推理
def predict(image_path):
    image = Image.open(image_path).convert('RGB')
    input_tensor = transform(image).unsqueeze(0)
    
    with torch.no_grad():
        output = model(input_tensor)
        probabilities = torch.nn.functional.softmax(output[0], dim=0)
        top5_prob, top5_catid = torch.topk(probabilities, 5)
        
    for i in range(5):
        print(f"{labels[top5_catid[i]]}: {top5_prob[i].item():.4f}")

7.3 微调ViT

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import timm

# 配置
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_epochs = 10
batch_size = 32
learning_rate = 1e-4
num_classes = 10  # 例如CIFAR-10

# 数据预处理
train_transform = transforms.Compose([
    transforms.Resize(224),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                        std=[0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                        std=[0.229, 0.224, 0.225])
])

# 加载数据集(以CIFAR-10为例)
train_dataset = datasets.CIFAR10(root='./data', train=True, 
                                  download=True, transform=train_transform)
val_dataset = datasets.CIFAR10(root='./data', train=False, 
                                download=True, transform=val_transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size, 
                          shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=batch_size, 
                        shuffle=False, num_workers=4)

# 加载预训练模型并修改分类头
model = timm.create_model('vit_base_patch16_224', pretrained=True, 
                          num_classes=num_classes)
model = model.to(device)

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.05)

# 学习率调度器
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

# 训练循环
for epoch in range(num_epochs):
    model.train()
    train_loss = 0.0
    train_correct = 0
    train_total = 0
    
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        _, predicted = outputs.max(1)
        train_total += labels.size(0)
        train_correct += predicted.eq(labels).sum().item()
    
    # 验证
    model.eval()
    val_loss = 0.0
    val_correct = 0
    val_total = 0
    
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            val_loss += loss.item()
            _, predicted = outputs.max(1)
            val_total += labels.size(0)
            val_correct += predicted.eq(labels).sum().item()
    
    scheduler.step()
    
    print(f"Epoch [{epoch+1}/{num_epochs}]")
    print(f"  Train Loss: {train_loss/len(train_loader):.4f}, "
          f"Acc: {100.*train_correct/train_total:.2f}%")
    print(f"  Val Loss: {val_loss/len(val_loader):.4f}, "
          f"Acc: {100.*val_correct/val_total:.2f}%")

# 保存模型
torch.save(model.state_dict(), 'vit_finetuned.pth')

8. 本章小结

核心知识点回顾

  1. Vision Transformer(ViT)

    • 将图像分割为Patch,视为序列处理
    • 使用CLS Token进行图像分类
    • 需要大规模数据预训练
  2. Swin Transformer

    • 层次化特征图设计
    • 窗口自注意力降低计算复杂度
    • 移位窗口实现跨窗口信息交互
  3. DETR

    • 端到端目标检测,无需锚框和NMS
    • 集合预测和二分匹配
    • Transformer编码器-解码器架构
  4. 改进模型

    • DeiT:知识蒸馏实现数据高效训练
    • BEiT:BERT风格的掩码图像建模
    • MAE:高比例掩码的自监督预训练
  5. CNN vs Transformer

    • CNN具有强归纳偏置,数据效率高
    • Transformer具有全局建模能力,需要更多数据
    • 混合架构是未来的发展趋势

关键公式总结

ViT Patch嵌入
z0=[xclass;xp1E;⋯ ;xpNE]+Epos\mathbf{z}_0 = [\mathbf{x}_{class}; \mathbf{x}_p^1\mathbf{E}; \cdots; \mathbf{x}_p^N\mathbf{E}] + \mathbf{E}_{pos}z0=[xclass;xp1E;;xpNE]+Epos

多头自注意力
Attention(Q,K,V)=SoftMax(QKTdk)V\text{Attention}(Q, K, V) = \text{SoftMax}(\frac{QK^T}{\sqrt{d_k}})VAttention(Q,K,V)=SoftMax(dk QKT)V

DETR二分匹配
σ^=arg⁡min⁡σ∑i=1NLmatch(yi,y^σ(i))\hat{\sigma} = \arg\min_{\sigma} \sum_{i=1}^N \mathcal{L}_{match}(y_i, \hat{y}_{\sigma(i)})σ^=argσmini=1NLmatch(yi,y^σ(i))

进一步学习资源

  • 论文:An Image is Worth 16x16 Words(ViT)
  • 论文:Swin Transformer: Hierarchical Vision Transformer
  • 论文:End-to-End Object Detection with Transformers(DETR)
  • 代码库:timm(PyTorch Image Models)
  • 代码库:Hugging Face Transformers

学习建议:本章内容较为前沿,建议读者先确保理解Transformer的基础知识(第13章)。实践方面,可以从使用timm库加载预训练模型开始,逐步尝试微调和自定义实现。视觉Transformer是一个快速发展的领域,建议关注最新的顶会论文(CVPR、ICCV、NeurIPS等)了解最新进展。

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐