Phase 2 Stage 1:Teacher Forcing 自回归初始化

这一阶段解决什么问题(Why)

Phase 1 训练出了双向(非因果)模型,但最终推理目标是因果自回归——模型按块生成视频,每次只能看到已生成的历史帧,不能看到未来帧。双向模型的权重无法直接用于因果推理:注意力掩码不同,推理行为完全不同。

Stage 1 的目标是在 因果(Causal)架构上初始化一个"够用"的基础模型,作为后续蒸馏 Stage 的起点。"够用"的意思是:给定完整的历史上下文(ground truth),因果模型能够做出合理的去噪预测,质量接近 Phase 1 双向模型。

关键约束是资源效率:从头训练因果模型代价极高,且容易遗忘 Phase 1 习得的视觉先验。Stage 1 的解法是从 Phase 1 的双向模型 热启动(warm start),用 Teacher Forcing 策略快速收敛。

算法直觉(TL;DR)

训练时,把已知的干净历史帧(或带微小噪声的历史帧)直接作为上下文注入因果模型。模型在"被告知正确历史"的条件下学习预测当前块的去噪速度场。这消除了因果推理中最难的部分(历史帧的质量不确定性),让模型能集中学习"给定好的历史,如何去噪当前块"。

代价是 Training/Inference Mismatch(Exposure Bias):训练时历史是 ground truth,推理时历史是模型自己的生成结果,质量存在差异。Noise Augmentation(给历史帧加微小噪声)是缓解这一 mismatch 的工程手段。

算法详解

模型:因果 CausalWanModelis_causal=True),Block-diagonal 注意力掩码——第 i i i 块内全局注意力,但不能 attend 到第 i + 1 i+1 i+1 块及以后的内容。

Timestep 采样uniform_timestep=False,不同块可以独立采样各自的时间步 t i t_i ti,每块对应各自的噪声标准差 σ t i \sigma_{t_i} σti

Teacher Forcing 上下文clean_x(干净历史帧 latent)注入 CausalWanModelaug_t(noise augmentation 的噪声水平)控制向历史帧加多少噪声。

训练目标:与 Phase 1 相同的 Flow Matching 损失,仅作用于当前去噪的目标块:

L Stage1 = E x 0 , ε , t  ⁣ [ w ( σ t )   ∥   v θ ( x t ( cur ) , t , c ,   x aug ( hist ) ) − ( ε − x 0 ( cur ) ) ∥ 2 ] \mathcal{L}_\text{Stage1} = \mathbb{E}_{x_0, \varepsilon, t}\!\left[w(\sigma_t)\,\bigl\|\,v_\theta(x_t^{(\text{cur})}, t, c,\, x^{(\text{hist})}_{\text{aug}}) - (\varepsilon - x_0^{(\text{cur})})\bigr\|^2\right] LStage1=Ex0,ε,t[w(σt) vθ(xt(cur),t,c,xaug(hist))(εx0(cur)) 2]

其中 x aug ( hist ) = ( 1 − σ τ ) x 0 ( hist ) + σ τ ε ′ x^{(\text{hist})}_{\text{aug}} = (1-\sigma_{\tau})x^{(\text{hist})}_0 + \sigma_\tau \varepsilon' xaug(hist)=(1στ)x0(hist)+στε 是加了噪声的历史帧, τ \tau τ 为 aug_t 对应的低噪声水平。

数学推导:因果掩码、Teacher Forcing 与 Noise Augmentation

因果 Block-Diagonal 注意力掩码

将视频 N N N 帧按块大小 B B B 分成 K = N / B K = N/B K=N/B 个块。CausalWanModel 使用 FlexAttention 实现如下掩码:

M i j = { 1 如果  ⌊ i / B ⌋ ≥ ⌊ j / B ⌋ 0 否则(未来块) M_{ij} = \begin{cases} 1 & \text{如果 } \lfloor i/B \rfloor \geq \lfloor j/B \rfloor \\ 0 & \text{否则(未来块)} \end{cases} Mij={10如果 i/Bj/B否则(未来块)

即 token i i i 可以 attend 到所有不晚于它所在块的 token。同一块内所有 token 两两可见(block-internal full attention)。

这与 GPT 风格的逐 token 因果掩码不同——这里的基本单元是"块"而非"帧",一个块内的所有帧同时去噪。这是因为视频帧在 latent 空间的局部时序相关性强,块内全局注意力有助于块内时序一致性。

Teacher Forcing 与 Autoregressive Inference 的差异

设历史块为 x 0 ( 1 ) , … , x 0 ( k − 1 ) x_0^{(1)}, \ldots, x_0^{(k-1)} x0(1),,x0(k1),当前块为 x 0 ( k ) x_0^{(k)} x0(k)

  • 训练(Teacher Forcing):模型输入历史为 ground truth x 0 ( 1 ) , … , x 0 ( k − 1 ) x_0^{(1)}, \ldots, x_0^{(k-1)} x0(1),,x0(k1)
  • 推理(Autoregressive):模型输入历史为模型自身已生成的 x ^ 0 ( 1 ) , … , x ^ 0 ( k − 1 ) \hat{x}_0^{(1)}, \ldots, \hat{x}_0^{(k-1)} x^0(1),,x^0(k1)

两者之间的差距称为 Exposure Bias(暴露偏差)。如果训练时始终用 ground truth,模型在推理时遇到自己生成的(不完美的)历史就会"不适应",误差逐步累积,长视频生成质量下降。

Noise Augmentation 缓解 Exposure Bias

向历史帧注入小量噪声:

x aug ( hist ) = 1 − σ τ 2   x 0 ( hist ) + σ τ   ε ′ , σ τ ≪ 1 x_{\text{aug}}^{(\text{hist})} = \sqrt{1-\sigma_\tau^2}\,x_0^{(\text{hist})} + \sigma_\tau\,\varepsilon', \quad \sigma_\tau \ll 1 xaug(hist)=1στ2 x0(hist)+στε,στ1

直觉:给 ground truth 加一点噪声 ≈ 模拟"稍微不完美的历史帧",让模型对历史帧的轻微瑕疵具备鲁棒性,从而部分缓解 Training/Inference Mismatch。

σ τ \sigma_\tau στ 是一个超参数(通常很小,如 0.02 0.02 0.02 0.1 0.1 0.1),过大会破坏历史信息,过小则缓解效果有限。

为什么 uniform_timestep=False

因果推理时,不同块处于完全不同的去噪阶段是可能出现的情形(例如,在多步采样时同时处理多个块)。允许每块独立采样 t i t_i ti 使训练分布与这种推理场景更匹配,也给 Stage 2/3 的蒸馏提供更灵活的时序覆盖。

代码走读

wan_trainer/ar_diffusion.py : Trainer.train_one_step()
  │
  ├─ 准备数据
  │    ├─ clean_latent = vae.encode(video_frames)   # (B, C, T, H, W) 全段干净 latent
  │    └─ image_latent = clean_latent[:, :, :4, ...]  # 第一块作为 initial_latent(可选)
  │
  └─ model.generator_loss(
         clean_latent=clean_latent,
         initial_latent=image_latent,
         ...)
       └─ CausalDiffusion.generator_loss()   [Wan21/model/diffusion.py : 49-130]
            │
            ├─ _get_timestep(uniform_timestep=False)
            │    └─ 为每个块独立采样 t_i,再 broadcast 到块内所有帧
            │
            ├─ noise_augmentation(若 aug_t_cfg 已配置)
            │    └─ clean_latent_aug = scheduler.add_noise(
            │            clean_latent, noise, aug_t)
            │         └─ aug_t 采样自低噪声区间
            │
            ├─ scheduler.add_noise(x0_target, noise, t)
            │    └─ noisy_latents = (1-σ_t)·x0 + σ_t·ε   # 对当前块加噪
            │
            ├─ generator(
            │    noisy_latents,
            │    conditional_dict,
            │    timestep,
            │    clean_x=clean_latent_aug,   # ← Teacher Forcing 注入
            │    aug_t=timestep_clean_aug)   # ← 历史帧的噪声水平
            │    └─ CausalWanModel 前向,block-diagonal 掩码生效
            │
            └─ flow_matching_loss(v_pred, target, weight)
                 └─ target = ε - x0_target(仅计算目标块的损失)

关键文件:

  • Wan21/model/diffusion.py:CausalDiffusion.generator_loss()(第 49–130 行):完整的 Teacher Forcing 训练逻辑
  • Wan21/wan/modules/causal_model.py:CausalWanSelfAttention:block-diagonal FlexAttention 掩码实现
  • Wan21/wan_utils/wan_wrapper.py:WanDiffusionWrapper:统一封装层,根据 is_causal 路由到 CausalWanModelWanModel

与 Phase 1 对比,最直观的代码级区别:

  • _get_timestep()uniform_timestep 参数从 True 变为 False
  • generator() 调用增加了 clean_x=clean_latent_augaug_t=timestep_clean_aug 参数

Wan21 / HY15 实现差异

维度 Wan21 HY15
训练入口 wan_trainer/ar_diffusion.py HY15/trainer/pipelines/ar_hunyuan_training_entry.py
因果 Backbone CausalWanModel(FlexAttention block-diagonal mask) CausalHunyuanVideoTransformer(自定义 causal mask)
Teacher Forcing 注入 clean_x 参数传入 CausalWanModel.forward() clean_x 参数传入 HunyuanVideoTransformer.forward()
注意力掩码实现 FlexAttention(PyTorch 原生,支持任意 mask 模式) torch.nn.functional.scaled_dot_product_attention + 手写 mask
块大小(chunk size) 配置文件中指定(默认 4 帧/块) 相同默认配置
Phase 1 → Stage 1 权重迁移 WanModelCausalWanModel(加载非因果权重,因果 mask 新增) HunyuanVideoTransformerCausalHunyuanVideoTransformer(同理)

两者的 Teacher Forcing 策略和 Noise Augmentation 逻辑在原理上完全一致,差异仅在 backbone 内部实现。

参考论文

  • [必读] DMD2 Section 4: Teacher Forcing AR Initialization (Yin et al., 2024):https://arxiv.org/abs/2405.14867
    Stage 1 的直接方法来源,描述了如何用双向 teacher 初始化因果模型。

  • [必读] Teacher Forcing (Williams & Zipser, 1989):A Learning Algorithm for Continually Running Fully Recurrent Neural Networks
    Teacher Forcing 这一训练技巧的原始文献,理解"用 ground truth 历史替代模型自身生成历史"的动机。

  • [延伸] Scheduled Sampling (Bengio et al., 2015):https://arxiv.org/abs/1506.03099
    缓解 Exposure Bias 的经典方法(逐渐减少 Teacher Forcing 比例),与 Noise Augmentation 是同一问题的不同解法。

  • [延伸] Flow Matching for Generative Modeling (Lipman et al., 2022):https://arxiv.org/abs/2210.02747
    Stage 1 沿用 Phase 1 的 Flow Matching 损失,需要理解此损失的含义。

  • [延伸] FlexAttention (PyTorch 2.x):理解 block-diagonal mask 的高效实现方式,相关内容见 Wan21/wan/modules/causal_model.py

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐