摘要

这篇论文探讨了数据混合策略(例如CutMix)在提高卷积神经网络(CNNs)性能方面的有效性,并指出这些策略在视觉Transformer(ViTs)上同样有效。然而,发现了一个“token fluctuation phenomenon”,这限制了数据混合策略在ViTs上的潜力。具体来说,输入token在前向传播过程中的贡献会出现波动,可能导致输出token的混合比例与预期不同,从而使得原始数据混合策略计算出的训练目标不准确,影响训练效果。为了解决这个问题,论文提出了一种名为Token-Label Alignment (TL-Align) 的方法,通过追踪变换后的token与原始token之间的对应关系,为每个token保持标签。TL-Align方法通过重用每层计算出的注意力来高效地进行token-label对齐,仅引入了微小的额外训练成本。广泛的实验表明,该方法在图像分类、语义分割、目标检测和迁移学习任务上提高了ViTs的性能。

拓展阅读:

  • Mixup:将随机的两张样本按比例混合,分类的结果按比例分配;
  • Cutout:随机的将样本中的部分区域cut掉,并且填充0像素值,分类的结果不变;
  • CutMix:就是将一部分区域cut掉但不填充0像素而是随机填充训练集中的其他数据的区域像素值,分类结果按一定的比例分配

拟解决的问题

 论文主要解决的问题是ViTs在使用数据混合策略时出现的token波动现象,这导致了训练目标的不准确,影响了模型的训练效果。

创新之处

  • 提出了Token-Label Alignment (TL-Align) 方法,这是一种新的训练策略,用于解决ViTs在数据混合策略下的token波动问题。
  • TL-Align通过追踪输入token和变换后token之间的对应关系,为每个token动态地对齐标签,而不是简单地使用输入时的标签。
  • 该方法重用了每层的注意力计算,使得对齐过程的额外训练成本非常小。

方法论

(a)类似 CutMix 的方法广泛用于模型训练,它在空间上混合了标记及其在输入空间中的标签。

(b) 它们最初是为 CNN 设计的,并假设处理后的标记与输入标记在空间上对齐。我们表明,由于全局感受野和自适应权重,它对 ViT 不成立。

(c) 与现有方法相比,我们的方法可以有效且高效地对齐标记和标签,而无需预训练的教师网络。

  • 问题识别在ViT中,自注意力机制会导致输入token的波动,即token在网络前向传播过程中的重要性发生变化,这可能导致输出token的混合比例与原始图像不一致
  • 影响: 这种波动会影响训练目标的准确性,因为原始的数据混合策略可能无法正确反映经过自注意力处理后的token分布。

框架图:

方法的核心:通过追踪输入token和变换后token之间的对应关系,为每个输出token获得对齐的标签。 

工作流程:

1. 划分token,投影到适合的维度并添加位置嵌入

 2. 给每个令牌Z_{i}分配一个嵌入的标签:

3. 以分层方式执行 TL-Align,并根据令牌的操作计算对齐标签。形式上,ViTs 使用 self-attention 来执行输入标记 Z 的空间混合:

 为了对齐标签,我们使用相同的注意力矩阵 A(Q, K) 更新标签嵌入 Y:

对于多头注意力机制: 对齐方式通过简单地取所有注意力矩阵的平均值来对齐来调整我们的标签对齐到 MSA:

4. 逐层同步地将标签与处理后的标记对齐,得到对齐的标记 Z^{L} 和标签 Y^{L}。 图像z的最终表示为类标记Z_{cls}^{L}(DeiT)或所有空间标记的平均池化\frac{1}{N}\sum_{i=1}^{N}z_{i}^{L}(Swin Transformer)。根据具体的模型,图像的对齐标签y_{align}y_{cls}^{L}\frac{1}{N}\sum_{i=1}^{N}y_{i}^{L}。然后我们采用对齐后的标签y_{align}来训练网络,它可以适应不同的损失函数和训练方案:

不会通过对齐的标签反向传播,因为它们只能作为更准确的目标。在逐层传播期间自适应地调整每个标记的标签,并在整个前向过程中保留标记和标签之间的对齐。TL-Align 只在训练期间使用,并且在推理时不会引入额外的计算成本。 

具体来说,使用得到的对齐标签作为损失函数的输入,与模型的原始输出一起计算损失。损失函数通常是一个分类损失,如交叉熵损失。根据计算得到的损失,通过反向传播算法更新模型的参数。TL-Align方法的一个关键特点是对齐标签在训练中是停止梯度的,这意味着不会将梯度传回标签嵌入的初始化过程。

结论: 论文的实验结果表明,TL-Align能够一致性地提高不同ViT模型的性能,并且在各种下游任务中验证了其鲁棒性和泛化能力。此外,作者们还探讨了将TL-Align方法应用于其他架构(如MLP-like模型)的可能性,指出这是未来研究的一个有前景的方向。

GitHub 加速计划 / tra / transformers
130.24 K
25.88 K
下载
huggingface/transformers: 是一个基于 Python 的自然语言处理库,它使用了 PostgreSQL 数据库存储数据。适合用于自然语言处理任务的开发和实现,特别是对于需要使用 Python 和 PostgreSQL 数据库的场景。特点是自然语言处理库、Python、PostgreSQL 数据库。
最近提交(Master分支:3 个月前 )
13493215 * remove v4.44 deprecations * PR comments * deprecations scheduled for v4.50 * hub version update * make fiuxp --------- Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> 1 天前
8d50fda6 * Remove FSDP wrapping from sub-models. * solve conflict trainer.py * make fixup * add unit test for fsdp_auto_wrap_policy when using auto_find_batch_size * put back extract_model_from_parallel * use transformers unwrap_model 1 天前
Logo

旨在为数千万中国开发者提供一个无缝且高效的云端环境,以支持学习、使用和贡献开源项目。

更多推荐