在这里插入图片描述



摘要

尽管Transformer架构已成为NLP任务上的事实标准,但在计算机视觉中的应用仍然有限。在视觉中,注意力要么与CNN结合起来使用,要么替换CNN的部分组件,但保持其整体结构不变。我们表明,这种对CNN的依赖不再是必要的,一个纯Transformer直接应用到图像块组成的序列上,能够在图像分类任务上取得了非常好的性能。当在海量的数据上预训练,并迁移到多个中型或小型识别基准(如ImageNet,CIFAR-100,VTAB等)上,Vision Transformer即ViT能够取得媲美CNN的结果,同时需要更少的计算资源来训练。

1 介绍

自注意力模型,特别是Transformer(如Vaswani等2017提出的Attention is all you need)已经成为NLP中的首选模型。这个具有统治地位的方法,首先在一个很大的语料上预训练,然后在一个小型的面向任务的数据集上微调(Devlin等2019提出)。由于Transformers计算的有效性和可扩展性,人们已经训练出超过100B参数的前所未有大小的模型了。随着模型和数据持续增长,仍然没有饱和性能的迹象。

但是,在计算机视觉中,CNN保持着统治地位,如AlexNet,到ResNet等。受到NLP成功的启发,多个工作尝试将类似CNN的架构与自注意结合,还有些工作直接替换整个CNN。后者的模型虽然从理论上更有效,由于使用了专门的注意力模块,至今没有在现代硬件加速器上有效地扩展。因此,在大规模图像识别任务上,经典的类ResNet架构仍然是最领先的。

受到NLP中Transformer成功地扩展,我们采取实验,将标准的Transformer直接应用于图像,尽可能小的改动。因此,我们将一个图像切成块,然后将这些块线性嵌入的序列作为Transformer的输入。图像块就相当于NLP中的tokens(words)。我们在图像识别上以监督的方式来训练ViT模型。

当在中型大小的数据库如ImageNet上训练时,这样的模型产生比ResNet稍逊色的百分点。这个看似沮丧的结果可能是意料之中的:Transformers缺少CNN内置的归纳偏置(inductive biases),例如变换等效性和局部性,因此在不充足的数据上训练时不能很好泛化。

但是,在一个更大的数据库如14M-300M张图像上训练,这种情况改变了。我们发现,大规模的训练会胜过归纳偏置。我们的ViT获取了出色的结果,当在足够规模数据上预训练,并迁移到较少数据的任务上时。当在公共的ImageNet-21K或者室内JFT-300M数据集上,ViT在多个图像识别集上取得或打败了最领先的方法。特别的,最好的模型ImageNet取得了88.55%的识别率,ImageNet-Real上90.72%,CIFAR-100上94.55%等

2 相关工作

Transformers于2017被Vaswani等提出用于机器翻译任务,然后成为NLP众多任务中最领先的方法。大的Transformer为基准的模型通常都在大型的语料库上预训练,然后再用于手头的任务上微调:例如,BERT使用一个去噪的自监督预训练任务,而GPT使用语言模型作为预训练任务。

自注意的简单应用于图像可能需要每个像素关注其他每一个像素。由于像素数量的二次方的开销,在实际的输入尺寸上不容易扩展。因此,在应用Transformers到图像处理的上下文语境中,过去曾有几种近似的方法被尝试:Parmar于2018应用自注意,仅仅在局部的相邻像素上去询问像素,而不是全局。这种局部的多头点乘自注意力能够完全代替卷积。又如,Sparse Transformers为了用于图像,采用了可扩展的近似到全局自注意。另一个扩展注意力的方式是将之应用于不同尺度的图像块中,一种极端的情况就是单独应用在像素上。很多这类专门的注意力结构在计算机视觉任务上展示了鼓舞人心的结果,但是在硬件加速器上高效的实现需要复杂的工程。

也有很多的工作结合CNN和自注意力的形式,例如增强特征图(Bello 2019)或使用自注意力进一步处理CNN的输出用于目标检测(Hu 2018),视频处理,图像分类,无监督的物体发现或者统一的文本视觉任务。

我们不知道有先前工作将全局注意力的Transformers应用于全尺寸的图像。最接近我们模型的是iGPT,它在降低图像分辨率和颜色空间后,使用Transformers到图像像素。这个模型以一种无监督的方法作为一个产生式模型来训练的,学习到的表征能够用于图像分类的微调或者线性推测,在ImageNet上实现了72%的最高精度。

我们增加了更多论文来探索在更大规模相对于标准的ImageNet的图像识别。额外的数据源能够在标准集上获得最先进的性能。并且,Sun 2017研究了CNN性能如何随数据集大小而扩展。还有不少文章探索了CNN从更大规模数据集上,如ImageNet-21K或者室内JFT-300M,执行迁移学习的实证探索。我们同样关注后两者数据集,但训练的是Transformer,而不是前人工作中的ResNet类模型。

3 方法

在模型设计方面,我们尽可能跟随原始的Transformer,即Vaswani Transformer。这个有意简单的设置的优点在于,可扩展的NLPtransformer结构和其高效的实现,几乎开箱即用。

在这里插入图片描述
模型概述:我们将一张图像切割成一些固定尺寸的图像块,线性将他们嵌入,添加位置嵌入,然后将产生的向量序列输入到标准的Transformer encoder中。为了执行识别任务,我们添加了一个额外可学习的识别令牌到序列中。Transformer编码器的插图受Vaswani Transformer的启发。

Vision Transformer (ViT)

上图概述了ViT的结构。原始的Transformer接受一维令牌嵌入作为输入。为了处理2D图像,我们将图像 x ∈ R H × W × C x \in R^{H \times W \times C } xRH×W×C重塑成2D块拉伸后的序列即 x p ∈ R N × ( P 2 × C ) x_p \in R^{N \times (P^2 \times C ) } xpRN×(P2×C),其中 ( H , W ) (H, W) (H,W)是原图的分辨率, C C C是通道数, ( P , P ) (P,P) (P,P)是每一个图像块的分辨率, N = H W / P 2 N = HW/P^2 N=HW/P2是产生的图像块的个数,同时也作为Transformer的有效输入的序列长度。Transformer在其所有层中使用固定的隐向量大小D,因此我们将patch拉伸后然后通过可一个训练的线性映射,投影到D维上。我们将这个投影的输出称之为块嵌入,即patch embeddings.

与BERT中的类令牌即class tokens相似,我们在块嵌入的序列之前添加了一个可学习的令牌嵌入即 ( z 0 0 = x c l a s s ) (z^0_{0}=x_{class}) (z00=xclass),它在Transformer编码器的输出 ( z L 0 ) (z_L^0) (zL0)时的状态用作图像表示 y y y。在预训练和微调期间,一个类头,即classification head附加到 ( z L 0 ) (z_L^0) (zL0)。这个分类头,在预训练时由带有一个隐含层的MLP,或在微调时仅有一个线性层来实现的。

位置嵌入也添加到块嵌入中来保留位置信息。我们使用标准的可学习的1D位置嵌入,因为我们还没有观察到使用更先进的2D敏感的位置嵌入的显著提升。生成的嵌入向量序列作为编码器的输入。

Vaswani Transformer包含由多头注意力和MLP块的交替构成。Layernorm在每个block之前都使用了,每个块后使用残差连接。MLP包含两层,带有GELU的非线性激活。
在这里插入图片描述
其中, X p ∈ R N × ( P 2 C ) X_p \in R^{N \times (P^2 C ) } XpRN×(P2C)即所有N个patch块,patch块需要线性投影到D维空间上,故 E ∈ R ( P 2 C ˙ ) × D E \in R^{(P^2 \dot C ) \times D } ER(P2C˙)×D;由于class_token作为可学习的嵌入,与块投影一起合并,所以位置信息需要多一个,即 ( N + 1 ) (N+1) (N+1)

Hybrid Architecture混合结构. 作为原始图像块的替代,这个输入序列可以由CNN的特征图构成(LeNet1989)。在这个混合模型中,块嵌入投影被应用到从CNN特征图中提取的块。作为一种特殊情况,这个块可以由1×1的空间尺度,这意味着输入序列能够通过以下方式获取,即简单地将特征图的空间维度进行拉伸,并投影到Transformer的维度。分类输入嵌入class token和位置嵌入 position embedding都添加进去,如上所述。

微调与高分辨率

典型地,我们在大数据集上预训练ViT,然后再小一点的下游任务上微调。因此,我们移除了预训练的预测头,并添加了一个零初始化的 D × K D \times K D×K反馈层,其中 K K K是下游任务的标签数。文章Touvron2019等表明,通过在比预训练时采用更高的分辨率能够有好处。当输入更高分辨率的图像时,我们保持每个patch的尺度不变,这将导致一个更大的有效序列长度。ViT能够处理任意长度的序列,这取决于内存的大小,但是预训练的位置信息将不再有意义。因此,我们根据他们在原始图像中的位置,对预训练的位置嵌入采取了2D插值。注意到,分辨率调整和patch块提取是将图像的2D归纳偏置手动输入ViT的唯一点。

4 实验

我们评估了ResNet,ViT和混合模型的表征学习能力。为了了解每个模型的数据要求,我们在不同大小的数据集上预训练并在多个任务集上评估。当考虑预训练模型的运算开销,ViT表现非常出色,以较低的成本在多个识别基准上取得了领先的水平。最后,我们执行了一个使用自监督的小实验,表明自监督的ViT展现出未来的希望。
现在回头看这篇2019文章,发现其贡献了很多实在的东西,如ViT和为MAE铺垫的自监督学习任务。

4.1设置

数据集。 未来探索模型的可扩展性,我们使用了三组数据集:带有1.3M的ILSVRC-2012 ImageNet-1K,14M张图片的21K分类的ImageNet-21K,和303M高分辨率的JFT-18K。我们将在这几个数据集上训练好的模型迁移到几个基准上:带有原始验证标签和清理后的Real标签的ImageNet,CIFAR-10/100,Oxford-IIIT Pets和Oxford Flowers-102。采用koles2020中的预处理这些数据集。
我们还在19任务的VTAB分类套件上进行了评估。VTAB评估向不同任务传递的低级数据,每个任务使用1000个训练样本。这些任务被分成三组:自然组-任务如上,例如Pets,CIFAR等;专业组-例如医学或卫星成像;结构化-任务需要几何理解,如定位。

检查ViT

为了开始了解ViT如何处理图像数据,我们分析其内在的表征。ViT的第一层线性将拉伸后的图像块映射到一个低维空间,如公式1。以下图片展示了学习到的嵌入滤波器的主成分。这些成分类似于合理的基函数,用于每个块内精细结构的低维表征。
在这里插入图片描述
在这个投影之后,一个可学习的位置编码添加到这些块表征中。下图展示了模型学习图像内在位置嵌入上的距离编码,也就是说,更近的图像块倾向于有更多相似的位置嵌入。并且,这个行列结果出现,图像块在相同的行或列上有相似的嵌入。最后,在更大的格子上有时候回出现正弦结构。这个位置嵌入尝试去学习表征图像的2D拓扑结构,解释了为什么手工的2D敏感的嵌入变量不产生提升。
在这里插入图片描述
自注意力允许ViT去集成整个图片的信息,甚至在最低层上。我们调查这个网络在多大程度上去利用这个能力。特别地,我们根据注意力权重,去计算图像空间中信息集成的平均距离。如下图所示。这个注意力距离相当于CNN中的视觉感受野。
在这里插入图片描述
通过注意力头与网络深度关注到的区域大小。每个点表示在每一层一个注意力头在图像上的平均注意力距离。

我们发现有一些注意力头,关注了最底层中的大部分图像,表明全局集成信息的能力确实存在于这个模型中。其他一些注意力头,在底层的注意力距离上始终较小。这种高度局部化的注意力在混合模型即在ViT之前先ResNet中不太明显,表明它可能具有与CNN中早期卷积层相似的功能。此外,这个注意力距离随着层数的增加而增加。全局来说,我们发现这个模型关注为分类语义有关的图像区域。

4.6 自监督

Transformers在NLP任务中展示出非凡的性能。但是,很多他们的成功不仅源于其出色的可扩展性,也源于大规模的自监督预训练。我们同样进行的初步探索,就是对遮挡的图像块的自监督预测,模仿了BERT中对遮挡语言的建模行为。在自监督的预训练中,我们更小的ViT-B16模型在ImageNet获得了79。9%的角度,比从头开始训练提升了2%,但还是比有监督的预训练落后4%。

5 结论

我们已经探索了Transformer在图像识别上的直接应用。不像先前的工作使用自注意力用于CV,我们没有引入任何的针对图像的先导偏置到架构中。反而,我们试图将图像作为一个序列块,通过一个标准的Transformer编码器来处理,跟NLP使用的编码器一样。这个简单的可扩展的策略应用非常成功,当在大数据集上预训练的时候。因此,视觉Transformer匹配或者超过了很多数据集上领先的工作,但相对来说预训练花销更少。

虽然这个初步的结果比较鼓舞人心,很多挑战存在。一个是应用ViT到其他任务中,例如检测和分割。我们的工作,配合Carion202,为这个探索指出可取之处。另一个挑战就是继续探索自监督的预训练方法。我们初始的实验表明,来自自监督的预训练能够提升,但是在自监督和大规模的有监督训练之间仍然存在巨大的鸿沟。最后,ViT的进一步扩展将导致性能的提升。

Appendix附件

A. 多头注意力
标准的QKV自注意力是一个流行的构造块,对于神经架构而言。对于输入序列 z ∈ R N × D z \in R^{N \times D} zRN×D中每一个元素,我们计算一个在序列所有值 v v v的加权和。注意力权重 A i j A_{ij} Aij基于序列中两个元素之间的成对相似度和他们对应的询问query q i q^i qi和值 k j k^{j} kj的表征。
在这里插入图片描述
多头注意力,MSA是SA的一个扩展,其中我们执行 k k k次自注意力操作,也称为“Heads”,并行将他们的输出串接在一起。为了保持计算和参数的个数稳定,当改变K时, D h D_h Dh通常被设置成 D / k D/k D/k.

M S A ( z ) = [ S A 1 ( z ) ; S A 2 ( z ) ; …   ; S A k ( z ) ] U m s a , U m s a ∈ R k D h × D MSA(z) = [SA_1(z);SA_2(z); \dots;SA_k(z)] U_{msa}, U_{msa} \in R^{kD_h \times D} MSA(z)=[SA1(z);SA2(z);;SAk(z)]Umsa,UmsaRkDh×D

GitHub 加速计划 / vi / vision
15.85 K
6.89 K
下载
pytorch/vision: 一个基于 PyTorch 的计算机视觉库,提供了各种计算机视觉算法和工具,适合用于实现计算机视觉应用程序。
最近提交(Master分支:2 个月前 )
868a3b42 13 天前
e9a32135 22 天前
Logo

旨在为数千万中国开发者提供一个无缝且高效的云端环境,以支持学习、使用和贡献开源项目。

更多推荐