abstract

基于强化学习(RL)的大语言模型(LLM)后训练【post-training】计算成本很高,因为它会生成大量 rollout 序列,而这些序列经常可能共享很长的 token 前缀。现有的 RL 框架通常会独立处理这些序列,在策略模型训练的前向传播和反向传播过程【forward and backward passes】中反复重新计算相同的前缀,从而造成大量计算和内存使用上的低效。尽管前缀共享【prefix sharing】天然会在 rollout 之间形成一种树结构,但以往基于 tree attention 的解决方案依赖完全物化的 attention mask,因此在 RL 场景下扩展性很差。

在本文中,我们提出 AREAL-DTA,用于在 RL 训练中高效利用前缀共享。AREAL-DTA 采用一种基于深度优先搜索(depth-first-search,DFS)的执行策略,在前向计算和反向计算过程中动态遍历 rollout 前缀树,每次只物化【materializing】一条从根节点到叶子节点的路径。为了进一步提升可扩展性,AREAL-DTA 还引入了一种负载均衡的分布式 batching 机制,可以在多张 GPU 上动态构建并处理前缀树。在流行的 RL 后训练工作负载中,AREAL-DTA 在 τ²-bench 上实现了最高 8.31 倍的训练吞吐提升。

1 Introduction

使用强化学习(reinforcement learning,RL)对大语言模型(LLM)进行后训练通常计算开销很大。RL 后训练流程通常需要针对每个 prompt 或环境状态生成许多相似的 rollout 序列,以探索不同结果或收集足够的奖励信号。关键的是,这些序列经常共享很长的前缀——例如,多个候选回复可能都以相同的用户 prompt 或初始对话轮次开头。在当前实践中,每条序列都会被独立处理,这导致在策略模型训练过程中,不同 rollout 之间会对相同前缀进行冗余计算。

这种重复会带来很高的计算成本和内存开销,因为策略模型训练会在每次前向传播和反向传播中不必要地重新计算相同的 token 前缀。这会形成一个主要瓶颈,降低训练吞吐量,并增加 GPU 显存使用。在本文中,我们希望探索如何有效利用这种前缀共享范式来提升 RL 训练效率。

前缀共享【Prefix sharing】在现代 LLM 的 RL 训练流程中普遍存在。例如,用于高级推理 LLM 的 RL 训练以及多轮 agent 训练,经常会针对每个上下文采样多个 continuation,而这些 continuation 都从相同的 prompt 或思维链推理步骤开始。这些分支式轨迹天然形成一种前缀树结构:它们在分化成不同结果之前,会共享一个共同的主干【a prefix tree structure】,也就是前缀。

然而,普通的 RL 训练流水线【vanilla RL training】无法利用这种结构——它们会把每个分支都当作一条独立序列来处理。结果是,重叠的前缀计算会被执行多次,从而浪费计算预算和内存消耗。如果我们能够利用前缀共享,让轨迹中的共享部分只被计算一次,就应该能够高效复用前缀计算。这将消除大量重复工作,按比例节省运行时间,并且能够在不耗尽显存的情况下支持更大的 batch size 或更多采样。

然而,由于 Transformer 架构中的 attention 机制以及前缀树分支计算的复杂性,在策略模型训练中利用前缀共享并不容易。一种朴素方案,例如 tree attention,可能会尝试使用一个合并后的 attention mask,把所有 rollout 轨迹合并到一次巨大的前向传播中【merge all rollout trajectories into one giant forward pass】。这个合并的 attention mask 用来确保每个 token 只能关注到它对应的正确前缀上下文。不幸的是,这类基于树的 attention 机制需要为整棵前缀树完全物化【materializing】大型 attention mask。

因此,在 RL 后训练流程中,这种方法的扩展性很差:其内存和计算开销会随着所有合并序列的总 token 数呈二次增长。这种完全物化的前缀树可能会超过 GPU 显存限制,并且由于需要管理巨大的 attention mask 矩阵,会显著拖慢计算。在实践中,先前沿着这一路线的方法并不能有效扩展——这类标准 tree attention 方法会产生过高的额外开销,往往抵消了前缀复用带来的收益。

在本文中,我们提出 AREAL-DTA,也就是在 AReaL RL 训练框架之上的动态树注意力方法。这是一种新的解决方案,能够利用前缀共享,同时克服已有方法继承下来的扩展性挑战。其核心思想是在基于 Transformer 的策略模型训练中,针对前向传播和反向传播采用一种基于深度优先搜索(depth-first search,DFS)的动态计算策略。AREAL-DTA 还包含一种负载均衡的分布式 batching 策略,可以有效扩展 RL 训练计算。具体来说,我们的贡献如下。

Contribution 1 贡献 1:我们设计并实现了一种创新的 token 前缀树 DFS 遍历方法。AREAL-DTA 并不构造完全物化的 attention mask,而是将一组 rollout 看作一棵前缀树,并使用 DFS 对其进行动态遍历。在 DFS 中压入前缀树节点时,AREAL-DTA 每次只探索前缀树的一条分支。它会复用所有共享前缀的计算,并且只在分支发生分叉时才分配新的计算资源。在任意时刻,活跃上下文都被限制在一条从根节点,也就是公共前缀,到某个叶子节点的单一路径上。这会显著降低内存占用;换句话说,AREAL-DTA 从不需要同时把整棵 token 树的 activation 都保存在显存中。在 DFS 中弹出前缀树节点时,AREAL-DTA 会在访问下一条分支之前,立即沿着当前分支反向传播对应的梯度。因此,中间 activation 不需要为所有分支同时保存。同时,来自共享前缀的梯度贡献可以被正确累积,而不需要重复计算。这种 DFS 遍历保证了每个前缀树节点的计算只执行一次,并被其所有后续 continuation 复用。类似地,每个前缀对应的梯度也会被正确聚合。整个过程的内存使用量与最长序列的深度成正比,而不是与前缀树中的总 token 数成正比。

Contribution 2 贡献 2:为了进一步扩展 RL 训练规模,我们为 AREAL-DTA 开发了一种负载均衡的分布式 batching 策略。具体来说,AREAL-DTA 会在异步 rollout 生成阶段,对生成出的 rollout 进行动态 batching,从而构建多棵前缀树。然后,它会把这些计算分发到多个训练 GPU worker 上。这样可以在各个训练 GPU 之间保持计算负载均衡。这种机制能够最小化 GPU 空闲时间。因此,即使每次策略模型训练迭代中包含大量序列,并且每条轨迹都很长,它也能形成一种可扩展的训练机制,高效构建并遍历 RL rollout 的前缀树。

Contribution 3 贡献 3:我们在广泛的 RL 训练任务上实现了显著的性能提升。在实验中,我们证明 AREAL-DTA 在具有挑战性的 RL 微调基准上,在速度和内存效率方面都带来了显著收益。我们在多个流行的 RL 任务上评估了 AREAL-DTA,并发现 AREAL-DTA 始终优于标准 RL baseline。例如,AREAL-DTA 实现了显著的吞吐提升:单个训练 worker 最高提升 8.31 倍,整个训练集群最高提升 6.20 倍,完整端到端流水线最高提升 2.28 倍。这种加速可以归因于其内存高效的设计。与普通方法相比,我们观察到显存使用量大幅下降,峰值 GPU 显存通常减少超过 50%。因此,AREAL-DTA 不再需要依赖额外的内存优化技术,例如 activation checkpointing 或 gradient accumulation。否则,这些额外技术本身会引入不可忽视的计算开销。

2 Preliminaries and Related Work

2.1 RL System for LLM Post-training

强化学习(RL)已经被广泛用于提升大语言模型(LLM)的推理能力。已有研究表明,基于 RL 的后训练可以显著提升模型在大量推理密集型任务上的表现,包括数学推理、程序合成和多跳问答。从系统角度看,LLM 的 RL 训练对资源需求极高,并且通常包含三个不同阶段。

第一阶段是 rollout 生成,它会在 GPU 上执行推理,为每个 prompt 生成多个候选回复,也就是 rollout;这一阶段通常受限于 HBM I/O。

第二阶段是奖励估计,它可能依赖密集的 CPU 资源,例如用于代码评测的沙箱执行,或者用于数学任务的基于规则的求解器。当使用基于 LLM 的奖励模型或价值模型时,奖励估计阶段也可能需要额外的 GPU 资源。

第三阶段是模型训练,它通过随机梯度优化对策略模型,以及可选的价值模型,执行计算密集型的 GPU 更新;有时还会引入参考模型来稳定训练。

现有的 LLM RL 训练流水线通常可以分为同步范式【synchronous】和异步范式【asynchronous】。在同步 RL 训练中,rollout 生成和模型优化会以交替迭代的方式执行。也就是说,系统首先使用当前策略模型参数生成推理轨迹,然后用得到的 rollout 来更新模型。相比之下,异步 RL 训练允许这些阶段并发进行。在这种情况下,rollout 生成会使用可能已经过时的参数持续产生轨迹,同时训练进程并行地更新模型。在这些异步系统中,AReaL 进一步通过一种完全异步的架构,将流式生成与训练过程解耦。同时,AReaL 引入了陈旧度感知优化和解耦 RL 目标等算法技术,从而为 LLM 推理工作流实现高效且稳定的 RL 训练。

2.2 Tree Attention for LLM Inference and Training

将多条序列组织成前缀树,最早被探索用于加速 LLM 推理。其方式是并行化投机解码,并在多个候选输出之间复用共享前缀。例如,Specinfer 首先提出将候选 token 组织成一棵 token tree,并通过一次模型前向传播并行验证这些 token。这样可以在每次迭代中验证更多 token。基于类似视角,Medusa 为原始 LLM 增加额外的解码头,用于在一步内预测多个后续 token。同时,Medusa 使用树结构的 attention mask,在每一步同时构建并验证多个 continuation 分支。除了这些 draft–verify 框架之外,近期研究也开始优化 tree decoding graph 本身,以获得更高效率。例如,Sequoia 使用一种基于搜索的策略,在树的深度和宽度之间分配固定的 token budget。它的目标是在给定成本约束下最大化前缀复用。

Yggdrasil 则将动态投机与静态运行时优化连接起来。它会为每个 query 动态选择树的宽度,也就是并行分支数,以及树的深度。同时,它使用一种“equal-growth”的树结构和分阶段调度来维持较高的硬件利用率。互补方向的优化还关注 tree-based attention 的内存和计算开销。例如,FastTree 引入了专门的 attention kernel。这些 kernel 会对共享公共前缀的分支进行计算打包和分块,从而减少冗余的 key/value 加载和内存访问。总体来看,共享前缀状态的复用、用于探索多个 token continuation 的并行解码头或并行分支,再结合通过自定义 attention mask、kernel 和调度实现的优化执行,可以最小化运行时间和内存开销。

与 tree-based attention 在 LLM 推理中的广泛使用相比,它在 LLM 训练中的应用相对缺乏探索。我们发现,唯一相关的尝试是 Tree Training,它同样针对 LLM 的 RL 微调,通过在分支轨迹之间复用前缀计算来提升效率。具体来说,Tree Training 通过一种专门的轨迹 tree-packing 机制和一种梯度校正机制来实现。然而,由于它在内存占用、计算效率和可扩展性方面存在限制,因此不足以支持实际的大规模 RL。Tree Training 会把每个打包后的轨迹树存储在 GPU 显存中,因此显著增加内存使用量。此外,它的静态前缀打包方案依赖自定义 kernel,如果没有有效的并行训练支持,并不容易应用。相比之下,AREAL-DTA 使用动态 DFS 遍历和负载均衡的分布式调度,来解决 RL 训练中的这些核心挑战。

3 Dynamic Tree Attention

Problem formulation

首先把 RL 中的 policy model training 形式化。假设当前一次策略模型训练迭代中有 N 条 rollout 序列,记为 s_1, s_2, \ldots, s_N。每条序列 s_i都对应一个训练损失 L(s_i),这个损失可以是负对数似然,也可以是 RL policy gradient 中的损失信号。总训练目标就是把所有 rollout 序列的损失加起来:

L=\sum_{i=1}^{N}L(s_i)

也就是说,AREAL-DTA 并没有改变 RL 的训练目标;它优化的是这个目标的计算方式。为了利用 rollout 之间共享的前缀,论文把所有序列压缩表示成一棵前缀树T。前缀树中的每个节点表示一段 token segment,这段 token 被某些 rollout 序列共同共享。

从根节点到某个叶子节点的一条路径,就对应一条完整 rollout 序列 s_i。通过遍历这棵前缀树,模型可以对公共前缀只计算一次,然后让多个 rollout 分支复用这部分计算。这里的核心要求有两个:第一,多个序列共同拥有的 prefix 在 forward propagation 中应该只处理一次;第二,这个 prefix 来自所有后代序列的梯度应该被正确累积。

难点在于,既要做到计算复用,又不能引入巨大的内存开销,也不能破坏梯度计算的正确性。AREAL-DTA 的解决方式是:用深度优先搜索动态遍历前缀树,并把 forward 和 backward 交织执行。

3.1 前缀树的深度优先搜索遍历

为了实现 DFS 遍历,AREAL-DTA 维护一个 stack。这个 stack 表示当前正在访问的路径,也就是从 prefix tree 根节点到当前节点的 prefix。在任意时刻,这个 stack 里保存两类东西:第一,当前 prefix 中的 token 序列;第二,当前策略模型为这些 token 生成的中间状态,也就是 Transformer 的 KV cache。给定前缀树 T,训练计算就按照 DFS 的方式推进,完成 policy gradient 所需的 forward 和 backward。

1. Push prefix tree intermediate nodes 压入前缀树的中间节点。第一步是沿着前缀树的一条分支向下走,把中间节点依次压入 stack。压入一个中间节点,意味着当前 prefix 被扩展了;模型只需要对这个新节点中的 token segment 做 forward。由于父节点 prefix 的 KV cache 已经保存在 stack 里,所以子节点 forward 时可以直接接着父 prefix 的 KV 状态继续算。具体来说,从一个 prefix node 移动到它的 child node 时,AREAL-DTA 会把 child node 的 token 输入 policy model,并使用 prefix 已缓存的 KV state。这个过程会计算新 token 的 log-probabilities,并更新扩展后 prefix 的 KV cache。然后,新 token 对应的 KV cache 会被追加到 stack 里。这样,共享 prefix 的计算就被复用了,不需要每条 rollout 都重新计算相同 prefix。

1. KV cache 到底缓存了什么?

假设当前 prefix 是:

prompt + A + B

这个 prefix 经过 policy model forward 后,每一层 attention 都会产生对应的:

K_prompt, V_prompt
K_A, V_A
K_B, V_B

这些就是 KV cache。

它的作用是:后面继续生成/计算新 token 时,不用重新计算 prompt、A、B 的 K/V。

但是注意,KV cache 只代表:

“过去 token 已经算好的 K/V”

它不包含未来 child token 的 K/V,也不包含 child token 的 log-probability。


2. 为什么 child token 还要输入 policy model?

假设现在从父节点走到 child node,child node 是:

C D

完整序列变成:

prompt + A + B + C + D

你训练 policy model 时,需要知道模型对 CD 的概率:

log π(C | prompt, A, B)
log π(D | prompt, A, B, C)

这些 log-probability 不可能只从父节点的 KV cache 里直接读出来。

因为父节点 KV cache 只告诉你:

prompt, A, B 的历史信息已经编码好了

但模型还必须拿 CD 作为输入,经过 embedding、Transformer layers、attention、MLP、lm head,才能算出:

C、D 对应位置的 hidden states
C、D 的 logits
C、D 的 log-probs
C、D 自己产生的新 K/V cache

所以这句话:

AREAL-DTA 会把 child node 的 token 输入 policy model,并使用 prefix 已缓存的 KV state

意思就是:

不是从头输入 prompt+A+B+C+D
而是只输入 C+D,
同时告诉模型:前面 prompt+A+B 的 KV cache 已经在这里了,你直接接着算。

3. 一个具体 forward 例子

普通训练会这样算:

model(prompt + A + B + C + D)

这会重复计算:

prompt, A, B, C, D

而 AREAL-DTA 是:

prefix_kv = model(prompt + A + B) 得到并缓存
child_output, child_kv = model(C + D, past_kv=prefix_kv)

也就是:

已有 prefix KV cache
       ↓
只 forward child tokens
       ↓
得到 child tokens 的 logits / logprobs / 新 KV cache

这里的 past_kv=prefix_kv 就是关键。

2. Visit prefix tree leaf nodes 访问前缀树的叶子节点。当 DFS 到达一个叶子节点时,它对应一条完整 rollout 序列 s_i​。此时,stack 中已经包含完整的 token 序列 s_i,以及这条序列 forward pass 得到的结果,例如 log-probabilities 和 entropy。于是系统可以根据 stack 里的信息计算这条序列的 loss L(s_i)。

这个 loss 可以来自正确 token 的 negative log-likelihood,也可以来自 RL trajectory reward。计算完这条完整序列的 loss 后,AREAL-DTA 立即从这条序列的输出处注入 loss gradient,开始对当前 branch 做 backward。关键点是,它不会等所有序列都 forward 完之后再统一 backward,而是每完成一条 leaf path 就立刻反传。这样做的好处是:一旦这条序列的梯度已经反传完,它对应的计算图就不必继续留在显存里。

3. Pop prefix tree intermediate nodes 弹出前缀树的中间节点。处理完一个叶子节点后,DFS 会沿着当前分支向上回退,也就是 pop stack。在 pop 的过程中,系统会把刚才 L(s_i)产生的梯度沿着当前 branch 的 policy model 计算传播回去。如果某个节点是 branching node,也就是它对应的 token segment 被多条序列共享,那么来自这些序列 loss 的梯度都会累积到这个 prefix node 上。

因为共享 prefix 参与了多条 rollout 的 loss 计算,所以它对模型参数的更新影响,应该等于“所有使用了这个 prefix 的 rollout 的 loss 贡献之和”。

1. 先看普通训练里发生了什么

假设有两条 rollout:

s1 = prompt + A + B + C
s2 = prompt + A + B + D

它们共享前缀:

prompt + A + B

普通训练会把它们当成两条独立 sequence:

loss = L(s1) + L(s2)

也就是:

总损失 = 第一条 rollout 的损失 + 第二条 rollout 的损失

那么对模型参数 θ\thetaθ 求梯度就是:

∇θ loss = ∇θ L(s1) + ∇θ L(s2)

这就是梯度累计。

它不是可选项,而是数学上必然如此。


2. 为什么 shared prefix 也会收到来自两条 loss 的梯度?

因为 prompt + A + B 的 hidden states / KV / activations 会影响后面的 token。

比如第一条:

prompt + A + B -> C -> L(s1)

第二条:

prompt + A + B -> D -> L(s2)

共享前缀 prompt + A + B 参与了两条路径的计算。

所以:

L(s1) 会反向影响 prompt+A+B 对应的计算
L(s2) 也会反向影响 prompt+A+B 对应的计算

因此共享 prefix 的梯度应该是:

来自 C 分支的梯度 + 来自 D 分支的梯度

否则就相当于你只用其中一条 rollout 更新了共享 prefix,另一条 rollout 的训练信号被丢了。


3. 一个更直观的类比

把 shared prefix 想成一个公共岔路口:

          C 分支 -> loss1
         /
prompt-A-B
         \
          D 分支 -> loss2

prompt-A-B 是公共路段。

如果 C 分支走错了,loss1 会告诉模型:

以后遇到 prompt-A-B 这种状态时,往 C 方向的概率要调整。

如果 D 分支也有奖励/惩罚,loss2 也会告诉模型:

以后遇到 prompt-A-B 这种状态时,往 D 方向的概率也要调整。

所以公共路段对应的模型计算,必须同时吸收两个方向传回来的信号。

这就是“梯度累计”。

同时,这些梯度也会累积到模型参数上。AREAL-DTA 通过按照 DFS 顺序对每条 branch 依次 backward 来正确处理这个问题。共享 prefix node 会多次收到梯度贡献,每个后代 leaf 对应一次贡献;这些贡献会在 DFS 遍历过程中被求和。当当前 leaf node 对应计算的梯度已经反传完成后,AREAL-DTA 会把该 leaf 对应的 token 和 activation 从 stack 中弹出。这会把系统状态恢复到父 prefix。被弹出的 token 对应的节点已经不再需要保留在计算图中,因为 DFS 已经完成了该 prefix 下所有相关分支的处理。并且这些节点所需的梯度也已经全部注入和传播过了。因此,它们的 activation 和临时梯度可以安全释放。随后,DFS 会继续访问下一个 sibling branch,并复用仍然保留在 stack 中的共享 prefix state。

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐