第 04 章:Cola DLM 架构全景 —— 三层解耦的设计哲学

论文Continuous Latent Diffusion Language Model

项目地址ByteDance-Seed/Cola-DLM

核心困惑:Cola DLM 为什么要分成 VAE + DiT + Decoder 三个组件?能不能端到端训练一个模型?


一、用类比建立直觉

回忆我们的"跨语言写作工作室"类比:

  • VAE Encoder(翻译官 q ϕ q_\phi qϕ:把人类语言(离散 token)翻译成"世界语"(连续隐向量)
  • DiT 先验(世界语作家 p ψ p_\psi pψ:在世界语空间里创作——先写大纲(第一个 block),再逐段填充
  • VAE Decoder(翻译官的逆 p θ p_\theta pθ:把世界语作品翻译回人类语言

为什么要这么麻烦?为什么不直接用一个模型端到端完成?

答案是:解耦让每个组件专注于自己最擅长的事


二、层次化联合分布

Cola DLM 的数学框架是一个层次化隐变量模型(论文式 2.1.1):

p ( x , z 0 ) = p θ ( x ∣ z 0 ) ⋅ p ψ ( z 0 ) p(x, z_0) = p_\theta(x \mid z_0) \cdot p_\psi(z_0) p(x,z0)=pθ(xz0)pψ(z0)

边缘化隐变量得到文本的概率:

p ( x ) = ∫ p θ ( x ∣ z 0 ) ⋅ p ψ ( z 0 )   d z 0 p(x) = \int p_\theta(x \mid z_0) \cdot p_\psi(z_0) \, dz_0 p(x)=pθ(xz0)pψ(z0)dz0

三个组件的职责:

组件 符号 职责 是否生成模型的一部分
VAE Encoder q ϕ ( z 0 ∣ x ) q_\phi(z_0 \mid x) qϕ(z0x) 把文本映射到隐空间 (只用于推断/前缀编码)
DiT 先验 p ψ ( z 0 ) p_\psi(z_0) pψ(z0) 在隐空间生成有意义的向量 (核心生成组件)
VAE Decoder p θ ( x ∣ z 0 ) p_\theta(x \mid z_0) pθ(xz0) 把隐向量还原为文本

关键理解:VAE encoder q ϕ q_\phi qϕ 不属于生成模型本身。它只在两个场景下使用:

  1. 训练时:作为变分推断的 inference model
  2. 推理时:把 prompt 编码为隐变量(前缀编码)

三、为什么不能端到端?

3.1 端到端的诱惑

如果把三个组件合并成一个端到端的模型,理论上可以:

  • 减少组件间的接口开销
  • 避免 VAE 的信息瓶颈
  • 训练更简单(一个损失函数)

3.2 解耦的好处

但解耦带来了几个关键优势:

优势 1:各组件独立训练/评估

VAE 可以单独训练(Stage 1),不需要等 DiT 准备好。DiT 可以在固定的隐空间上训练(Stage 2),不需要关心文本的离散性。

Stage 1: 训练 VAE(重构 + KL + BERT mask)
Stage 2: 训练 DiT(Flow Matching)+ 微调 VAE(+ reference KL)

优势 2:隐空间可操作

连续隐空间可以做很多离散 token 空间做不到的操作:

  • 插值:两个隐向量之间可以线性插值
  • 组合:多个隐向量可以加权求和
  • 条件生成:通过修改隐向量的某些维度控制生成

优势 3:先验可替换

VAE 的隐空间定义好之后,可以换不同的先验模型:

  • 用 DiT(当前做法)
  • 用 GAN
  • 用 normalizing flow
  • 甚至用 AR 模型

优势 4:多模态扩展

不同模态可以有各自的 VAE(文本 VAE、图像 VAE、音频 VAE),但共享同一个先验模型。这就是论文 Discussion 部分展示的统一文本-图像实验的思路。


四、分块因果分解

4.1 序列的分块

隐变量 z 0 z_0 z0 沿序列维分解为 B B B 个 block:

z 0 = ( z 0 ( 1 ) , z 0 ( 2 ) , … , z 0 ( B ) ) z_0 = (z_0^{(1)}, z_0^{(2)}, \ldots, z_0^{(B)}) z0=(z0(1),z0(2),,z0(B))

每个 block 的大小由 block_size 参数控制(默认值 4,见 configuration_cola_dit.py:69)。

4.2 先验的因式分解

先验分布按 block 因式分解(论文式 2.1.4):

p ψ ( z 0 ) = p ψ ( z 0 ( 1 ) ) ⋅ ∏ b ≥ 2 p ψ ( z 0 ( b ) ∣ z 0 ( < b ) ) p_\psi(z_0) = p_\psi(z_0^{(1)}) \cdot \prod_{b \geq 2} p_\psi(z_0^{(b)} \mid z_0^{(<b)}) pψ(z0)=pψ(z0(1))b2pψ(z0(b)z0(<b))

  • 第一个 block z 0 ( 1 ) z_0^{(1)} z0(1):无条件生成
  • 后续 block z 0 ( b ) z_0^{(b)} z0(b):条件生成,条件是前面所有 block z 0 ( < b ) z_0^{(<b)} z0(<b)

4.3 注意力的可见性约束

这个因式分解通过注意力 mask 来实现。在 attention_utils.py:88-158create_na_block_causal_mask 中:

q_block = q_local.unsqueeze(1) // block_size  # Q 属于哪个 block
k_block = k_local.unsqueeze(0) // block_size  # K 属于哪个 block
same_sample = q_sample.unsqueeze(1) == k_sample.unsqueeze(0)  # 同一样本内
block_causal = q_block >= k_block  # block 间因果
allowed = same_sample & block_causal  # 最终可见性

三个约束:

  1. 同一样本内:不同样本之间完全隔离(same_sample
  2. block 间因果:Q 只能看到 K 所在 block ≤ Q 所在 block(q_block >= k_block
  3. block 内双向:同一个 block 内的所有位置互相可见(这是 q_block >= k_block 的自然结果)

4.4 与标准 causal attention 的区别

维度 标准 causal attention 分块因果注意力
粒度 逐 token 逐 block(block 内双向)
第 5 个 token 能看到第 3 个? 是(如果在同一序列) 是(如果在同一 block)
第 5 个 token 能看到第 6 个? 是(如果在同一 block)
全局视野 block 内是,block 间否

五、代码中的三层架构

5.1 模块划分

cola_dlm/
├── modeling_cola_vae.py    # Text VAE(encoder + decoder)
├── modeling_cola_dit.py    # DiT 先验
├── attention_utils.py      # 分块因果 mask
├── inference.py            # 推理流水线(串联三层)
├── configuration_cola_vae.py  # VAE 配置
└── configuration_cola_dit.py  # DiT 配置

5.2 数据流

输入 token: "The capital of France is"
        │
        ▼
┌─────────────────────┐
│  VAE Encoder (q_φ)  │  modeling_cola_vae.py:580-640
│  token → 隐向量      │  z_pre = encode(input_ids)
└─────────┬───────────┘
          │ z_pre: (n_i, latent_dim)
          ▼
┌─────────────────────┐
│  DiT 先验 (p_ψ)     │  modeling_cola_dit.py:594-685
│  噪声 → 干净隐向量    │  z_0 = Phi_psi(epsilon; z_pre)
│  分块因果 + CFG      │  block-by-block prior transport
└─────────┬───────────┘
          │ z_0: (B*block_size, latent_dim)
          ▼
┌─────────────────────┐
│  VAE Decoder (p_θ)  │  modeling_cola_vae.py:646-691
│  隐向量 → token logits│  logits = decode(z_0)
└─────────┬───────────┘
          │ logits: (B*block_size*patch_size, vocab)
          ▼
      采样 → "Paris"

5.3 代码入口

inference.py:285-738generate_task_repaint_inference 中,三层的调用顺序是:

# Step 2: VAE encode(前缀编码)
enc = vae.encode(input_ids_list)
latents_list = [((lat - shift) * scale).float() for lat in enc.latents_list]

# Step 5: 分块先验传输(DiT)
for t_curr, t_next in zip(timesteps[:-1], timesteps[1:]):
    drift_cond = dit(txt=txt_bf16, ..., use_kv_cache=True).txt_sample   # 条件
    drift_uncond = dit(txt=txt_bf16, ..., use_kv_cache=False).txt_sample # 无条件
    drift = s * (drift_cond - drift_uncond) + drift_uncond               # CFG
    txt = txt - drift * dt                                               # Euler 更新

# Step 5 续: VAE decode(条件解码)
decoded = vae.decode(z=txt, ..., update_kv=True)

六、与 Stable Diffusion 的类比

Cola DLM 和 Stable Diffusion 的架构高度同构:

组件 Stable Diffusion(图像) Cola DLM(文本)
编码器 图像 VAE encoder Text VAE encoder q ϕ q_\phi qϕ
隐空间 ( 64 , 64 , 4 ) (64, 64, 4) (64,64,4) 的图像隐向量 ( n i , 16 ) (n_i, 16) (ni,16) 的文本隐向量
先验模型 UNet / DiT 分块因果 DiT
解码器 图像 VAE decoder Text VAE decoder p θ p_\theta pθ
条件机制 CLIP text embedding 前缀隐向量 z p r e z^{pre} zpre
引导机制 Classifier-Free Guidance Classifier-Free Guidance

核心思想完全一样:在低维隐空间做扩散,而不是在高维原始空间

区别在于:

  • 图像的隐空间是 2D 空间(有空间结构)
  • 文本的隐空间是 1D 序列(有时间/因果结构)

这个区别导致 Cola DLM 需要"分块因果"的设计,而 Stable Diffusion 的 UNet/DiT 使用标准的 2D 注意力。


七、2026 年的批判性视角

7.1 解耦是必要的吗?

从理论上讲,端到端训练可能更好——组件间的接口不会丢失信息。但实践中,解耦有几个实际好处:

  1. 训练稳定性:VAE 先稳定,DiT 再训练,避免端到端的模式崩塌
  2. 调试方便:可以单独评估 VAE 的重构质量和 DiT 的先验质量
  3. 复用性:VAE 的隐空间可以用于其他任务(分类、检索等)

7.2 分块因果是最佳选择吗?

block_size 的选择是一个 trade-off:

  • block_size=1:退化为逐 token AR,丧失全局视野
  • block_size=序列长度:退化为全双向,丧失自回归能力
  • block_size=4(当前选择):在全局视野和自回归能力之间折中

最优 block_size 可能取决于任务:短文本生成可能需要更大的 block_size,长文本生成可能需要更小的 block_size。当前的固定 block_size 设计缺乏灵活性。

7.3 VAE 的信息瓶颈

VAE 把一个 ( L , v o c a b _ s i z e ) (L, vocab\_size) (L,vocab_size) 的 one-hot 序列压缩为 ( n i , 16 ) (n_i, 16) (ni,16) 的连续向量( n i = L / p a t c h _ s i z e n_i = L / patch\_size ni=L/patch_size)。这个压缩比非常高,必然丢失信息。论文中 l a t e n t _ d i m = 16 latent\_dim = 16 latent_dim=16 是一个保守的选择,增大 latent_dim 可能会提升质量,但也会增加 DiT 的计算开销。


八、面试追问清单

基础(⭐)

  1. Cola DLM 的三个组件分别是什么?各自负责什么?
  2. 为什么 VAE encoder 不是生成模型的一部分?
  3. 什么是分块因果注意力?

进阶(⭐⭐)

  1. Cola DLM 和 Stable Diffusion 的架构有什么异同?
  2. block_size 的大小如何影响生成质量?
  3. 为什么 Cola DLM 选择解耦而不是端到端训练?

专家(⭐⭐⭐)

  1. 如果把 VAE encoder 和 DiT 合并成一个端到端模型,需要什么技术来保证训练稳定性?
  2. 分块因果分解和 AR 的链式分解在数学上有什么关系?
  3. VAE 的 latent_dim=16 是否太小?增大 latent_dim 的 trade-off 是什么?

九、下期预告

下一章我们将深入 Text VAE 的内部——看看它是如何把离散 token 变成连续向量的,隐空间的几何结构是什么样的,以及 Stage 1 的训练目标如何保证隐空间的质量。


系列导航

第 01 章:语言生成的三次范式之争

第 02 章:扩散模型 10 分钟速通

第 03 章:离散扩散的困境

第 04 章:Cola DLM 架构全景 ← 你在这里

第 05 章:Text VAE 深度解剖

第 06 章:分块因果 DiT 先验

第 07 章:推理流水线逐行拆解

第 08 章:工程实现评析

第 09 章:评测复现与结果深度分析

第 10 章:从文本到多模态


作者Yunzenn

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐