ICLR 2021
Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby
https://arxiv.org/abs/2010.11929

一、简介

在视觉领域,注意力机制要么与卷积结合使用,要么用来替换卷积网络的某些组件,同时保持其整体结构不变。

我们提出了ViT(Vision Transformer),一种仅仅使用transformer来进行图像分类任务。我们将图像分割成多个块,并提供这些块的线性嵌入序列,然后直接输入到标准的transformer进行图像分类学习。

该方法在中等数据集上(例如ImageNet)表现不如ResNet,但是当数据集规模扩大时,其效果要接近甚至好于目前的一些SOTA结果。作者认为是大规模的训练可以鼓励transformer学到CNN结构所拥有的translation equivariance和locality。

二、模型

在这里插入图片描述

在这里插入图片描述

具体流程:

token表示一个句子中的一个一个的单词嵌入。

标准的Transformer的输入是一个一维token序列而我们要处理的是二维图像,因此,我们将图像划分为多个patch子图,每个子图相当于一个token。图中将图片分成了9个patch,然后通过一个线性变换层将二维patch变为一个D维的嵌入特征。

然后在D维嵌入特征中加入位置嵌入以保留位置信息,这有利于模型正确评估注意力权重。实验使用标准的可学习的一维位置嵌入,因为在实验中发现二维位置嵌入并没有显著提升。

另外在最前面添加了一个learnable embedding(记作Xclass),这个可学习嵌入经过encoder之后对整个图作整体表示。因为这个token是没有语义信息的,所以不会出现其整体表示偏向于指定embedding信息。

这个learnable embedding经过Transformer得到的输出嵌入将会经过一个隐层MLP后给出分类预测结果。

Fine-tuning过程中高分辨率图像的处理:

在Fine-tuning到下游任务时,当图像的分辨率增大时(即图像的长和宽增大时),如果保持patch大小不变,得到的patch个数将由 N 变为 N’ 个。但是由于在pretrain时,position embedding的个数是N。则多出来的 N’-N 个positioin embedding在pretrain中是无意义的。

为了解决这个问题,文章中提出用2D插值的方法,基于原图中的位置信息,将pretrain中的 N 个position embedding插值成 N’ 个。从而保证了position embedding的语义信息。

三、实验

在这里插入图片描述

在这里插入图片描述

在JFT-300M上预先训练的VIT-L/16模型在所有任务上都优于BIT-L(在相同的数据集上预先训练),同时需要的计算资源要少得多。

在这里插入图片描述

上图显示,通过大量的预训练数据集进行预训练才能发挥出Transformer的性能。

四、理解Vision Transformer

在这里插入图片描述

左图是每个patch的低维表示。

中间的图是位置嵌入的相似性,越接近的patch具有更为相似的位置嵌入。我们发现,同一行/列的patch具有相似的位置嵌入。

右图是根据注意力权重计算图像空间中信息集成的平均距离。我们发现,有一些Head已经有很大的注意力距离,这表明了模型确实使用了全局集成信息的能力。有些Head的注意力距离一直很小,这表明这些Head可能起到与CNN类似的功能,但随着网络深度的增加,注意力距离会变大。

在这里插入图片描述

从整体上看,我们发现,该模型关注的是语义上与分类相关的图像区域。

在这里插入图片描述

如图所示,虽然没有位置嵌入的模型和有位置嵌入的模型的性能之间存在很大的差距,但是不同的位置信息编码方式之间几乎没有差别。

可能的原因是Transformer的输入是patch,而不是像素级别的输入,空间维度比原始像素级别输入小得多,学习以分辨率表示空间关系容易的多。

GitHub 加速计划 / tra / transformers
130.24 K
25.88 K
下载
huggingface/transformers: 是一个基于 Python 的自然语言处理库,它使用了 PostgreSQL 数据库存储数据。适合用于自然语言处理任务的开发和实现,特别是对于需要使用 Python 和 PostgreSQL 数据库的场景。特点是自然语言处理库、Python、PostgreSQL 数据库。
最近提交(Master分支:2 个月前 )
13493215 * remove v4.44 deprecations * PR comments * deprecations scheduled for v4.50 * hub version update * make fiuxp --------- Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> 14 小时前
8d50fda6 * Remove FSDP wrapping from sub-models. * solve conflict trainer.py * make fixup * add unit test for fsdp_auto_wrap_policy when using auto_find_batch_size * put back extract_model_from_parallel * use transformers unwrap_model 14 小时前
Logo

旨在为数千万中国开发者提供一个无缝且高效的云端环境,以支持学习、使用和贡献开源项目。

更多推荐