顶会新热点!斯坦福全新架构TTT干翻Mamba和Transformer!
点击下方卡片,关注「3D视觉工坊」公众号
选择星标,干货第一时间送达
来源:3D视觉工坊
添加小助理:dddvision,备注:方向+学校/公司+昵称,拉你入群。文末附行业细分群
扫描下方二维码,加入3D视觉知识星球,星球内凝聚了众多3D视觉实战问题,以及各个模块的学习资料:近20门视频课程(星球成员免费学习)、最新顶会论文、计算机视觉书籍、优质3D视觉算法源码等。想要入门3D视觉、做项目、搞科研,欢迎扫码加入!
0. 这篇文章干了啥?
2020年,OpenAI的扩展定律论文(Kaplan等人)表明,LSTM(一种RNN)无法像Transformer那样进行扩展,也无法有效地利用长上下文。我们观察到Mamba——当今最流行的RNN之一——的扩展性与强大的Transformer相似,自2020年的LSTM以来取得了巨大进展。然而,我们观察到Mamba存在与Kaplan等人对LSTM的相同问题。序列中后续的标记在平均意义上应该更容易预测,因为它们基于更多的信息。这确实是Transformer的情况,其每个标记索引的平均困惑度在其32k上下文中逐渐降低。相比之下,Mamba在同一指标上在16k之后达到了平稳状态。
这一结果揭示了现有RNN的一个尴尬现实。一方面,RNN(与Transformer相比)的主要优势在于其线性(与二次方)复杂度。这种渐近优势实际上仅在长上下文中实现,这发生在8k之后。另一方面,一旦上下文足够长,现有的RNN(如Mamba)就难以实际利用额外的条件信息。
长上下文的困难是RNN层本质所固有的:与自注意力不同,RNN层必须将上下文压缩成固定大小的隐藏状态。作为一种压缩启发式方法,更新规则需要发现成千上万甚至数百万个标记之间的潜在结构和关系。在本文中,我们首先观察到自监督学习可以将庞大的训练集压缩成如大型语言模型(LLM)这样的模型权重,这些模型通常对其训练数据之间的语义联系有深刻的理解——这正是我们所需要的压缩启发式方法。
TTT 层。基于这一观察,我们设计了一类新的序列建模层,其中隐藏状态是一个模型,而更新规则是自监督学习的一个步骤。因为测试序列上隐藏状态的更新过程相当于在测试时训练一个模型,所以这类新的层被称为测试时训练(TTT)层。我们在这类层中引入了两个简单的实例:TTT-Linear 和 TTT-MLP,其中隐藏状态分别是线性模型和两层的多层感知机(MLP)。TTT 层可以集成到任何网络架构中,并像循环神经网络(RNNs)层和自注意力机制一样进行端到端的优化。
实际运行时间。虽然 TTT 层在浮点运算次数(FLOPs)上已经相当高效,但我们提出了两项实用的创新来进一步提高其在实际运行时间上的效率。首先,类似于在常规训练期间对序列的小批量进行梯度步长操作以实现更好的并行性,我们在 TTT 期间也使用标记的小批量。其次,我们为每个 TTT 小批量内的操作开发了一个对等形式,以更好地利用现代 GPU 和 TPU。对等形式在输出上与原始实现等价,但训练速度提高了 5 倍以上。
评估与开放问题。虽然我们在论文开头已经强调了一些 TTT-Linear 的结果,但论文第 3 节对 TTT-Linear 和 TTT-MLP 进行了更全面的评估,并指出了评估中暴露的开放问题。例如,我们按照 Chinchilla 配方进行的评估显示,即使是 Transformer 基线,也不完全符合线性扩展趋势。受我们学术资源的限制,我们鼓励社区与我们一同探索这些问题的解决方案。
下面一起来阅读一下这项工作~
1. 论文信息
标题:Learning to (Learn at Test Time): RNNs with Expressive Hidden States
作者:Yu Sun, Xinhao Li, Karan Dalal, Jiarui Xu, Arjun Vikram, Genghan Zhang, Yann Dubois, Xinlei Chen, Xiaolong Wang, Sanmi Koyejo, Tatsunori Hashimoto, Carlos Guestrin
机构:斯坦福大学、UC San Diego、UC伯克利、Meta AI
原文链接:https://arxiv.org/abs/2407.04620
2. 摘要
自注意力机制在长文本环境中表现优异,但具有二次复杂性。现有的循环神经网络(RNN)层具有线性复杂性,但其在长文本环境中的性能受到其隐藏状态表达能力的限制。我们提出了一类新的序列建模层,它们具有线性复杂性和表达能力强的隐藏状态。核心思想是将隐藏状态本身设计为一个机器学习模型,并且更新规则是自监督学习的一个步骤。由于隐藏状态即使在测试序列上也会通过训练进行更新,因此我们的层被称为测试时训练(TTT)层。我们考虑了两种实例化方式:TTT-Linear和TTT-MLP,它们的隐藏状态分别是线性模型和两层多层感知机(MLP)。我们在12500万至13亿参数的规模上评估了我们的实例化,并与强大的Transformer和现代RNN Mamba进行了比较。TTT-Linear和TTT-MLP的性能均达到或超过了基线模型。与Transformer类似,它们可以通过对更多标记进行条件处理来持续降低困惑度,而Mamba在16k上下文之后则无法做到这一点。在初步的系统优化后,TTT-Linear在8k上下文上已经比Transformer更快,并且在墙上时钟时间上与Mamba相当。TTT-MLP在内存I/O方面仍面临挑战,但在长文本环境中显示出更大的潜力,为未来的研究指明了一个有前途的方向。
3. 效果展示
所有序列建模层都可以表示为根据更新规则转换的隐藏状态。我们的核心思想是使隐藏状态本身成为一个权重为W的模型f,更新规则是自监督损失ℓ的梯度步长。因此,更新测试序列上的隐藏状态相当于在测试时训练模型f。这个过程被称为测试时间训练(TTT),被编程到我们的TTT层中。
与Mamba相比,TTT Linear具有更好的困惑度和更少的FLOP(左),以及更好地利用长上下文(右)。左:书籍上的缩放趋势,放大350M和1.3B参数。在760M和1.3B时,TTT Linear在使用较少FLOP的困惑度方面优于Mamba,在线性插值方面优于Transformer。右:Transformer和TTT Linear可以在更多令牌的条件下不断减少困惑,而Mamba在16k上下文后则不能。所有方法都将训练FLOP与曼巴1.4B相匹配。
随着上下文长度的变化,批大小16的每个令牌的转发时间(延迟)。所有型号均为1.3B(曼巴为1.4B)。随着上下文长度的增加,Transformer的每个令牌的前向时间呈线性增长,但对于其他两种方法,前向时间大致保持不变。TTT Linear在8k上下文下比Transformer快,并且与Mamba匹配。
4. 基本原理是啥?
所有序列建模层都可以从将历史上下文存储到隐藏状态的角度来观察,如图 4所示。例如,循环神经网络(RNN)层——如长短期记忆网络(LSTM)、RWKV和Mamba层——会将上下文压缩成跨时间的固定大小的状态。这种压缩有两个结果。一方面,将输入标记 xt 映射到输出标记 zt 是高效的,因为更新规则和输出规则对每个标记都采取恒定时间。另一方面,RNN 层在长上下文中的性能受到其隐藏状态 st 表达能力的限制。自注意力也可以从上述角度观察,但其隐藏状态,通常称为键值(KV)缓存,是一个随 t 线性增长的列表。其更新规则只是将当前的 KV 元组追加到这个列表中,而输出规则则扫描所有直到 t 的元组以形成注意力矩阵。隐藏状态明确存储了所有历史上下文而不进行压缩,这使得自注意力在长上下文中比 RNN 层更具表达能力。然而,扫描这个线性增长的隐藏状态也使得每个标记的处理时间线性增长。
为了在长上下文中既保持高效又保持表达能力,我们需要一个更好的压缩启发式方法。具体来说,我们需要将数千个甚至数百万个标记压缩成一个隐藏状态,该状态能够有效地捕捉它们的基础结构和关系。这听起来可能是一项艰巨的任务,但实际上我们都已经熟悉这样的启发式方法了。
5. 实验结果
Chinchilla的论文实证性地观察到,遵循其配方的计算最优模型在FLOP与困惑度的对数图中落在一条线上,就像标度定律实验的情况一样。然而,我们在图11或图12中没有观察到干净的线性拟合(书籍中的类似实验),即使是变压器也是如此。考虑到数据集、上下文长度、标记器和架构的差异,这并不奇怪。根据Mamba论文,由于误差较大,我们将点连接起来,而不是用线性回归拟合它们。
从图13中,我们可以观察到:
•TTT Linear和TTT-MLP这两种性能最好的方法几乎完全重叠。曼巴和TF微调的线条在1020次浮点运算后也大多重叠。
•TF微调的性能明显优于TF预训练,因为它受益于长期背景,而不会在训练FLOP时产生巨大的成本。请注意,TF微调和预训练的推断FLOP同样差,这在本图中没有反映出来。
•对于从头开始训练的所有方法(包括TF预训练),一旦上下文长度变得太大,困惑就会变得更糟。我们将对这一趋势的进一步调查留给未来的工作。
6. 总结 & 未来工作
我们已经将监督学习的经典问题重新表述为学习在测试时进行学习。我们的表述为构建传统上称为网络架构的内容提供了一个替代性的概念框架。在这个框架内,有效实例化的搜索空间是巨大的,而我们的论文只是迈出了一小步。幸运的是,如果我们的观点成立,那么常规训练中的启发式方法可以转移到测试时训练,并且搜索可以更加高效。接下来,我们概述了未来工作的一些特别有前景的方向。
• 外层循环参数化。有多种方式可以参数化一系列多视图重建任务,或者更一般的自监督任务家族。如果我们尝试的第一个方法就是最好的,那将是一个巨大的巧合。
• 系统优化。我们的系统优化充其量只是初步的,并且有很多方法可以对其进行改进。此外,通过时间的流水线并行可能允许我们在多个设备上共同处理数百万个标记的长序列。
• 更长的上下文和更大的模型。受学术资源的限制,我们尚未在数百万或数十亿长度的上下文中进行训练,这也将需要更大的模型。在更长的上下文中,TTT 层的优势应该更加明显。
• 关于f的更宏大的实例化。当上下文长度变得更长时,f也需要相应地变得更大。对于视频任务和具身代理(embodied agents),其上下文长度可以轻易地扩展到数百万或数十亿,此时f可以是一个卷积神经网络。
• 多级学习以学习。如果f本身是一个自注意力层,那么它可以被解释为嵌套在现有内部循环中的另一个内部循环。通过这种方式,我们可以潜在地构建多层嵌套的学习问题。
我们为什么研究TTT(测试时训练)?首先,从一个更基本的问题出发:我们为什么研究人工智能?对于我们中的一些人来说,人工智能是一个探索人类智能本质的试验场。先前的工作经常试图用机器学习来模拟人类学习,其中训练是在一个打乱的数据集上进行的,包含独立同分布的实例,而推理则是在一个单独的测试集上进行的。然而,人类并不是自然地用独立同分布的实例来学习,也没有明确的训练集和测试集之分。我们相信,人类学习与TTT,即我们的内部循环,有着更有前途的联系。TTT的数据可能是一个具有很强时间依赖性的非常长的序列,而且任何数据片段都可以同时用于训练和测试。这就是我们研究TTT的原因。
对更多实验结果和文章细节感兴趣的读者,可以阅读一下论文原文~
本文仅做学术分享,如有侵权,请联系删文。
3D视觉工坊交流群
目前我们已经建立了3D视觉方向多个社群,包括2D计算机视觉、大模型、工业3D视觉、SLAM、自动驾驶、三维重建、无人机等方向,细分群包括:
2D计算机视觉:图像分类/分割、目标/检测、医学影像、GAN、OCR、2D缺陷检测、遥感测绘、超分辨率、人脸检测、行为识别、模型量化剪枝、迁移学习、人体姿态估计等
大模型:NLP、CV、ASR、生成对抗大模型、强化学习大模型、对话大模型等
工业3D视觉:相机标定、立体匹配、三维点云、结构光、机械臂抓取、缺陷检测、6D位姿估计、相位偏折术、Halcon、摄影测量、阵列相机、光度立体视觉等。
SLAM:视觉SLAM、激光SLAM、语义SLAM、滤波算法、多传感器融合、多传感器标定、动态SLAM、MOT SLAM、NeRF SLAM、机器人导航等。
自动驾驶:深度估计、Transformer、毫米波|激光雷达|视觉摄像头传感器、多传感器标定、多传感器融合、自动驾驶综合群等、3D目标检测、路径规划、轨迹预测、3D点云分割、模型部署、车道线检测、Occupancy、目标跟踪等。
三维重建:3DGS、NeRF、多视图几何、OpenMVS、MVSNet、colmap、纹理贴图等
无人机:四旋翼建模、无人机飞控等
除了这些,还有求职、硬件选型、视觉产品落地、最新论文、3D视觉最新产品、3D视觉行业新闻等交流群
添加小助理: dddvision,备注:研究方向+学校/公司+昵称(如3D点云+清华+小草莓), 拉你入群。
3D视觉工坊知识星球
3D视觉从入门到精通知识星球、国内成立最早、6000+成员交流学习。包括:星球视频课程近20门(价值超6000)、项目对接、3D视觉学习路线总结、最新顶会论文&代码、3D视觉行业最新模组、3D视觉优质源码汇总、书籍推荐、编程基础&学习工具、实战项目&作业、求职招聘&面经&面试题等等。欢迎加入3D视觉从入门到精通知识星球,一起学习进步。
3DGS、NeRF、结构光、相位偏折术、机械臂抓取、点云实战、Open3D、缺陷检测、BEV感知、Occupancy、Transformer、模型部署、3D目标检测、深度估计、多传感器标定、规划与控制、无人机仿真、三维视觉C++、三维视觉python、dToF、相机标定、ROS2、机器人控制规划、LeGo-LAOM、多模态融合SLAM、LOAM-SLAM、室内室外SLAM、VINS-Fusion、ORB-SLAM3、MVSNet三维重建、colmap、线面结构光、硬件结构光扫描仪,无人机等。
3D视觉相关硬件
点这里👇关注我,记得标星哦~
一键三连「分享」、「点赞」和「在看」
3D视觉科技前沿进展日日相见 ~
更多推荐
所有评论(0)