第 06 章:分块因果 DiT 先验 —— 在隐空间里做 Flow Matching

论文Continuous Latent Diffusion Language Model
项目地址ByteDance-Seed/Cola-DLM
源码modeling_cola_dit.py

核心困惑:DiT 是怎么学习隐空间先验 p ψ ( z 0 ) p_\psi(z_0) pψ(z0) 的?分块因果注意力怎么实现?CFG 怎么工作?


一、DiT 在图像领域的成功

DiT(Diffusion Transformer)由 Scalable Diffusion Models with Transformers(Peebles & Xie, 2023)提出,用 Transformer 替代 UNet 作为扩散模型的骨干网络。核心思想:

  • 把图像切成 patch → 线性投影 → Transformer blocks → 线性投影 → unpatch
  • 用 AdaLN(Adaptive Layer Norm)注入时间步信息

Cola DLM 把这个思路从 2D 图像迁移到 1D 文本隐序列。


二、模型架构

2.1 整体结构

代码位置:modeling_cola_dit.py:536-690

输入: txt (L_q_total, in_channels)
        │
        ▼
┌──────────────┐
│  PatchIn1D   │  patchify + 线性投影
└──────┬───────┘
       │ (L_q_total/patch_size, txt_dim)
       ▼
┌──────────────┐
│ TimestepEmb  │  sinusoidal → MLP → emb_dim
└──────┬───────┘
       │
       ▼
┌──────────────┐
│ DiTBlock ×24 │  AdaLN + Attention + FFN
│ (分块因果)    │  per-sample KV cache
└──────┬───────┘
       │
       ▼
┌──────────────┐
│ PatchOut1D   │  线性投影 + unpatch
└──────┬───────┘
       │
       ▼
输出: txt_sample (L_q_total, out_channels)

2.2 配置参数

参数 默认值 含义
txt_in_channels 16 输入隐空间维度(= VAE 的 latent_dim
txt_out_channels 16 输出隐空间维度
txt_dim 2048 Transformer 隐藏维度
emb_dim 2048 AdaLN 条件维度
heads 16 注意力头数
head_dim 128 每头维度
expand_ratio 4 FFN 扩展比
num_layers 24 Transformer 层数
patch_size 1 patchify 因子
rope_dim 96 RoPE 作用的通道数(< head_dim=128)
block_size 4 分块大小

总参数量:约 1.8B(24 层 × 16 头 × 128 head_dim = 2048 hidden dim)。


三、关键组件

3.1 PatchIn1D / PatchOut1D

代码位置:modeling_cola_dit.py:166-205

class PatchIn1D(nn.Module):
    def __init__(self, in_channels, patch_size, dim):
        self.proj = nn.Linear(in_channels * patch_size, dim)

    def forward(self, txt, txt_shape):
        txt_shape_before_patchify = txt_shape
        if self.patch_size != 1:  # patch_size=1 时跳过,直接投影
            batch_list = _unflatten(txt, txt_shape)
            for i in range(len(batch_list)):
                batch_list[i] = rearrange(batch_list[i], "(T t) c -> T (t c)", t=self.patch_size)
            txt, txt_shape = _flatten(batch_list)
        txt = self.proj(txt)
        return txt, txt_shape, txt_shape_before_patchify

默认 patch_size=1,所以 rearrange 被完全跳过,只有线性投影生效。

3.2 TimestepEmbedding(AdaLN 条件化)

代码位置:modeling_cola_dit.py:135-158

class TimestepEmbedding(nn.Module):
    def __init__(self, sinusoidal_dim, hidden_dim, output_dim):
        self.proj_in = nn.Linear(sinusoidal_dim, hidden_dim)
        self.proj_hid = nn.Linear(hidden_dim, hidden_dim)
        self.proj_out = nn.Linear(hidden_dim, output_dim)

    def forward(self, timestep, device, dtype):
        emb = _get_sinusoidal_embedding(timestep, self.sinusoidal_dim)
        emb = self.act(self.proj_in(emb))
        emb = self.act(self.proj_hid(emb))
        emb = self.proj_out(emb)
        return emb

时间步 t t t 通过 sinusoidal embedding + MLP 注入到每个 Transformer block 的 AdaLN 中。

3.3 AdaLN(Adaptive Layer Norm)

代码位置:modeling_cola_dit.py:299-336

class AdaLN(nn.Module):
    def forward(self, hid, emb, layer, mode, norm_layer=None, residual=None, **kwargs):
        emb = getattr(self, f"{layer}_{mode}")(emb)  # 线性投影
        if mode == "in":
            shift, scale = emb.chunk(2, dim=-1)
            return norm_layer(hid) * (1 + scale) + shift  # scale + shift
        if mode == "out":
            return hid * emb + residual  # gate + residual

每个 block 有两处 AdaLN:

  • mode="in":在 attention/FFN 之前,做 scale + shift
  • mode="out":在 attention/FFN 之后,做 gate + residual

3.4 MLP

代码位置:modeling_cola_dit.py:344-352

class MLP(nn.Module):
    def __init__(self, dim, expand_ratio):
        self.proj_in = nn.Linear(dim, dim * expand_ratio)
        self.act = nn.GELU("tanh")  # 注意:不是 SwiGLU!
        self.proj_out = nn.Linear(dim * expand_ratio, dim)

注意:DiT 用 GELU tanh,而 VAE 用 SwiGLU。这是一个设计选择差异。


四、分块因果注意力

4.1 ColaDiTAttention

代码位置:modeling_cola_dit.py:360-463

class ColaDiTAttention(nn.Module):
    def __init__(self, txt_dim, heads, head_dim, qk_bias, qk_norm_eps, rope_dim):
        self.proj_qkv = nn.Linear(txt_dim, inner_dim * 3)
        self.norm_q = nn.LayerNorm(head_dim)  # QK-norm
        self.norm_k = nn.LayerNorm(head_dim)
        self.rope = TextRotaryEmbedding(dim=rope_dim)  # rope_dim=96 < head_dim=128

RoPE 只作用于部分通道rope_dim=96,而 head_dim=128。这意味着 128 个通道中有 96 个有位置编码,32 个没有。

4.2 KV Cache 管理

代码位置:modeling_cola_dit.py:420-442

# per-sample KV cache
self._k_cache: Optional[list[torch.Tensor]] = None
self._v_cache: Optional[list[torch.Tensor]] = None

# forward 中的逻辑:
if update_kv:  # 提交新 block 到 cache
    self._k_cache = [torch.cat([c, n], dim=0) for c, n in zip(self._k_cache, new_ks)]
    full_k = torch.cat(self._k_cache, dim=0)
elif use_kv_cache and self._k_cache is not None:  # 读 cache
    full_k = torch.cat([torch.cat([c, n], dim=0) for c, n in zip(self._k_cache, new_ks)], dim=0)
else:  # 无 cache
    full_k = txt_k

4.3 注意力计算

代码位置:modeling_cola_dit.py:381-397

def slow_attn(self, query, key, value, attn_mask=None):
    d_head = query.shape[-1]
    device_type = "cuda" if query.is_cuda else query.device.type
    with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
        scale = 1.0 / (d_head ** 0.5)
        attn = query.mul(scale) @ key.transpose(-2, -1)  # 显式构造完整矩阵
        if attn_mask is not None:
            attn = attn + attn_mask.to(attn.dtype)
        attn_weight = attn.softmax(dim=-1)
        attn_out = attn_weight @ value
    return attn_out

注意:没有使用 Flash Attention,而是显式构造完整的 ( L q , L k ) (L_q, L_k) (Lq,Lk) 注意力矩阵。这是当前实现的一个重要限制。


五、CFG(Classifier-Free Guidance)

5.1 原理

CFG 是扩散模型的标准技巧:同时做条件生成和无条件生成,然后按比例混合。

v ^ = v uncond + s ⋅ ( v cond − v uncond ) \hat{v} = v_{\text{uncond}} + s \cdot (v_{\text{cond}} - v_{\text{uncond}}) v^=vuncond+s(vcondvuncond)

其中 s s s 是 guidance scale(默认 7.0)。

5.2 代码实现

代码位置:inference.py:621-648

# 条件前向:用 KV cache(能看到 prompt 和历史 block)
drift_cond = dit(txt=txt_bf16, txt_shape=txt_shape_cum,
                 txt_q_shape=txt_q_shape, timestep=ts_bf16,
                 update_kv=False, use_kv_cache=True).txt_sample

# 无条件前向:不用 cache(只能看到当前 block)
drift_uncond = dit(txt=txt_bf16, txt_shape=txt_q_shape,
                   txt_q_shape=txt_q_shape, timestep=ts_bf16,
                   update_kv=False, use_kv_cache=False).txt_sample

# CFG 融合
s = cfg_scale_first_block if step == 0 else guidance_scale
drift = s * (drift_cond - drift_uncond) + drift_uncond

5.3 短 prompt 的 CFG 退化

代码位置:inference.py:531-554

当 prompt 短于 block_size 时,第一个生成 block 的前缀 KV cache 为空,条件和无条件前向数学上相同。此时 CFG 会放大 bf16 噪声:

cfg_scale_first_block = torch.tensor(
    [guidance_scale if pl > 0 else 1.0 for pl in prefix_lens],
    device=device, dtype=torch.bfloat16,
).repeat_interleave(block_size).unsqueeze(-1)

空 prefix 的样本自动将 guidance_scale 降为 1.0。


六、ColaDiTBlock

代码位置:modeling_cola_dit.py:471-519

每个 block 的前向流程:

def forward(self, txt, *, txt_shape, txt_q_shape, emb, ...):
    # 1. AdaLN + Attention
    txt_msa = self.ada(txt, emb=emb, layer="msa", mode="in", norm_layer=self.msa_norm)
    txt_msa = self.msa(txt_msa, txt_shape=txt_shape, txt_q_shape=txt_q_shape, ...)
    txt = self.ada(txt_msa, emb=emb, layer="msa", mode="out", residual=txt)

    # 2. AdaLN + FFN
    txt_mlp = self.ada(txt, emb=emb, layer="mlp", mode="in", norm_layer=self.mlp_norm)
    txt_mlp = self.mlp(txt_mlp)
    txt = self.ada(txt_mlp, emb=emb, layer="mlp", mode="out", residual=txt)
    return txt

七、Stage 2 联合训练目标

论文式 2.2.3 给出 Stage 2 的损失:

L stage2 = λ VAE ⋅ L VAAE + λ FM ⋅ L FM + λ ref ⋅ E [ KL ( q ϕ ( z 0 ∣ x ) ∥ q ϕ ref ( z 0 ∣ x ) ) ] \mathcal{L}_{\text{stage2}} = \lambda_{\text{VAE}} \cdot \mathcal{L}_{\text{VAAE}} + \lambda_{\text{FM}} \cdot \mathcal{L}_{\text{FM}} + \lambda_{\text{ref}} \cdot \mathbb{E}[\text{KL}(q_\phi(z_0|x) \| q_{\phi_{\text{ref}}}(z_0|x))] Lstage2=λVAELVAAE+λFMLFM+λrefE[KL(qϕ(z0x)qϕref(z0x))]

作用
L VAE \mathcal{L}_{\text{VAE}} LVAE 保持 VAE 的重构能力
L FM \mathcal{L}_{\text{FM}} LFM 训练 DiT 先验(Flow Matching)
reference KL 防止 VAE 的隐空间漂移(对齐冻结的参考编码器)

八、面试追问清单

基础(⭐)

  1. DiT 和 UNet 作为扩散模型骨干的区别是什么?
  2. AdaLN 是如何注入时间步信息的?
  3. CFG 的 guidance scale 对生成质量有什么影响?

进阶(⭐⭐)

  1. 为什么 DiT 的 RoPE 只作用于 96/128 个通道?
  2. per-sample KV cache 和标准 KV cache 有什么区别?
  3. Stage 2 的 reference-encoder KL 正则为什么能防止隐空间漂移?

专家(⭐⭐⭐)

  1. DiT 用 GELU tanh 而 VAE 用 SwiGLU,这个差异会影响什么?
  2. rope_theta=10000(DiT)vs rope_theta=500000(VAE)的位置编码频率差异意味着什么?
  3. 如果把 block_size 从 4 改为 16,DiT 的注意力模式会怎么变化?

九、下期预告

下一章我们将逐行拆解推理流水线——从 prompt 输入到文本输出的完整过程,包括 tokenization、前缀编码、分块先验传输、条件解码和采样策略。


系列导航

第 01 章 · 第 02 章 · 第 03 章 · 第 04 章 · 第 05 章

第 06 章:分块因果 DiT 先验 ← 你在这里

第 07 章 · 第 08 章 · 第 09 章 · 第 10 章


作者Yunzenn

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐