【无标题】
Phase 2 Stage 1:Teacher Forcing 自回归初始化
这一阶段解决什么问题(Why)
Phase 1 训练出了双向(非因果)模型,但最终推理目标是因果自回归——模型按块生成视频,每次只能看到已生成的历史帧,不能看到未来帧。双向模型的权重无法直接用于因果推理:注意力掩码不同,推理行为完全不同。
Stage 1 的目标是在 因果(Causal)架构上初始化一个"够用"的基础模型,作为后续蒸馏 Stage 的起点。"够用"的意思是:给定完整的历史上下文(ground truth),因果模型能够做出合理的去噪预测,质量接近 Phase 1 双向模型。
关键约束是资源效率:从头训练因果模型代价极高,且容易遗忘 Phase 1 习得的视觉先验。Stage 1 的解法是从 Phase 1 的双向模型 热启动(warm start),用 Teacher Forcing 策略快速收敛。
算法直觉(TL;DR)
训练时,把已知的干净历史帧(或带微小噪声的历史帧)直接作为上下文注入因果模型。模型在"被告知正确历史"的条件下学习预测当前块的去噪速度场。这消除了因果推理中最难的部分(历史帧的质量不确定性),让模型能集中学习"给定好的历史,如何去噪当前块"。
代价是 Training/Inference Mismatch(Exposure Bias):训练时历史是 ground truth,推理时历史是模型自己的生成结果,质量存在差异。Noise Augmentation(给历史帧加微小噪声)是缓解这一 mismatch 的工程手段。
算法详解
模型:因果 CausalWanModel(is_causal=True),Block-diagonal 注意力掩码——第 i i i 块内全局注意力,但不能 attend 到第 i + 1 i+1 i+1 块及以后的内容。
Timestep 采样:uniform_timestep=False,不同块可以独立采样各自的时间步 t i t_i ti,每块对应各自的噪声标准差 σ t i \sigma_{t_i} σti。
Teacher Forcing 上下文:clean_x(干净历史帧 latent)注入 CausalWanModel,aug_t(noise augmentation 的噪声水平)控制向历史帧加多少噪声。
训练目标:与 Phase 1 相同的 Flow Matching 损失,仅作用于当前去噪的目标块:
L Stage1 = E x 0 , ε , t [ w ( σ t ) ∥ v θ ( x t ( cur ) , t , c , x aug ( hist ) ) − ( ε − x 0 ( cur ) ) ∥ 2 ] \mathcal{L}_\text{Stage1} = \mathbb{E}_{x_0, \varepsilon, t}\!\left[w(\sigma_t)\,\bigl\|\,v_\theta(x_t^{(\text{cur})}, t, c,\, x^{(\text{hist})}_{\text{aug}}) - (\varepsilon - x_0^{(\text{cur})})\bigr\|^2\right] LStage1=Ex0,ε,t[w(σt) vθ(xt(cur),t,c,xaug(hist))−(ε−x0(cur)) 2]
其中 x aug ( hist ) = ( 1 − σ τ ) x 0 ( hist ) + σ τ ε ′ x^{(\text{hist})}_{\text{aug}} = (1-\sigma_{\tau})x^{(\text{hist})}_0 + \sigma_\tau \varepsilon' xaug(hist)=(1−στ)x0(hist)+στε′ 是加了噪声的历史帧, τ \tau τ 为 aug_t 对应的低噪声水平。
数学推导:因果掩码、Teacher Forcing 与 Noise Augmentation因果 Block-Diagonal 注意力掩码
将视频 N N N 帧按块大小 B B B 分成 K = N / B K = N/B K=N/B 个块。CausalWanModel 使用 FlexAttention 实现如下掩码:
M i j = { 1 如果 ⌊ i / B ⌋ ≥ ⌊ j / B ⌋ 0 否则(未来块) M_{ij} = \begin{cases} 1 & \text{如果 } \lfloor i/B \rfloor \geq \lfloor j/B \rfloor \\ 0 & \text{否则(未来块)} \end{cases} Mij={10如果 ⌊i/B⌋≥⌊j/B⌋否则(未来块)
即 token i i i 可以 attend 到所有不晚于它所在块的 token。同一块内所有 token 两两可见(block-internal full attention)。
这与 GPT 风格的逐 token 因果掩码不同——这里的基本单元是"块"而非"帧",一个块内的所有帧同时去噪。这是因为视频帧在 latent 空间的局部时序相关性强,块内全局注意力有助于块内时序一致性。
Teacher Forcing 与 Autoregressive Inference 的差异
设历史块为 x 0 ( 1 ) , … , x 0 ( k − 1 ) x_0^{(1)}, \ldots, x_0^{(k-1)} x0(1),…,x0(k−1),当前块为 x 0 ( k ) x_0^{(k)} x0(k)。
- 训练(Teacher Forcing):模型输入历史为 ground truth x 0 ( 1 ) , … , x 0 ( k − 1 ) x_0^{(1)}, \ldots, x_0^{(k-1)} x0(1),…,x0(k−1)
- 推理(Autoregressive):模型输入历史为模型自身已生成的 x ^ 0 ( 1 ) , … , x ^ 0 ( k − 1 ) \hat{x}_0^{(1)}, \ldots, \hat{x}_0^{(k-1)} x^0(1),…,x^0(k−1)
两者之间的差距称为 Exposure Bias(暴露偏差)。如果训练时始终用 ground truth,模型在推理时遇到自己生成的(不完美的)历史就会"不适应",误差逐步累积,长视频生成质量下降。
Noise Augmentation 缓解 Exposure Bias
向历史帧注入小量噪声:
x aug ( hist ) = 1 − σ τ 2 x 0 ( hist ) + σ τ ε ′ , σ τ ≪ 1 x_{\text{aug}}^{(\text{hist})} = \sqrt{1-\sigma_\tau^2}\,x_0^{(\text{hist})} + \sigma_\tau\,\varepsilon', \quad \sigma_\tau \ll 1 xaug(hist)=1−στ2x0(hist)+στε′,στ≪1
直觉:给 ground truth 加一点噪声 ≈ 模拟"稍微不完美的历史帧",让模型对历史帧的轻微瑕疵具备鲁棒性,从而部分缓解 Training/Inference Mismatch。
σ τ \sigma_\tau στ 是一个超参数(通常很小,如 0.02 0.02 0.02– 0.1 0.1 0.1),过大会破坏历史信息,过小则缓解效果有限。
为什么 uniform_timestep=False?
因果推理时,不同块处于完全不同的去噪阶段是可能出现的情形(例如,在多步采样时同时处理多个块)。允许每块独立采样 t i t_i ti 使训练分布与这种推理场景更匹配,也给 Stage 2/3 的蒸馏提供更灵活的时序覆盖。
代码走读
wan_trainer/ar_diffusion.py : Trainer.train_one_step()
│
├─ 准备数据
│ ├─ clean_latent = vae.encode(video_frames) # (B, C, T, H, W) 全段干净 latent
│ └─ image_latent = clean_latent[:, :, :4, ...] # 第一块作为 initial_latent(可选)
│
└─ model.generator_loss(
clean_latent=clean_latent,
initial_latent=image_latent,
...)
└─ CausalDiffusion.generator_loss() [Wan21/model/diffusion.py : 49-130]
│
├─ _get_timestep(uniform_timestep=False)
│ └─ 为每个块独立采样 t_i,再 broadcast 到块内所有帧
│
├─ noise_augmentation(若 aug_t_cfg 已配置)
│ └─ clean_latent_aug = scheduler.add_noise(
│ clean_latent, noise, aug_t)
│ └─ aug_t 采样自低噪声区间
│
├─ scheduler.add_noise(x0_target, noise, t)
│ └─ noisy_latents = (1-σ_t)·x0 + σ_t·ε # 对当前块加噪
│
├─ generator(
│ noisy_latents,
│ conditional_dict,
│ timestep,
│ clean_x=clean_latent_aug, # ← Teacher Forcing 注入
│ aug_t=timestep_clean_aug) # ← 历史帧的噪声水平
│ └─ CausalWanModel 前向,block-diagonal 掩码生效
│
└─ flow_matching_loss(v_pred, target, weight)
└─ target = ε - x0_target(仅计算目标块的损失)
关键文件:
Wan21/model/diffusion.py:CausalDiffusion.generator_loss()(第 49–130 行):完整的 Teacher Forcing 训练逻辑Wan21/wan/modules/causal_model.py:CausalWanSelfAttention:block-diagonal FlexAttention 掩码实现Wan21/wan_utils/wan_wrapper.py:WanDiffusionWrapper:统一封装层,根据is_causal路由到CausalWanModel或WanModel
与 Phase 1 对比,最直观的代码级区别:
_get_timestep()的uniform_timestep参数从True变为Falsegenerator()调用增加了clean_x=clean_latent_aug和aug_t=timestep_clean_aug参数
Wan21 / HY15 实现差异
| 维度 | Wan21 | HY15 |
|---|---|---|
| 训练入口 | wan_trainer/ar_diffusion.py |
HY15/trainer/pipelines/ar_hunyuan_training_entry.py |
| 因果 Backbone | CausalWanModel(FlexAttention block-diagonal mask) |
CausalHunyuanVideoTransformer(自定义 causal mask) |
| Teacher Forcing 注入 | clean_x 参数传入 CausalWanModel.forward() |
clean_x 参数传入 HunyuanVideoTransformer.forward() |
| 注意力掩码实现 | FlexAttention(PyTorch 原生,支持任意 mask 模式) | torch.nn.functional.scaled_dot_product_attention + 手写 mask |
| 块大小(chunk size) | 配置文件中指定(默认 4 帧/块) | 相同默认配置 |
| Phase 1 → Stage 1 权重迁移 | WanModel → CausalWanModel(加载非因果权重,因果 mask 新增) |
HunyuanVideoTransformer → CausalHunyuanVideoTransformer(同理) |
两者的 Teacher Forcing 策略和 Noise Augmentation 逻辑在原理上完全一致,差异仅在 backbone 内部实现。
参考论文
-
[必读] DMD2 Section 4: Teacher Forcing AR Initialization (Yin et al., 2024):https://arxiv.org/abs/2405.14867
Stage 1 的直接方法来源,描述了如何用双向 teacher 初始化因果模型。 -
[必读] Teacher Forcing (Williams & Zipser, 1989):A Learning Algorithm for Continually Running Fully Recurrent Neural Networks
Teacher Forcing 这一训练技巧的原始文献,理解"用 ground truth 历史替代模型自身生成历史"的动机。 -
[延伸] Scheduled Sampling (Bengio et al., 2015):https://arxiv.org/abs/1506.03099
缓解 Exposure Bias 的经典方法(逐渐减少 Teacher Forcing 比例),与 Noise Augmentation 是同一问题的不同解法。 -
[延伸] Flow Matching for Generative Modeling (Lipman et al., 2022):https://arxiv.org/abs/2210.02747
Stage 1 沿用 Phase 1 的 Flow Matching 损失,需要理解此损失的含义。 -
[延伸] FlexAttention (PyTorch 2.x):理解 block-diagonal mask 的高效实现方式,相关内容见
Wan21/wan/modules/causal_model.py。
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐



所有评论(0)