Titans: 在测试时学习记忆

Paper链接:Titans

摘要

在过去十多年里,研究界一直在探索如何更有效地利用循环模型与注意力机制。循环模型试图将数据压缩到固定大小的记忆中(即隐藏状态),而注意力机制允许模型关注整个上下文窗口,从而捕获所有 token 之间的直接依赖关系。然而,这种更精确的依赖建模需要二次复杂度,因此模型只能处理固定长度的上下文。

本文提出了一种新的神经长时记忆模块,它能够学习记忆历史上下文,并帮助注意力在关注当前上下文的同时利用久远过去的信息。我们表明,这种神经记忆在保持快速推理的同时,还具备可并行化、训练速度快的优势。

从记忆视角来看,我们认为:由于注意力的上下文有限但依赖建模精确,因此它更像是短时记忆;而神经记忆由于能够记住数据,因此更像是长期且更持久的记忆。基于这两个模块,我们提出了一类新架构,称为 Titans,并给出了三种变体来回答如何有效地将记忆并入架构这一问题。

在语言建模、常识推理、基因组学和时间序列任务上的实验表明,Titans 比 Transformer 以及近期现代线性循环模型更有效。此外,在“大海捞针”任务中,Titans 可以有效扩展到超过 2M 的上下文窗口,并取得比基线更高的准确率。


1 Introduction

“记忆的真正艺术,就是注意力的艺术!”
—— Samuel Johnson, 1787

Transformer 作为一种纯注意力架构,已经成为序列建模中的事实标准。这主要归功于它的上下文学习能力以及大规模学习能力。Transformer 的核心构件——注意力模块——可以被看作一种联想记忆块:它学习存储 key-value 对,并通过计算 query 与 key 的两两相似度来完成检索。

因此,从设计上说,Transformer 的输出只依赖于当前上下文窗口内 token 的直接依赖关系。这种依赖建模虽然准确,但其时间和内存复杂度都关于上下文长度呈二次增长。在语言建模、视频理解、长时序预测等真实任务中,上下文窗口可能极大,从而使 Transformer 的应用受到限制。

为了解决 Transformer 的可扩展性问题,近期工作尝试设计各种线性 Transformer:用核函数替换 softmax,以显著降低内存消耗。尽管这类方法在效率和长上下文扩展性上有优势,但它们通常无法达到 Transformer 的性能,因为核技巧会使模型退化为线性循环网络,将数据压缩到矩阵值状态中。于是就出现了一个矛盾:我们引入线性模型,是为了在很长上下文下获得效率优势;但真正很长的上下文,又无法被小的向量或矩阵状态充分压缩。

进一步地,除了效率之外,从 Hopfield Network、LSTM 到 Transformer,多数现有架构在泛化、长度外推与推理方面仍然存在困难。作者认为,这些架构虽然受人脑启发,但往往缺失以下内容之一:

  1. 学习过程中至关重要的组成部分,例如短时记忆、长时记忆、元记忆、对当前上下文的注意等;
  2. 这些组成部分如何彼此连接,同时又能独立运作;
  3. 主动从数据中学习,并把过去历史的抽象存入记忆的能力。

作者主张:高效的学习范式应当像人脑一样,由若干彼此区分但又互相连接的模块组成,每个模块负责学习过程中的一个关键部分。

Memory Perspective

记忆是人类学习不可分割的一部分。没有正常工作的记忆系统,人和动物将只能进行基础反射和刻板行为。受此启发,机器学习中许多经典模型都可以从“记忆”的角度来理解。

从神经心理学中关于记忆与学习的常见定义出发,可以把学习看作:在给定目标下,获取有效且有用的记忆的过程。按照这个视角,RNN 可以看作带有向量值记忆模块 M M M 的模型,并具有两个步骤:

  1. 更新记忆:用新输入 x t x_t xt 更新 M t − 1 M_{t-1} Mt1
  2. 读取记忆:从更新后的记忆中提取与输入对应的信息。

Transformer 同样可以视为具有“增长型记忆”的结构,其中 key-value 对构成其记忆,模型通过追加方式更新记忆,并通过 query-key 相似度来读取对应信息。

这一视角帮助我们提出五个核心问题:

  • 什么样的记忆结构才是好的?
  • 什么样的记忆更新机制才是合适的?
  • 什么样的记忆读取过程才是好的?
  • 如何设计一个能融合多种记忆模块的高效架构?
  • 为了有效存储长久过去,是否需要深层记忆模块

Contributions and Roadmap

本文试图通过设计一种能够在测试时高效学习记忆的长时神经记忆模块,来回答上述五个问题,并进一步讨论如何将其纳入整体架构中。

作者的主要贡献包括:

  • 提出一种深层神经长时记忆
  • 给出其快速、可并行的训练方法
  • 基于短时记忆、长时记忆和持久记忆,提出 Titans 架构家族;
  • 在多种任务中验证其有效性。

2 Preliminaries

本节主要给出符号说明,并回顾注意力、线性注意力和现代线性循环模型的背景。这里仅保留与后续方法直接相关的核心公式。

2.1 Backgrounds

Attention

给定输入 x ∈ R N × d i n x \in \mathbb{R}^{N \times d_{in}} xRN×din,因果注意力首先计算 query、key、value:

Q = x W Q , K = x W K , V = x W V (1) Q = xW_Q,\quad K = xW_K,\quad V = xW_V \tag{1} Q=xWQ,K=xWK,V=xWV(1)

随后第 i i i 个位置的输出为:

y i = ∑ j = 1 i exp ⁡ ( Q i ⊤ K j / d i n ) V j ∑ ℓ = 1 i exp ⁡ ( Q i ⊤ K ℓ / d i n ) (2) y_i= \sum_{j=1}^{i}\frac{\exp \left(Q_i^\top K_j/\sqrt{d_{in}}\right)V_j} {\sum_{\ell=1}^{i}\exp \left(Q_i^\top K_\ell/\sqrt{d_{in}}\right)} \tag{2} yi=j=1i=1iexp(QiK/din )exp(QiKj/din )Vj(2)

其中, W Q , W K , W V ∈ R d i n × d i n W_Q, W_K, W_V \in \mathbb{R}^{d_{in}\times d_{in}} WQ,WK,WVRdin×din 为可学习参数。

Efficient Attentions

在线性注意力中,标准注意力中的 softmax 被核函数 ϕ ( ⋅ , ⋅ ) \phi(\cdot,\cdot) ϕ(,) 替换,且满足:

ϕ ( x , y ) = ϕ ( x ) ϕ ( y ) \phi(x,y)=\phi(x)\phi(y) ϕ(x,y)=ϕ(x)ϕ(y)

于是注意力可写为:

$$
y_i

=
\sum_{j=1}{i}\frac{\phi(Q_i\top K_j)V_j}
{\sum_{\ell=1}{i}\phi(Q_i\top K_\ell)}

=

\frac{\phi(Q_i)^\top \sum_{j=1}^{i}\phi(K_j)V_j}
{\phi(Q_i)^\top \sum_{\ell=1}^{i}\phi(K_\ell)}
\tag{3}
$$

当核函数取恒等映射时,这一形式还能改写成循环形式:

M t = M t − 1 + K t ⊤ V t (4) M_t = M_{t-1}+K_t^\top V_t \tag{4} Mt=Mt1+KtVt(4)

y t = Q t M t (5) y_t = Q_t M_t \tag{5} yt=QtMt(5)

这样就能实现高效推理。

Modern Linear Models and Their Memory Perspective

从记忆视角出发,RNN 的一般形式可以写成:

M t = f ( M t − 1 , x t ) (写操作) (6) M_t = f(M_{t-1}, x_t) \qquad \text{(写操作)} \tag{6} Mt=f(Mt1,xt)(写操作)(6)

y t = g ( M t , x t ) (读操作) (7) y_t = g(M_t, x_t) \qquad \text{(读操作)} \tag{7} yt=g(Mt,xt)(读操作)(7)

其中, M t M_t Mt 表示时刻 t t t 的记忆状态。作者指出,线性 Transformer 的记忆写入本质上是把历史 key-value 对加性压缩进一个矩阵值记忆单元中,这在超长上下文下容易产生记忆溢出。现有工作主要从两条路改进:

  1. 加入遗忘机制;
  2. 改进写操作本身。

3 Learning to Memorize at Test Time

本节是论文核心。作者提出一种测试时学习记忆的神经长时记忆模块。它本质上是一个元模型:在测试阶段仍持续更新,从而把过去历史的抽象编码进自身参数中。

3.1 Long-term Memory

为了设计长时神经记忆模块,我们需要一种模型,它能把过去历史的抽象编码到参数中。一个直觉想法是:直接训练一个神经网络,让它记住训练数据。但“记忆训练数据”通常被认为是坏现象,因为它会限制泛化,还可能导致隐私问题。

作者认为,真正需要的是一个在线元模型:它在测试时学习如何记住/遗忘数据。这样,模型学到的是“如何记忆”的函数,而不是简单地对训练集过拟合。

Learning Process and Surprise Metric

训练长时记忆的关键思想,是把它看作一个在线学习问题:希望把过去信息 x 1 , … , x t − 1 x_1,\dots,x_{t-1} x1,,xt1 压缩到长时神经记忆模块 M t M_t Mt 的参数中。

受人类记忆机制启发,违反预期的事件更令人难忘。因此,作者将“惊讶度”定义为模型关于输入损失的梯度。梯度越大,说明当前输入与过去数据越不同,于是越值得被记忆。

基本更新公式为:

M t = M t − 1 − θ t ∇ ℓ ( M t − 1 ; x t ) (8) M_t = M_{t-1} - \theta_t \nabla \ell(M_{t-1}; x_t) \tag{8} Mt=Mt1θt(Mt1;xt)(8)

其中,梯度项表示 surprise。

但单纯依赖这个惊讶度会丢失某些重要信息:当经历几个强惊讶时刻后,梯度可能迅速变小,模型会停留在平坦区域,导致后续信息难以继续被记住。为解决这个问题,作者把惊讶度拆分成两部分:

  1. 过去惊讶:衡量最近过去一段时间的惊讶积累;
  2. 瞬时惊讶:衡量当前输入的惊讶。

对应更新为:

M t = M t − 1 + S t (9) M_t = M_{t-1}+S_t \tag{9} Mt=Mt1+St(9)

S t = η t S t − 1 − θ t ∇ ℓ ( M t − 1 ; x t ) (10) S_t = \eta_t S_{t-1} - \theta_t \nabla \ell(M_{t-1}; x_t) \tag{10} St=ηtSt1θt(Mt1;xt)(10)

这与带动量的梯度下降非常相似,其中 S t S_t St 相当于动量项。因此,动量在这里扮演的是跨时间传播惊讶的记忆。其中:

  • η t \eta_t ηt 是数据依赖的惊讶衰减项,用于控制过去惊讶随时间如何衰减;
  • θ t \theta_t θt 控制当前瞬时惊讶被纳入最终惊讶度的比例。

这种数据依赖性非常关键:如果上下文发生切换,则模型可以令 η t → 0 \eta_t\to 0 ηt0 来忽略上一个惊讶;如果当前 token 与最近过去强相关,则可以令 η t → 1 \eta_t\to 1 ηt1 来继承此前惊讶。

Objective

上述惊讶度建立在一个损失函数 ℓ ( ⋅ ; ⋅ ) \ell(\cdot;\cdot) (;) 之上,这个损失决定了记忆在测试时“学会做什么”。

本文聚焦于联想记忆,即把过去数据存成 key-value 对。对于输入 x t x_t xt,作者和 Transformer 一样,用两层线性层把它映射成 key 与 value:

k t = x t W K , v t = x t W V (11) k_t = x_t W_K,\qquad v_t = x_t W_V \tag{11} kt=xtWK,vt=xtWV(11)

其中 W K , W V ∈ R d i n × d i n W_K, W_V \in \mathbb{R}^{d_{in}\times d_{in}} WK,WVRdin×din。接着,希望记忆模块学习 key 与 value 的关联,因此定义损失:

ℓ ( M t − 1 ; x t ) = ∣ M t − 1 ( k t ) − v t ∣ 2 2 (12) \ell(M_{t-1}; x_t)=\left|M_{t-1}(k_t)-v_t\right|_2^2 \tag{12} (Mt1;xt)=Mt1(kt)vt22(12)

通过在元模型的内循环中优化该损失,模型就学会了在测试时记住 key 到 value 的映射。这里,内循环优化的是记忆模块 M M M 的参数,而外循环优化的是整体架构中的其他参数。

Forgetting Mechanism

在超长序列场景下,仅靠存储是不够的,还必须决定遗忘什么。因此作者引入了自适应遗忘机制:

M t = ( 1 − α t ) M t − 1 + S t (13) M_t = (1-\alpha_t)M_{t-1}+S_t \tag{13} Mt=(1αt)Mt1+St(13)

S t = η t S t − 1 − θ t ∇ ℓ ( M t − 1 ; x t ) (14) S_t = \eta_t S_{t-1}-\theta_t \nabla \ell(M_{t-1}; x_t) \tag{14} St=ηtSt1θt(Mt1;xt)(14)

其中, α t ∈ [ 0 , 1 ] \alpha_t \in [0,1] αt[0,1] 是门控项,用来灵活控制遗忘程度:

  • α t → 0 \alpha_t \to 0 αt0 时,几乎不遗忘;
  • α t → 1 \alpha_t \to 1 αt1 时,可以几乎清空整个记忆。

作者指出,这种权重衰减机制与现代 RNN 中的门控遗忘机制密切相关。

Memory Architecture

本文中,长时记忆模块采用简单的 MLP,层数满足 L M ≥ 1 L_M \ge 1 LM1。作者强调,本文重点不在于发明最强记忆网络结构,而在于提出长时记忆的设计思路与并入架构的方法

但作者进一步指出:如果记忆模块仅仅是向量值或矩阵值,那么它实际上是在做一种线性压缩。从元学习/在线学习视角看,矩阵记忆 M = W ∈ R d i n × d i n M=W\in\mathbb{R}^{d_{in}\times d_{in}} M=WRdin×din 对应于一个在线线性回归目标:

ℓ ( W t − 1 ; x t ) = ∣ W t − 1 k t − v t ∣ 2 2 \ell(W_{t-1};x_t)=|W_{t-1}k_t-v_t|_2^2 (Wt1;xt)=Wt1ktvt22

因此其最优解隐含地假设历史依赖是线性的。而作者认为,深层记忆模块(即 L M ≥ 2 L_M\ge 2 LM2)会更有效,因为两层及以上 MLP 严格强于线性模型。

Retrieving a Memory

在完成记忆设计与训练后,剩余问题是:如何从记忆中读出信息?

作者采用非常直接的方法:在不更新权重的前向传播下,把输入作为 query,从记忆中取回对应信息。具体地,先用线性层将输入投影为 query:

q t = x t W Q q_t = x_t W_Q qt=xtWQ

再读取:

y t = M ∗ ( q t ) (15) y_t = M^*(q_t) \tag{15} yt=M(qt)(15)

这里 M ∗ M^* M 表示只做推理、不更新权重的记忆模块。


3.2 How to Parallelize the Long-term Memory Training

长时记忆的设计本质上等价于:对联想记忆损失进行带动量、带权重衰减的梯度下降。理论上,这要求关于序列长度 N N N O ( N ) O(N) O(N) FLOPs。但在实际中,为了充分利用 GPU/TPU,必须把这一过程张量化,使其由更多 matmul 组成,从而实现并行训练。

作者将序列切成大小为 b ≥ 1 b\ge 1 b1 的 chunk,并把 mini-batch 梯度下降写成:

M t = ( 1 − α t ) M t − 1 − θ t ∇ ℓ ( M t − 1 ; x t ) = β t M 0 − ∑ i = 1 t θ i β t β i ∇ ℓ ( M t ′ ; x i (16) M_t=(1-\alpha_t)M_{t-1}-\theta_t \nabla \ell(M_{t-1};x_t) = \beta_t M_0-\sum_{i=1}^{t}\theta_i\frac{\beta_t}{\beta_i}\nabla \ell(M_{t'};x_i \tag{16} Mt=(1αt)Mt1θt(Mt1;xt)=βtM0i=1tθiβiβt(Mt;xi(16)

其中:

t ′ = t − m o d ( t , b ) , β i = ∏ j = 1 i ( 1 − α j ) t' = t-\mathrm{mod}(t,b),\qquad \beta_i=\prod_{j=1}^{i}(1-\alpha_j) t=tmod(t,b),βi=j=1i(1αj)

对于线性情形 M t = W t M_t=W_t Mt=Wt,有:

∇ ℓ ( W 0 ; x t ) = ( W 0 x t − x t ) x t ⊤ \nabla \ell(W_0;x_t)=(W_0x_t-x_t)x_t^\top (W0;xt)=(W0xtxt)xt

因此
∑ i = 1 b θ i β b β i ∇ ℓ ( W 0 ; x i ) = Θ b B b ( W 0 X − X ) X ⊤ (17) \sum_{i=1}^{b}\theta_i\frac{\beta_b}{\beta_i}\nabla \ell(W_0;x_i) = \Theta_b B_b (W_0X-X)X^\top \tag{17} i=1bθiβiβb(W0;xi)=ΘbBb(W0XX)X(17)

其中, Θ b = d i a g ( θ 1 , θ 2 , … , θ b ) \Theta_b=\mathrm{diag}(\theta_1,\theta_2,\dots,\theta_b) Θb=diag(θ1,θ2,,θb) B b B_b Bb 类似地由 β b / β i \beta_b/\beta_i βb/βi 构成。这样就能把原本的逐步更新改写成矩阵乘法与求和。

若再加入动量项,则 chunk 内的动量递推为:

S t = η t S t − 1 − θ t u t (18) S_t = \eta_t S_{t-1}-\theta_t u_t \tag{18} St=ηtSt1θtut(18)

其中:

u t = ∇ ℓ ( M t ′ ; x t ) u_t = \nabla \ell(M_{t'};x_t) ut=(Mt;xt)

由于所有 u t u_t ut 可同时计算,上式变成了以 u t u_t ut 为输入、 S t S_t St 为隐藏状态、 η t \eta_t ηt 为输入依赖转移值的线性递推,因此可以用并行 associative scan 来高效计算。

作者还讨论了一种进一步简化:把 α t , θ t , η t \alpha_t,\theta_t,\eta_t αt,θt,ηt 不再设为 token 级输入依赖,而改成 chunk 级参数。虽然表达能力会下降,但训练速度可能进一步提升。


3.3 Persistent Memory

长时记忆本质上是一种上下文记忆:其输出完全依赖于当前上下文。因此,除了长时记忆之外,作者还引入一组可学习但与输入无关的参数,作为任务相关记忆。这类记忆在文献中被称为 persistent memory 或 meta-memory。

给定 N p ≥ 1 N_p\ge 1 Np1,设持久记忆参数为:

P = [ p 1 , p 2 , … , p N p ] P=[p_1,p_2,\dots,p_{N_p}] P=[p1,p2,,pNp]

将其拼接到序列开头:

x new = [ p 1 ; p 2 ; …   ; p N p ] ; ∣ ∣ ; x (19) x_{\text{new}} = [p_1;p_2;\dots;p_{N_p}] ;||; x \tag{19} xnew=[p1;p2;;pNp];∣∣;x(19)

其中 ∣ ∣ || ∣∣ 表示拼接。

作者从三个角度解释其作用:

Memory Perspective

上下文记忆中的参数是输入依赖的,但一个有效的记忆系统还需要输入无关的参数,用于存储任务知识本身。换言之,掌握一个任务,不仅需要记住样本,还需要记住“这个任务该怎么做”。

Feedforward Network Perspective

Transformer 中注意力后的前馈层,可以看成某种输入无关参数的注意力形式。若把前馈层的 ReLU 替换为 Softmax,则有:

F F N ( x ) = W V , S o f t m a x ( W K x ) (20) FFN(x)=W_V,\mathrm{Softmax}(W_Kx) \tag{20} FFN(x)=WV,Softmax(WKx)(20)

此时, W K W_K WK W V W_V WV 的作用就类似于注意力中的 K , V K,V K,V 矩阵,但它们是输入无关的。持久记忆参数预期也能起到类似作用。

Technical Perspective

因果掩码下,注意力对序列初始 token 往往存在隐式偏置。把这些可学习参数放在序列最前面,可以更有效地重新分配注意力权重,从而缓解这一问题。


4 How to Incorporate Memory?

在完成长时神经记忆模块的设计后,接下来的问题是:如何把它高效且有效地并入深度学习架构?作者认为:

  • Transformer 中的注意力,由于对当前上下文的依赖建模精确但窗口有限,可视为短时记忆
  • 本文的神经记忆模块,由于能够持续从数据中学习并把信息存入参数,可视为长时记忆

围绕这一点,作者提出了 Titans 的三种变体。

4.1 Memory as a Context

第一种设计把记忆当作当前信息的上下文。

给定长序列 x ∈ R N × d i n x\in\mathbb{R}^{N\times d_{in}} xRN×din,先把序列分成若干固定长度片段 S ( i ) S^{(i)} S(i)。对于当前片段 S ( t ) S^{(t)} S(t),将其视为当前上下文,之前的片段视为历史信息。设 M t − 1 M_{t-1} Mt1 是处理当前片段前的长时记忆状态,则首先把当前片段作为 query,从记忆中取回相关历史:

h t = M t − 1 ∗ ( q t ) (21) h_t = M_{t-1}^*(q_t) \tag{21} ht=Mt1(qt)(21)

其中:

q t = S ( t ) W Q q_t = S^{(t)}W_Q qt=S(t)WQ

然后把取回的历史信息和持久记忆一起拼接,送入注意力模块:

S ~ ( t ) = [ p 1 ; p 2 ; …   ; p N p ] ; ∣ ∣ ; h t ; ∣ ∣ ; S ( t ) (22) \tilde{S}^{(t)} = [p_1;p_2;\dots;p_{N_p}] ;||; h_t ;||; S^{(t)} \tag{22} S~(t)=[p1;p2;;pNp];∣∣;ht;∣∣;S(t)(22)

y t = A t t n ( S ~ ( t ) ) (23) y_t = \mathrm{Attn}(\tilde{S}^{(t)}) \tag{23} yt=Attn(S~(t))(23)

接着用 y t y_t yt 更新长时记忆,并得到最终输出:

M t = M t − 1 ( y t ) (24) M_t = M_{t-1}(y_t) \tag{24} Mt=Mt1(yt)(24)

o t = y t ⊗ M t ∗ ( y t ) (25) o_t = y_t \otimes M_t^*(y_t) \tag{25} ot=ytMt(yt)(25)

这里 ⊗ \otimes 表示某种组合操作。

作者认为这一设计有三个优点:

  1. 注意力同时看到历史和当前上下文,因此能决定当前任务是否需要长时记忆;

  2. 注意力可以帮助长时记忆只存储当前片段中有用的信息,避免记忆容量被无关 token 挤占;

  3. 在测试时:

    • 持久记忆参数固定不变;
    • 注意力分支负责上下文学习;
    • 长时记忆模块继续学习、继续记忆。

4.2 Gated Memory

第二种设计中,一条分支直接用输入更新长时记忆,另一条分支使用滑动窗口注意力(SWA):

x ~ = [ p 1 ; p 2 ; …   ; p N p ] ; ∣ ∣ ; x (26) \tilde{x} = [p_1;p_2;\dots;p_{N_p}] ;||; x \tag{26} x~=[p1;p2;;pNp];∣∣;x(26)

y = S W − A t t n ∗ ( x ~ ) (27) y = \mathrm{SW -Attn}^*(\tilde{x}) \tag{27} y=SWAttn(x~)(27)

o = y ⊗ M ( x ~ ) (28) o = y \otimes M(\tilde{x}) \tag{28} o=yM(x~)(28)

这里, S W − A t t n ∗ \mathrm{SW-Attn}^* SWAttn 表示带 prefix 的滑动窗口注意力。与前一种设计不同,这里不再对输入分段。同时, M ( x ) M(x) M(x) 表示在整条序列递归完成后的记忆输出。

在这一设计中:

  • 滑动窗口注意力充当精确的短时记忆
  • 神经记忆模块充当逐渐衰减的长时记忆

作者认为,这种结构也可以理解为一种不同头结构的“多头架构”。

4.3 Memory as a Layer

第三种设计把神经记忆直接当作深度网络中的一层:

x ~ = [ p 1 ; p 2 ; …   ; p N p ] ; ∣ ∣ ; x (29) \tilde{x} = [p_1;p_2;\dots;p_{N_p}] ;||; x \tag{29} x~=[p1;p2;;pNp];∣∣;x(29)

y = M ( x ~ ) (30) y = M(\tilde{x}) \tag{30} y=M(x~)(30)

o = S W − A t t n ( y ) (31) o = \mathrm{SW-Attn}(y) \tag{31} o=SWAttn(y)(31)

这种设计在文献中更常见,因为很多混合模型本来就是把循环模型和注意力层叠加起来。

其主要缺点是:模型能力会被每一层的处理方式所限制,因此无法充分利用“注意力”和“神经记忆”之间的互补性。

Memory Without Attention

作者还提出一种更简单的变体:完全去掉注意力,只把 LMM 当作序列模型。理由是,从记忆视角看,记忆系统的每个部分都应当具备独立工作能力,因此长时记忆本身也应当在没有短时记忆时依然强大。

4.4 Architectural Details

实现上,作者省略了残差连接、线性门控和归一化等细节讨论,但实际实现中这些组件都被使用了。作者采用:

  • 残差连接;
  • SiLU 非线性;
  • 对 query 和 key 做 ℓ 2 \ell_2 2 归一化;
  • 在 query、key、value 投影后加入 1D depthwise-separable convolution;
  • 输出前再做归一化与线性门控。

Theorem 4.1

作者给出一个理论结论:与 Transformer、对角线线性循环模型和 DeltaNet 相比,Titans 能解决超出 T C 0 TC^0 TC0 的问题,因此在状态跟踪任务上,Titans 在理论上比 Transformer 和多数现代线性循环模型更具表达能力。


5 Experiments

这一节主要验证五个问题:

  1. Titans 相较于基线在下游任务中的表现如何;
  2. Titans 的实际有效上下文长度如何;
  3. Titans 随上下文长度扩展时表现如何;
  4. 记忆深度如何影响性能与效率;
  5. Titans 各个组件分别贡献了什么。

5.1 Experimental Setup

作者主要评估四类模型:

  1. Titans (MAC)
  2. Titans (MAG)
  3. Titans (MAL)
  4. 单独的神经记忆模块 LMM

模型规模包括 170M、340M、400M、760M。训练数据来自 FineWeb-Edu,训练长度为 4K token,优化器使用 AdamW,学习率为 4 × 10 − 4 4\times 10^{-4} 4×104,采用 cosine annealing,batch size 为 0.5M token,weight decay 为 0.1。

5.2 Language Modeling

在语言建模与常识推理任务中,作者发现:

  • 单独的神经记忆模块,在非混合模型中取得了最好的困惑度和准确率表现;
  • 与同样是基于梯度更新的 TTT 相比,本文方法中的权重衰减(遗忘机制)动量非常关键;
  • 与带门控的 Mamba、Mamba2、Gated DeltaNet 相比,本文的优势说明:惊讶机制 + 深层非线性记忆是有效的;
  • 在混合模型中,三种 Titans 变体都优于现有混合基线;
  • MAC 和 MAG 表现接近,但 MAC 在更长依赖场景下通常更强;
  • MAG 与 MAC 都优于 MAL,说明“如何把记忆并入架构”本身就很重要。

5.3 Needle in a Haystack

作者指出,把模型扩展到更长上下文,并不等于它在超长序列上真的有效。needle-in-a-haystack 任务正是为了测量模型的真实有效上下文长度

实验结果显示,神经记忆模块在该类任务中优于基线。作者将这一优势归因于三点:

  1. 相比 TTT,本文的动量和遗忘机制更能管理有限记忆容量;
  2. 相比 Mamba2,本文具备深层非线性记忆,因此记忆管理更强;
  3. 在序列变长时,神经记忆的性能下降更小,趋势更稳定。

5.4 / 5.5 / 5.8 / 5.9 结果总结

结合论文后续实验部分,作者给出的主要结论可以概括为:

  • Titans 可以有效扩展到超过 2M 的上下文窗口;

  • 深层记忆模块比线性记忆更有效;

  • 在效率上,神经记忆模块略慢于 Mamba2 和 Gated DeltaNet,原因在于:

    • 其记忆更深、更新更复杂;
    • Mamba2 具有高度优化的 kernel;
  • 但 Titans (MAL) 借助 FlashAttention,训练吞吐仍然很高;

  • 消融实验表明,以下组件都对性能有正向贡献:

    • 深层记忆
    • convolution
    • 动量
    • weight decay / forgetting
    • persistent memory
  • 在三种 Titans 架构中,MAC 与 MAG 在语言建模和常识推理上接近,但 MAC 在长上下文 NIAH 上显著更优;两者均强于 MAL。


6 Conclusion

本文提出了一种神经长时记忆模块,它作为一种元上下文学习器,能够在测试时学习记忆。该记忆模块本质上是循环式的,并会自适应地记忆那些更令人惊讶、或接近惊讶 token 的信息。

在此基础上,作者进一步提出了 Titans 架构,将:

  • 短时记忆(注意力)
  • 长时记忆(神经记忆)
  • 持久记忆(任务知识参数)

统一到一个整体系统中。实验表明,这种学习范式在多种任务上都优于现有现代循环模型和 Transformer 基线,并且能够扩展到极长上下文。


Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐