7.1 大模型的 Packing 技巧:从预训练到微调的高效 Token 利用

1. 为什么需要 Packing

大模型训练通常会把每条样本整理成固定长度的序列(例如 4096 tokens),然后组成 batch 做并行计算。但真实数据的长度分布极不均匀:有大量短文本、也有少量超长文档。

如果直接把短文本用 [PAD] 补到固定长度,会产生大量无效计算:模型对 [PAD] 的注意力与前向计算几乎不贡献训练信号,却消耗同等算力与显存带宽。

可以用一个简单的“token 利用率”来描述浪费程度。设 batch 里有 BBB 条样本,第 iii 条样本真实长度为 lil_ili,固定填充后的长度为 LLL,则:

  • 实际参与计算的 token 数(非 pad)为 ∑i=1Bli\sum_{i=1}^B l_ii=1Bli
  • 计算总 token 数为 B⋅LB \cdot LBL
  • token 利用率为
    η=∑i=1BliB⋅L \eta=\frac{\sum_{i=1}^B l_i}{B\cdot L} η=BLi=1Bli

当大量 li≪Ll_i \ll LliL 时,η\etaη 会非常低。

一个直观例子

假设 B=8B=8B=8,固定长度 L=4096L=4096L=4096。batch 内真实长度分别为:

  • 7 条短文本:每条 512512512
  • 1 条中等文本:204820482048

则:

  • 非 pad token 总数:7×512+2048=3584+2048=56327\times 512 + 2048 = 3584 + 2048 = 56327×512+2048=3584+2048=5632
  • 总计算 token:8×4096=327688\times 4096 = 327688×4096=32768

利用率:
η=563232768≈0.172 \eta=\frac{5632}{32768}\approx 0.172 η=3276856320.172

意味着约 82.8%82.8\%82.8% 的 token 计算都浪费在 padding 上。


2. Packing 是什么

2.1 核心定义

Packing 的核心思想是:把多条短序列“拼接”成一条更长的序列,让每个 batch 的 token 数尽量贴近上限,从而提高 η\etaη

Packing 可以有两种常见策略目标:

  1. 按 batch 内最长样本长度对齐:把 batch 内样本打包后再按该 batch 最长长度 pad
  2. 按模型最大长度对齐:每个 pack 直接凑到最大长度 Lmax⁡L_{\max}Lmax(例如 4k / 8k / 128k),最大化吞吐

本质上,Packing 是把“padding 浪费”转移为“更高密度的有效 token 训练”。


3. 预训练阶段的 Packing:传统拼接与文档边界

预训练数据通常来自不同来源(网页、书籍、代码、论文段落等),其结构往往是“文本片段集合”。预训练阶段最常见的 packing 就是直接拼接多个文本片段,并用特殊 token 做分隔(例如 [SEP] 或类似的分隔符)。

3.1 为什么需要分隔符(文档边界)

拼接会带来一个关键问题:不同片段之间不应该互相泄漏语义(尤其是来自不同文档/不同样本时)。分隔符提供显式边界信号,使模型学会“这里是一个新片段/新文档的开始”。

在训练目标仍是 Causal LM(CLM)的情况下,模型会在序列上学习下一个 token 预测。拼接后,模型在边界处也会学习到“从 [SEP] 后续开始生成新文档”的统计规律,这通常是可接受且常见的做法。


4. 预训练阶段的难点:超长文本导致“阶段现象”

当预训练中混入少量超长文档(例如数万到十万 token),会出现一个训练组织层面的现象:阶段现象——也就是在训练的某个阶段(特别是还没进入长上下文继续预训练阶段时),超长样本会显著拖慢批处理效率,甚至破坏 batch 的均衡性与吞吐稳定性。

很多训练流程会允许在早期出现少量超长样本,但通常要求其占比不能过高,否则会带来:

  • 每 step 实际吞吐抖动(batch 内 token 数波动大)
  • 显存峰值不可控(注意力/KV 等开销与长度强相关)
  • 优化统计不稳定(不同长度分布导致梯度噪声分布变化)

4.1 一个解决思路:短文本与长文本“分桶 packing”

一个实用策略是:短文本和长文本分开 packing,并为它们设置不同的 pack 上限:

  • 短文本桶:pack 长度上限 LsL_sLs(例如 409640964096
  • 长文本桶:pack 长度上限 LℓL_\ellL(例如 131072131072131072

这样做的目的是让训练过程同时满足:

  1. 短文本 batch 保持高吞吐和稳定性
  2. 长文本 batch 在需要时单独处理,不污染短文本 batch 的效率
为什么还要“对短文本再多进行一次 pack”

当引入长文本桶之后,整体数据流可能会出现一种情况:短文本桶内部仍然会产生碎片(例如某些 batch 的最后一点空间凑不满)。这时可以再做一次更高层次的合并:把多个短文本 pack 再组合成更大粒度的“超级 pack”,尽量减少剩余空间。

可以把它理解为两级装箱(bin-packing):

  • 第一级:短文本片段 →\rightarrow pack(上限 LsL_sLs
  • 第二级:若干 pack →\rightarrow 更高层批组织(尽量减少碎片)

这会进一步提高整体 η\etaη,并减少因长度分布导致的吞吐波动。


5. 微调阶段:不做 Packing 的高效做法(多轮对话为核心)

微调(SFT)阶段的数据常见形态是多轮对话,例如:

  • system: 设定
  • user: 问题
  • assistant: 回答
  • user: 追问
  • assistant: 再回答

很多人第一反应是对每一轮对话的 response 单独构造样本并 padding,但在 CLM 的训练机制下,有一种更高效的方式:一次前向得到所有 token 的 logits,但只对 response 部分计算 loss

5.1 关键原因:CLM 的因果注意力(下三角 mask)

在 CLM 中,序列位置 ttt 的 token 只能关注到 ≤t\le tt 的历史 token,而不能看到未来 token。用注意力 mask 表达时,可以写成:

设输入序列长度为 TTT,注意力 mask M∈{0,1}T×TM\in\{0,1\}^{T\times T}M{0,1}T×T,则因果 mask 满足:

Mt,j={1,j≤t0,j>t M_{t,j}= \begin{cases} 1, & j\le t \\ 0, & j>t \end{cases} Mt,j={1,0,jtj>t

这对应一个由 1 组成的下三角矩阵。

因此,一次前向传播可以为每个位置输出一个分布 p^(xt+1∣x≤t)\hat{p}(x_{t+1}\mid x_{\le t})p^(xt+1xt),模型自然就能覆盖“多轮对话里每个 assistant token 的预测”。

5.2 只在 response 上算 loss:为什么成立

设整段对话拼接成 token 序列 x1:Tx_{1:T}x1:T。CLM 的标准负对数似然损失为:

L=−∑t=1T−1log⁡pθ(xt+1∣x≤t) \mathcal{L}=-\sum_{t=1}^{T-1}\log p_\theta(x_{t+1}\mid x_{\le t}) L=t=1T1logpθ(xt+1xt)

但在 SFT 的对话微调里,我们希望模型学会生成 assistant 的内容,而不是“复述 prompt”。因此定义一个二值选择向量 mt+1∈{0,1}m_{t+1}\in\{0,1\}mt+1{0,1},表示 token xt+1x_{t+1}xt+1 是否属于 response(assistant 部分)。于是损失变为:

Lresp=−∑t=1T−1mt+1⋅log⁡pθ(xt+1∣x≤t) \mathcal{L}_{\text{resp}}=-\sum_{t=1}^{T-1} m_{t+1}\cdot \log p_\theta(x_{t+1}\mid x_{\le t}) Lresp=t=1T1mt+1logpθ(xt+1xt)

这样做的含义是:

  • prompt token 的预测不参与训练目标(m=0m=0m=0
  • response token 的预测参与训练目标(m=1m=1m=1

在多轮对话下,mmm 会在序列中出现多个连续为 1 的片段(每一轮 assistant 回复),从而一次前向就覆盖全部轮次的学习信号。

5.3 一个具体多轮对话例子(用 token 区间说明)

设拼接后的序列结构如下(用区间表示):

  • prompt 区(system+user):位置 1∼8001\sim 8001800
  • assistant 第 1 轮回复:位置 801∼1200801\sim 12008011200
  • user 追问:位置 1201∼14001201\sim 140012011400
  • assistant 第 2 轮回复:位置 1401∼17001401\sim 170014011700

那么 mtm_tmt 的形状就是:

  • t∈[1,800]t\in[1,800]t[1,800] 对应 prompt:mt=0m_t=0mt=0
  • t∈[801,1200]t\in[801,1200]t[801,1200] 对应 assistant:mt=1m_t=1mt=1
  • t∈[1201,1400]t\in[1201,1400]t[1201,1400] 对应 user:mt=0m_t=0mt=0
  • t∈[1401,1700]t\in[1401,1700]t[1401,1700] 对应 assistant:mt=1m_t=1mt=1

一次前向输出 TTT 个位置的 logits,loss 仅在 mt=1m_t=1mt=1 的位置求和即可。

这样做的优势是:

  • 不需要把每轮对话拆成多个样本
  • 不需要在 batch 内重复计算相同 prompt 的前缀部分
  • 训练吞吐更高、显存利用更稳定

6. 4D Mask:在同一 batch 内区分不同“文档/对话”的注意力边界

当把多个样本进行 packing 或在一个 batch 内组织成“拼接长序列”时,另一个关键问题出现了:

同一个 batch 内,不同样本(或不同文档段)之间,应该互相不可见。

仅靠因果下三角 mask 还不够,因为下三角只保证“不能看未来”,但不能阻止“看见前面其实属于另一个样本的 token”。

为了解决这个问题,可以引入更细粒度的 mask,让注意力同时满足两种约束:

  1. 因果约束:只能看历史
  2. 边界约束:只能看同一“文档 id / 段落 id”的 token

6.1 2D mask 到 4D mask 的形状变化

很多实现中,外部构造的 mask 可能是二维的(例如 T×TT\times TT×TB×TB\times TB×T 的形式),但进入注意力计算时,最终会广播/扩展成四维:

mask shape=[B,H,T,T] \text{mask shape}=[B,H,T,T] mask shape=[B,H,T,T]

其中:

  • BBB 是 batch size
  • HHH 是注意力头数(num_heads)
  • TTT 是序列长度

直观理解:batch 中每条序列都需要一张 T×TT\times TT×T 的可见性矩阵;而每个 head 通常共享同一张矩阵,所以复制到 HHH 个 head。

6.2 用“文档标识符”控制跨段可见性

设同一条 packed 序列里,每个 token 都带有一个“段/文档 id” dtd_tdt(整数标签),那么我们希望位置 ttt 的注意力只能关注满足:

  • j≤tj\le tjt(因果)
  • dj=dtd_j=d_tdj=dt(同段)

于是可以定义二维可见性矩阵:

Mt,j={1,(j≤t) ∧ (dj=dt)0,otherwise M_{t,j}= \begin{cases} 1, & (j\le t)\ \land\ (d_j=d_t) \\ 0, & \text{otherwise} \end{cases} Mt,j={1,0,(jt)  (dj=dt)otherwise

然后再扩展到 batch 与 heads:

Mt,j(b,h)=Mt,j(b) M^{(b,h)}_{t,j}=M^{(b)}_{t,j} Mt,j(b,h)=Mt,j(b)

其中 bbb 表示 batch 内第 bbb 条序列,hhh 表示 head。也就是说对同一条序列,所有 heads 共享相同的可见性规则。

6.3 一个具体例子:两段拼接但互不可见

设一个 packed 序列由两段组成:

  • 段 A:位置 1∼10001\sim 100011000dt=1d_t=1dt=1
  • 段 B:位置 1001∼20001001\sim 200010012000dt=2d_t=2dt=2

那么对于段 B 的 token(例如 t=1500t=1500t=1500):

  • 因果允许看 j≤1500j\le 1500j1500
  • 但边界约束要求 dj=2d_j=2dj=2

于是它只能看 1001∼15001001\sim 150010011500,不能看 1∼10001\sim 100011000

这就避免了“跨样本泄漏”,让 packing 在训练上等价于“同 batch 多样本并行”,但计算上仍是“一个长序列”。


7. 小结:预训练与微调的两类高效路线

  1. 预训练阶段:packing 直接拼接,提高 token 利用率

    • 用分隔符标记边界
    • 对超长文本采用“长短分桶 + 不同 pack 上限”缓解吞吐波动与阶段现象
    • 必要时对短文本再进行更高层次的 pack 以减少碎片
  2. 微调阶段:多轮对话不一定要 packing,关键是一次前向、response-only loss

    • 因果 mask 让一次前向覆盖所有 token 的预测
    • 用选择向量 mtm_tmt 只在 assistant token 上累加损失
    • 若 batch 内存在拼接或多样本并行,使用基于段 id 的 4D mask 阻断跨段注意力

通过这些技巧,训练过程可以显著减少 padding 的无效计算、提升吞吐稳定性,并在不改变模型结构的前提下获得更高的算力利用率。

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐