【扩散模型系列·第三篇】反向过程:学习去噪,训练目标推导与简化损失全解析

作者:技术博主 | 更新时间:2026-05-20 | 阅读时长:约 23 分钟
系列:扩散模型从零到实战(共 8 篇)
环境:Python 3.12 + PyTorch 2.x
标签扩散模型 DDPM 反向过程 ELBO 去噪 训练目标 参数化 采样算法


在这里插入图片描述

🔥 本篇目标:前两篇建立了直觉和前向过程的数学。本篇解决核心问题:神经网络到底在学什么?损失函数是怎么来的?采样时每步如何计算? 从变分下界(ELBO)出发推导出简化的 MSE 损失,再推导采样公式,最后实现完整的训练和采样代码。读完本篇,DDPM 的完整训练-推理循环就彻底透明了。


系列进度

篇次 主题 状态
第一篇 扩散模型是什么:从加噪到去噪的直觉 ✅ 已发布
第二篇 数学基础:前向过程与马尔可夫链 ✅ 已发布
第三篇(本篇) 反向过程:学习去噪
第四篇 U-Net 架构:去噪网络的设计 即将发布
第五篇 DDIM:加速采样 即将发布
第六篇 条件生成:文本引导与 CFG 即将发布
第七篇 Stable Diffusion:潜在扩散模型 即将发布
第八篇 实战与前沿:LoRA、DreamBooth、FLUX 即将发布

目录


一、反向过程的数学定义

1.1 我们想要什么

前向过程 q q q 是已知的(人为定义),我们真正需要的是反向过程 p p p:给定一张加了噪声的图 x t x_t xt,推算出上一步更干净的 x t − 1 x_{t-1} xt1

如果我们知道完整的数据分布 q ( x 0 ) q(x_0) q(x0),真实的反向条件分布是:

q ( x t − 1 ∣ x t ) = q ( x t ∣ x t − 1 )   q ( x t − 1 ) q ( x t ) q(x_{t-1} \mid x_t) = \frac{q(x_t \mid x_{t-1})\, q(x_{t-1})}{q(x_t)} q(xt1xt)=q(xt)q(xtxt1)q(xt1)

问题 q ( x t − 1 ) q(x_{t-1}) q(xt1) 需要对所有可能的 x 0 x_0 x0 积分,计算不可行。

解决:用神经网络 p θ p_\theta pθ 来近似这个分布:

p θ ( x t − 1 ∣ x t ) = N  ⁣ ( x t − 1 ;   μ θ ( x t , t ) ,   Σ θ ( x t , t ) ) p_\theta(x_{t-1} \mid x_t) = \mathcal{N}\!\left(x_{t-1};\, \mu_\theta(x_t, t),\, \Sigma_\theta(x_t, t)\right) pθ(xt1xt)=N(xt1;μθ(xt,t),Σθ(xt,t))

网络的任务:学习均值 μ θ \mu_\theta μθ 和方差 Σ θ \Sigma_\theta Σθ,使得 p θ ≈ q p_\theta \approx q pθq

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math

# 反向过程的核心问题
print("反向过程的核心问题:")
print()
print("  已知:xₜ(加了 t 步噪声的图像)")
print("  想要:xₜ₋₁(少一步噪声的图像)")
print()
print("  真实分布 q(xₜ₋₁|xₜ) 需要积分所有 x₀,不可行")
print()
print("  → 用神经网络 p_θ(xₜ₋₁|xₜ) 来近似")
print("  → p_θ 假设为高斯分布(均值和方差由网络预测)")
print()
print("  关键发现(Sohl-Dickstein et al., 2015):")
print("  当 βₜ 足够小时,q(xₜ₋₁|xₜ) 近似高斯分布")
print("  → 用高斯来近似反向过程是合理的!")

二、为什么反向过程也是高斯分布

2.1 后验分布的解析形式

虽然 q ( x t − 1 ∣ x t ) q(x_{t-1} \mid x_t) q(xt1xt) 难以计算,但给定 x 0 x_0 x0 的条件后验有解析形式:

q ( x t − 1 ∣ x t , x 0 ) = N  ⁣ ( x t − 1 ;   μ ~ t ( x t , x 0 ) ,   β ~ t I ) q(x_{t-1} \mid x_t, x_0) = \mathcal{N}\!\left(x_{t-1};\, \tilde{\mu}_t(x_t, x_0),\, \tilde{\beta}_t I\right) q(xt1xt,x0)=N(xt1;μ~t(xt,x0),β~tI)

其中:

μ ~ t = α ˉ t − 1   β t 1 − α ˉ t   x 0 + α t ( 1 − α ˉ t − 1 ) 1 − α ˉ t   x t \tilde{\mu}_t = \frac{\sqrt{\bar{\alpha}_{t-1}}\,\beta_t}{1 - \bar{\alpha}_t}\,x_0 + \frac{\sqrt{\alpha_t}(1 - \bar{\alpha}_{t-1})}{1 - \bar{\alpha}_t}\,x_t μ~t=1αˉtαˉt1 βtx0+1αˉtαt (1αˉt1)xt

β ~ t = ( 1 − α ˉ t − 1 )   β t 1 − α ˉ t \tilde{\beta}_t = \frac{(1 - \bar{\alpha}_{t-1})\,\beta_t}{1 - \bar{\alpha}_t} β~t=1αˉt(1αˉt1)βt

def compute_posterior(
    x0:             torch.Tensor,
    xt:             torch.Tensor,
    t:              torch.Tensor,
    betas:          torch.Tensor,
    alphas_cumprod: torch.Tensor,
) -> tuple:
    """
    计算后验分布 q(xₜ₋₁ | xₜ, x₀) 的均值和方差
    这是反向过程的"真实目标"——网络要学的正是这个分布
    """
    B = x0.shape[0]

    def extract(arr, t, shape):
        vals = arr[t]
        return vals.reshape(B, *([1] * (len(shape) - 1)))

    # 提取关键系数
    beta_t       = extract(betas, t, x0.shape)
    ab_t         = extract(alphas_cumprod, t, x0.shape)
    # ᾱₜ₋₁(t=0 时令 ᾱ₋₁ = 1)
    alphas_cumprod_prev = torch.cat([torch.ones(1), alphas_cumprod[:-1]])
    ab_t_prev    = extract(alphas_cumprod_prev, t, x0.shape)
    alpha_t      = extract(1 - betas, t, x0.shape)

    # 后验方差:β̃ₜ = (1 - ᾱₜ₋₁) / (1 - ᾱₜ) · βₜ
    posterior_variance = (1 - ab_t_prev) / (1 - ab_t) * beta_t
    posterior_variance = posterior_variance.clamp(min=1e-20)

    # 后验均值的两个系数
    coef_x0 = ab_t_prev.sqrt() * beta_t / (1 - ab_t)
    coef_xt = alpha_t.sqrt() * (1 - ab_t_prev) / (1 - ab_t)

    # 后验均值
    posterior_mean = coef_x0 * x0 + coef_xt * xt

    return posterior_mean, posterior_variance


# 验证:后验分布与前向过程的一致性
torch.manual_seed(42)
T = 1000
betas          = torch.linspace(1e-4, 0.02, T)
alphas         = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, 0)
sqrt_ab        = alphas_cumprod.sqrt()
sqrt_1mab      = (1 - alphas_cumprod).sqrt()

x0 = torch.randn(4, 3, 8, 8)
t  = torch.tensor([500, 500, 500, 500])

# 前向加噪
eps = torch.randn_like(x0)
xt  = sqrt_ab[t].reshape(-1,1,1,1) * x0 + sqrt_1mab[t].reshape(-1,1,1,1) * eps

# 后验分布
mean, var = compute_posterior(x0, xt, t, betas, alphas_cumprod)

print("后验分布 q(xₜ₋₁ | xₜ, x₀) 验证(t=500):")
print(f"  后验均值的形状:{mean.shape}")
print(f"  后验方差的形状:{var.shape}")
print(f"  后验方差的数值:{var[0,0,0,0].item():.6f}(应为正数)✓")
print()

# 从后验分布采样,应该比 xt 更接近 x0
sample = mean + var.sqrt() * torch.randn_like(mean)
dist_xt_to_x0   = (xt - x0).abs().mean().item()
dist_prev_to_x0 = (sample - x0).abs().mean().item()
print(f"  xₜ 与 x₀ 的平均距离:    {dist_xt_to_x0:.4f}")
print(f"  xₜ₋₁ 与 x₀ 的平均距离:  {dist_prev_to_x0:.4f}")
print(f"  xₜ₋₁ 更接近 x₀:{dist_prev_to_x0 < dist_xt_to_x0} ✓")

三、训练目标:从 ELBO 到简化损失

3.1 最大似然目标

生成模型的目标是最大化数据的对数似然 log ⁡ p θ ( x 0 ) \log p_\theta(x_0) logpθ(x0)。由于直接优化困难,DDPM 优化证据下界(ELBO)

log ⁡ p θ ( x 0 ) ≥ E q  ⁣ [ log ⁡ p θ ( x 0 : T ) q ( x 1 : T ∣ x 0 ) ] : = − L ELBO \log p_\theta(x_0) \geq \mathbb{E}_q\!\left[\log \frac{p_\theta(x_{0:T})}{q(x_{1:T} \mid x_0)}\right] := -\mathcal{L}_{\text{ELBO}} logpθ(x0)Eq[logq(x1:Tx0)pθ(x0:T)]:=LELBO

展开后,ELBO 分解为三项:

L ELBO = L T ⏟ 先验匹配 + ∑ t = 2 T L t − 1 ⏟ 去噪匹配 + L 0 ⏟ 重建项 \mathcal{L}_{\text{ELBO}} = \underbrace{\mathcal{L}_T}_{\text{先验匹配}} + \sum_{t=2}^{T}\underbrace{\mathcal{L}_{t-1}}_{\text{去噪匹配}} + \underbrace{\mathcal{L}_0}_{\text{重建项}} LELBO=先验匹配 LT+t=2T去噪匹配 Lt1+重建项 L0

# ELBO 的三个组成部分(直觉解释)

print("ELBO 的三个组成部分:")
print()
print("① L_T(先验匹配项)")
print("   KL( q(x_T|x_0) || p(x_T) )")
print("   = x_T 的分布与 N(0,I) 的 KL 散度")
print("   → 没有可学参数,通常忽略(调度选得好的话这项很小)")
print()
print("② L_{t-1}(去噪匹配项,t=2,...,T)← 主要训练目标")
print("   KL( q(xₜ₋₁|xₜ,x₀) || p_θ(xₜ₋₁|xₜ) )")
print("   = 学到的反向过程 vs 真实后验分布 的 KL 散度")
print("   → 网络要学习的核心目标!")
print()
print("③ L_0(重建项)")
print("   -log p_θ(x₀|x₁)")
print("   = 从 x₁ 重建原图的对数似然")
print("   → 通常用离散化的高斯近似处理")

3.2 去噪匹配项的推导

核心的去噪匹配项 L t − 1 \mathcal{L}_{t-1} Lt1 是两个高斯分布的 KL 散度:

L t − 1 = E q  ⁣ [ 1 2 σ t 2   ∥ μ ~ t ( x t , x 0 ) − μ θ ( x t , t ) ∥ 2 ] + C \mathcal{L}_{t-1} = \mathbb{E}_q\!\left[\frac{1}{2\sigma_t^2}\,\left\|\tilde{\mu}_t(x_t, x_0) - \mu_\theta(x_t, t)\right\|^2\right] + C Lt1=Eq[2σt21μ~t(xt,x0)μθ(xt,t)2]+C

关键步骤:把 μ ~ t \tilde{\mu}_t μ~t 中的 x 0 x_0 x0 替换掉。

由前向过程的重参数化公式 x t = α ˉ t   x 0 + 1 − α ˉ t   ϵ x_t = \sqrt{\bar{\alpha}_t}\,x_0 + \sqrt{1-\bar{\alpha}_t}\,\epsilon xt=αˉt x0+1αˉt ϵ,可以反解出:

x 0 = x t − 1 − α ˉ t   ϵ α ˉ t x_0 = \frac{x_t - \sqrt{1-\bar{\alpha}_t}\,\epsilon}{\sqrt{\bar{\alpha}_t}} x0=αˉt xt1αˉt ϵ

代入后验均值 μ ~ t \tilde{\mu}_t μ~t

μ ~ t ( x t , ϵ ) = 1 α t  ⁣ ( x t − β t 1 − α ˉ t   ϵ ) \tilde{\mu}_t(x_t, \epsilon) = \frac{1}{\sqrt{\alpha_t}}\!\left(x_t - \frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}}\,\epsilon\right) μ~t(xt,ϵ)=αt 1(xt1αˉt βtϵ)

因此,如果网络预测 ϵ θ ( x t , t ) \epsilon_\theta(x_t, t) ϵθ(xt,t),对应的均值参数化为:

μ θ ( x t , t ) = 1 α t  ⁣ ( x t − β t 1 − α ˉ t   ϵ θ ( x t , t ) ) \mu_\theta(x_t, t) = \frac{1}{\sqrt{\alpha_t}}\!\left(x_t - \frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}}\,\epsilon_\theta(x_t, t)\right) μθ(xt,t)=αt 1(xt1αˉt βtϵθ(xt,t))

代入 L t − 1 \mathcal{L}_{t-1} Lt1,经过化简,得到简化损失

L simple = E t , x 0 , ϵ  ⁣ [ ∥ ϵ − ϵ θ  ⁣ ( α ˉ t   x 0 + 1 − α ˉ t   ϵ ,    t ) ∥ 2 ] \boxed{\mathcal{L}_{\text{simple}} = \mathbb{E}_{t, x_0, \epsilon}\!\left[\left\|\epsilon - \epsilon_\theta\!\left(\sqrt{\bar{\alpha}_t}\,x_0 + \sqrt{1-\bar{\alpha}_t}\,\epsilon,\; t\right)\right\|^2\right]} Lsimple=Et,x0,ϵ[ ϵϵθ(αˉt x0+1αˉt ϵ,t) 2]

# 从 KL 散度到简化 MSE 损失的推导代码验证

def compute_elbo_loss_term(
    x0:             torch.Tensor,
    xt:             torch.Tensor,
    t:              torch.Tensor,
    predicted_mean: torch.Tensor,  # 网络预测的 μ_θ(xₜ, t)
    betas:          torch.Tensor,
    alphas_cumprod: torch.Tensor,
) -> torch.Tensor:
    """
    完整的 L_{t-1}(KL 散度形式)
    两个高斯的 KL 散度:KL(q(xₜ₋₁|xₜ,x₀) || p_θ(xₜ₋₁|xₜ))
    = (1/2σ²) ||μ̃_t - μ_θ||²  (当方差相等时)
    """
    true_mean, posterior_var = compute_posterior(x0, xt, t, betas, alphas_cumprod)

    # 两个高斯(方差相同时)的 KL 散度 = MSE / (2 * 方差)
    kl = 0.5 * ((predicted_mean - true_mean) ** 2) / posterior_var
    return kl.mean()


def verify_simplified_loss():
    """
    验证:完整 KL 损失和简化 MSE 损失对应相同的优化方向
    """
    T = 1000
    betas          = torch.linspace(1e-4, 0.02, T)
    alphas         = 1.0 - betas
    alphas_cumprod = torch.cumprod(alphas, 0)
    sqrt_ab        = alphas_cumprod.sqrt()
    sqrt_1mab      = (1 - alphas_cumprod).sqrt()

    torch.manual_seed(0)
    x0  = torch.randn(8, 3, 16, 16)
    t   = torch.randint(1, T, (8,))
    eps = torch.randn_like(x0)

    # 前向加噪
    xt  = sqrt_ab[t].reshape(-1,1,1,1) * x0 + sqrt_1mab[t].reshape(-1,1,1,1) * eps

    # 假设网络"完美预测"epsilon(用真实 eps 作为预测值)
    eps_pred_perfect = eps

    # 从预测的 epsilon 还原均值 μ_θ
    alpha_t    = (1 - betas[t]).reshape(-1,1,1,1)
    beta_t     = betas[t].reshape(-1,1,1,1)
    sqrt_1mab_t = sqrt_1mab[t].reshape(-1,1,1,1)

    mu_theta = (xt - beta_t / sqrt_1mab_t * eps_pred_perfect) / alpha_t.sqrt()

    # 完整 KL 损失
    kl_loss  = compute_elbo_loss_term(x0, xt, t, mu_theta, betas, alphas_cumprod)

    # 简化 MSE 损失
    mse_loss = F.mse_loss(eps_pred_perfect, eps)

    print("完整 KL 损失 vs 简化 MSE 损失(完美预测时):")
    print(f"  KL 损失:   {kl_loss.item():.6f}")
    print(f"  MSE 损失:  {mse_loss.item():.6f}")
    print(f"  当预测完美时(eps_pred = eps),MSE → 0,KL → 0 ✓")
    print()

    # 用随机(差)预测
    eps_pred_bad = torch.randn_like(eps)
    mu_theta_bad = (xt - beta_t / sqrt_1mab_t * eps_pred_bad) / alpha_t.sqrt()

    kl_loss_bad  = compute_elbo_loss_term(x0, xt, t, mu_theta_bad, betas, alphas_cumprod)
    mse_loss_bad = F.mse_loss(eps_pred_bad, eps)

    print("  随机(差)预测时:")
    print(f"  KL 损失:   {kl_loss_bad.item():.6f}(更大)")
    print(f"  MSE 损失:  {mse_loss_bad.item():.6f}(更大)")
    print(f"  → 两者方向一致,MSE 是 KL 的简化版 ✓")


verify_simplified_loss()

四、三种参数化方式

4.1 预测 ϵ \epsilon ϵ x 0 x_0 x0 还是 v v v

# 网络可以预测不同的量,通过变换互相转换

def eps_to_x0(eps_pred, xt, t, sqrt_ab, sqrt_1mab):
    """从预测的 ε 还原 x₀"""
    sab  = sqrt_ab[t].reshape(-1,1,1,1)
    smab = sqrt_1mab[t].reshape(-1,1,1,1)
    return (xt - smab * eps_pred) / sab


def x0_to_eps(x0_pred, xt, t, sqrt_ab, sqrt_1mab):
    """从预测的 x₀ 还原 ε"""
    sab  = sqrt_ab[t].reshape(-1,1,1,1)
    smab = sqrt_1mab[t].reshape(-1,1,1,1)
    return (xt - sab * x0_pred) / smab


def compute_v_target(x0, eps, t, sqrt_ab, sqrt_1mab):
    """
    v-prediction(Salimans & Ho 2022)
    v = √ᾱₜ · ε - √(1-ᾱₜ) · x₀
    在高噪声时与预测 ε 等价,低噪声时与预测 x₀ 等价
    数值更稳定,Imagen 和 Stable Diffusion v2 使用
    """
    sab  = sqrt_ab[t].reshape(-1,1,1,1)
    smab = sqrt_1mab[t].reshape(-1,1,1,1)
    return sab * eps - smab * x0


def v_to_x0(v_pred, xt, t, sqrt_ab, sqrt_1mab):
    """从预测的 v 还原 x₀"""
    sab  = sqrt_ab[t].reshape(-1,1,1,1)
    smab = sqrt_1mab[t].reshape(-1,1,1,1)
    return sab * xt - smab * v_pred


# 三种参数化的对比
T = 1000
betas          = torch.linspace(1e-4, 0.02, T)
alphas_cumprod = torch.cumprod(1 - betas, 0)
sqrt_ab        = alphas_cumprod.sqrt()
sqrt_1mab      = (1 - alphas_cumprod).sqrt()

torch.manual_seed(42)
x0  = torch.randn(4, 3, 8, 8)
eps = torch.randn_like(x0)
t   = torch.tensor([1, 250, 500, 750])

xt  = sqrt_ab[t].reshape(-1,1,1,1) * x0 + sqrt_1mab[t].reshape(-1,1,1,1) * eps

# 计算三种参数化的目标(真实值,训练时作为标签)
v_target = compute_v_target(x0, eps, t, sqrt_ab, sqrt_1mab)

# 验证相互转换的一致性
x0_from_eps = eps_to_x0(eps, xt, t, sqrt_ab, sqrt_1mab)
x0_from_v   = v_to_x0(v_target, xt, t, sqrt_ab, sqrt_1mab)

print("三种参数化方式的等价性验证:")
print(f"  从 ε 还原 x₀ 的误差:{(x0_from_eps - x0).abs().mean():.2e}")
print(f"  从 v 还原 x₀ 的误差:{(x0_from_v   - x0).abs().mean():.2e}")
print(f"  → 三者等价,误差为机器精度 ✓")
print()

# 各参数化在不同时间步的数值稳定性
print("各参数化目标在不同时间步的数值范围:")
print(f"{'t':^6} {'|ε|(预测目标)':^20} {'|x₀|(预测目标)':^20} {'|v|(预测目标)':^20}")
print("─" * 68)
for i, t_val in enumerate(t.tolist()):
    eps_norm = eps[i].abs().mean().item()
    x0_norm  = x0[i].abs().mean().item()
    v_norm   = v_target[i].abs().mean().item()
    print(f"  {t_val:^6} {eps_norm:^20.4f} {x0_norm:^20.4f} {v_norm:^20.4f}")

print()
print("各参数化适用场景:")
comparisons = [
    ("预测 ε(epsilon)", "DDPM 原版,大多数模型默认选择,高噪声阶段稳定"),
    ("预测 x₀",          "直观易理解,低噪声阶段可能方差大,Improved DDPM"),
    ("预测 v",           "全程数值稳定,Imagen/SDv2/SVD,推荐生产使用"),
]
for name, desc in comparisons:
    print(f"  [{name}]:{desc}")

五、反向过程的采样公式

5.1 DDPM 采样(随机,1000步)

def ddpm_sample_step(
    xt:             torch.Tensor,
    t:              int,
    eps_pred:       torch.Tensor,   # 网络预测的噪声 ε_θ(xₜ, t)
    betas:          torch.Tensor,
    alphas_cumprod: torch.Tensor,
) -> torch.Tensor:
    """
    DDPM 反向采样的单步:从 xₜ 得到 xₜ₋₁
    xₜ₋₁ = (1/√αₜ) · (xₜ - βₜ/√(1-ᾱₜ) · ε_θ) + σₜ · z

    注意:t 从 1 开始(不是 0),t=1 时不加随机噪声
    """
    alpha_t   = 1.0 - betas[t].item()
    beta_t    = betas[t].item()
    ab_t      = alphas_cumprod[t].item()

    # 后验方差 σₜ²(这里用 β̃ₜ 近似,也可以用 β 本身)
    if t > 0:
        ab_t_prev = alphas_cumprod[t - 1].item()
        # 后验方差(Improved DDPM 建议学习插值系数)
        posterior_var = beta_t * (1 - ab_t_prev) / (1 - ab_t)
        sigma_t = posterior_var ** 0.5
    else:
        sigma_t = 0.0

    # 去噪均值:μ_θ(xₜ, t)
    mu = (xt - beta_t / (1 - ab_t) ** 0.5 * eps_pred) / alpha_t ** 0.5

    # 加入随机噪声(随机采样,保证多样性)
    if sigma_t > 0:
        z   = torch.randn_like(xt)
        xt_prev = mu + sigma_t * z
    else:
        xt_prev = mu   # t=1 时不加噪声(最后一步)

    return xt_prev


def ddpm_full_sample(
    denoiser,               # 去噪网络 ε_θ(xₜ, t)
    shape:   tuple,         # 生成图像的形状
    T:       int = 1000,
    betas:   torch.Tensor = None,
    device:  str = "cpu",
    verbose: bool = False,
) -> torch.Tensor:
    """
    DDPM 完整采样(Algorithm 2)
    从纯高斯噪声开始,逐步去噪 T 步
    """
    if betas is None:
        betas = torch.linspace(1e-4, 0.02, T).to(device)

    alphas_cumprod = torch.cumprod(1 - betas, 0)

    # 从纯噪声出发:x_T ~ N(0, I)
    xt = torch.randn(*shape, device=device)

    denoiser.eval()
    with torch.no_grad():
        for t in reversed(range(T)):   # T-1, T-2, ..., 1, 0
            # 时间步编码(网络输入)
            t_tensor = torch.full((shape[0],), t, dtype=torch.long, device=device)

            # 网络预测噪声
            eps_pred = denoiser(xt, t_tensor)

            # 去噪一步
            xt = ddpm_sample_step(xt, t, eps_pred, betas, alphas_cumprod)

            if verbose and t % 100 == 0:
                print(f"  t={t:4d}: mean={xt.mean():.4f}, std={xt.std():.4f}")

    return xt


# 采样过程的统计分析(不用真实网络,用简单近似演示框架)
print("DDPM 采样过程统计(随机网络,仅演示框架):")
print()
print("采样步骤的关键参数:")
T     = 10   # 演示用 10 步
betas = torch.linspace(0.01, 0.5, T)
alphas_cumprod = torch.cumprod(1 - betas, 0)

print(f"{'t':^6} {'αₜ':^10} {'βₜ':^10} {'后验方差σ²':^16} {'噪声强度σ':^14}")
print("─" * 60)
for t in range(T - 1, -1, -1):
    alpha_t = (1 - betas[t]).item()
    beta_t  = betas[t].item()
    ab_t    = alphas_cumprod[t].item()
    if t > 0:
        ab_prev = alphas_cumprod[t - 1].item()
        pvar    = beta_t * (1 - ab_prev) / (1 - ab_t)
        sigma   = pvar ** 0.5
    else:
        pvar, sigma = 0.0, 0.0
    print(f"  {t:^6} {alpha_t:^10.4f} {beta_t:^10.4f} {pvar:^16.6f} {sigma:^14.6f}")

5.2 方差的选择:固定 vs 可学习

# DDPM 中方差有两种选择:
# β̃ₜ(后验方差下界)和 βₜ(先验方差上界)
# 原版 DDPM 固定其中之一
# Improved DDPM (Nichol & Dhariwal 2021) 建议学习插值系数

def learned_variance_parameterization():
    """
    Improved DDPM:用网络额外输出方差的插值系数 v ∈ [0, 1]
    Σ_θ(xₜ, t) = exp(v · log βₜ + (1-v) · log β̃ₜ)
    """
    T = 1000
    betas          = torch.linspace(1e-4, 0.02, T)
    alphas         = 1.0 - betas
    alphas_cumprod = torch.cumprod(alphas, 0)
    alphas_cumprod_prev = torch.cat([torch.ones(1), alphas_cumprod[:-1]])

    # 上界(βₜ)和下界(β̃ₜ)
    beta_upper = betas
    beta_lower = betas * (1 - alphas_cumprod_prev) / (1 - alphas_cumprod)
    beta_lower = beta_lower.clamp(min=1e-20)

    # 两种固定方差在不同时间步的差异
    print("固定方差的两种选择对比(对数尺度):")
    print(f"{'时间步 t':^10} {'log β̃ₜ(下界)':^18} {'log βₜ(上界)':^18} {'差异':^10}")
    print("─" * 58)
    for t in [1, 100, 250, 500, 750, 999]:
        lb = math.log(beta_lower[t].item())
        ub = math.log(beta_upper[t].item())
        print(f"  {t:^10} {lb:^18.4f} {ub:^18.4f} {ub-lb:^10.4f}")

    print()
    print("→ 小 t 时两者差异大(学习方差很有价值)")
    print("→ 大 t 时两者接近(接近纯噪声,方差选择影响小)")
    print("→ Improved DDPM 学习插值系数,使网络自适应选择")


learned_variance_parameterization()

六、完整训练循环代码

class DDPMTrainer:
    """
    DDPM 完整训练器
    集成前向过程、损失计算、优化步骤
    """

    def __init__(
        self,
        denoiser:       nn.Module,
        T:              int   = 1000,
        schedule_type:  str   = "linear",
        predict_target: str   = "epsilon",   # "epsilon" | "x0" | "v"
        loss_type:      str   = "simple",    # "simple" | "vlb" | "hybrid"
        beta_start:     float = 1e-4,
        beta_end:       float = 0.02,
        device:         str   = "cpu",
    ):
        self.denoiser       = denoiser.to(device)
        self.T              = T
        self.predict_target = predict_target
        self.loss_type      = loss_type
        self.device         = device

        # ── 噪声调度 ──────────────────────────────────────────
        if schedule_type == "linear":
            betas = torch.linspace(beta_start, beta_end, T)
        elif schedule_type == "cosine":
            s     = 0.008
            t_arr = torch.linspace(0, T, T + 1)
            f     = torch.cos((t_arr / T + s) / (1 + s) * math.pi / 2) ** 2
            ab    = f / f[0]
            betas = (1 - ab[1:] / ab[:-1]).clamp(1e-4, 0.9999)
        else:
            raise ValueError(f"Unknown schedule: {schedule_type}")

        self.register_buffers(betas, device)

    def register_buffers(self, betas: torch.Tensor, device: str):
        """预计算所有需要的统计量"""
        alphas         = 1.0 - betas
        alphas_cumprod = torch.cumprod(alphas, 0)
        alphas_cumprod_prev = torch.cat([torch.ones(1), alphas_cumprod[:-1]])

        self.betas               = betas.to(device)
        self.alphas              = alphas.to(device)
        self.alphas_cumprod      = alphas_cumprod.to(device)
        self.alphas_cumprod_prev = alphas_cumprod_prev.to(device)
        self.sqrt_ab             = alphas_cumprod.sqrt().to(device)
        self.sqrt_1mab           = (1 - alphas_cumprod).sqrt().to(device)
        self.sqrt_recip_a        = (1 / alphas).sqrt().to(device)
        self.posterior_variance  = (
            betas * (1 - alphas_cumprod_prev) / (1 - alphas_cumprod)
        ).clamp(min=1e-20).to(device)

    def q_sample(self, x0: torch.Tensor, t: torch.Tensor, noise=None):
        """前向加噪:x₀ → xₜ"""
        if noise is None:
            noise = torch.randn_like(x0)
        sab  = self.sqrt_ab[t].reshape(-1, 1, 1, 1)
        smab = self.sqrt_1mab[t].reshape(-1, 1, 1, 1)
        return sab * x0 + smab * noise, noise

    def get_target(self, x0: torch.Tensor, noise: torch.Tensor,
                   t: torch.Tensor) -> torch.Tensor:
        """根据参数化方式计算训练目标"""
        if self.predict_target == "epsilon":
            return noise

        elif self.predict_target == "x0":
            return x0

        elif self.predict_target == "v":
            sab  = self.sqrt_ab[t].reshape(-1, 1, 1, 1)
            smab = self.sqrt_1mab[t].reshape(-1, 1, 1, 1)
            return sab * noise - smab * x0

        else:
            raise ValueError(f"Unknown predict_target: {self.predict_target}")

    def compute_loss(
        self,
        x0:   torch.Tensor,
        noise: torch.Tensor = None,
    ) -> dict:
        """
        计算训练损失
        返回 dict 包含 loss 和调试信息
        """
        B = x0.shape[0]

        # 随机时间步
        t = torch.randint(0, self.T, (B,), device=self.device)

        # 前向加噪
        if noise is None:
            noise = torch.randn_like(x0)
        xt, noise = self.q_sample(x0, t, noise)

        # 网络预测
        model_output = self.denoiser(xt, t)

        # 目标
        target = self.get_target(x0, noise, t)

        # 简化 MSE 损失
        if self.loss_type == "simple":
            loss = F.mse_loss(model_output, target)

        # 加权 MSE(按 SNR 权重)
        elif self.loss_type == "weighted":
            snr     = self.alphas_cumprod[t] / (1 - self.alphas_cumprod[t])
            weights = torch.minimum(snr, torch.full_like(snr, 5.0)) / snr
            weights = weights.reshape(-1, 1, 1, 1)
            loss    = (weights * (model_output - target) ** 2).mean()

        return {
            "loss":  loss,
            "t":     t.float().mean().item(),    # 平均时间步(调试用)
            "pred_norm": model_output.norm().item(),
            "target_norm": target.norm().item(),
        }

    def train_step(
        self,
        x0:        torch.Tensor,
        optimizer: torch.optim.Optimizer,
    ) -> dict:
        """执行一步训练"""
        self.denoiser.train()
        optimizer.zero_grad()

        metrics = self.compute_loss(x0)
        metrics["loss"].backward()

        # 梯度裁剪(稳定训练)
        torch.nn.utils.clip_grad_norm_(self.denoiser.parameters(), max_norm=1.0)

        optimizer.step()
        return {k: v.item() if isinstance(v, torch.Tensor) else v
                for k, v in metrics.items()}


# ── 测试训练器 ──────────────────────────────────────────────────

# 极简去噪网络(仅用于测试框架,实际用 U-Net,见第四篇)
class TinyDenoiser(nn.Module):
    """极简的去噪网络(测试用)"""
    def __init__(self, channels: int = 3, hidden: int = 64):
        super().__init__()
        # 时间步嵌入
        self.time_embed = nn.Sequential(
            nn.Linear(1, hidden),
            nn.SiLU(),
            nn.Linear(hidden, hidden),
        )
        # 图像处理(极简版,实际是 U-Net)
        self.net = nn.Sequential(
            nn.Conv2d(channels, hidden, 3, padding=1),
            nn.GroupNorm(8, hidden),
            nn.SiLU(),
            nn.Conv2d(hidden, hidden, 3, padding=1),
            nn.GroupNorm(8, hidden),
            nn.SiLU(),
            nn.Conv2d(hidden, channels, 3, padding=1),
        )

    def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        # 时间步嵌入
        t_emb = self.time_embed(t.float().unsqueeze(1) / 1000)  # (B, hidden)
        # 简化:将时间信息加到特征图(实际是通过 ResBlock 注入)
        return self.net(x)


# 初始化和测试
denoiser = TinyDenoiser(channels=3, hidden=32)
trainer  = DDPMTrainer(denoiser, T=1000, predict_target="epsilon")
optimizer = torch.optim.Adam(denoiser.parameters(), lr=2e-4)

# 模拟一批训练数据(随机图像,实际应该是真实图片)
x0_batch = torch.randn(8, 3, 32, 32)   # 8 张 32×32 图像

print("DDPM 训练器测试:")
print()
for step in range(5):
    metrics = trainer.train_step(x0_batch, optimizer)
    print(f"  步骤 {step+1}: loss={metrics['loss']:.4f}, "
          f"平均t={metrics['t']:.1f}, "
          f"预测范数={metrics['pred_norm']:.4f}")

七、完整采样(推理)代码

class DDPMSampler:
    """
    DDPM 完整采样器
    支持不同的采样策略和可视化
    """

    def __init__(
        self,
        denoiser:       nn.Module,
        T:              int = 1000,
        predict_target: str = "epsilon",
        device:         str = "cpu",
    ):
        self.denoiser       = denoiser.to(device)
        self.T              = T
        self.predict_target = predict_target
        self.device         = device

        # 噪声调度(与训练时保持一致!)
        betas = torch.linspace(1e-4, 0.02, T).to(device)
        alphas         = 1.0 - betas
        alphas_cumprod = torch.cumprod(alphas, 0)
        alphas_cumprod_prev = torch.cat([
            torch.ones(1, device=device),
            alphas_cumprod[:-1]
        ])

        self.betas               = betas
        self.alphas              = alphas
        self.alphas_cumprod      = alphas_cumprod
        self.alphas_cumprod_prev = alphas_cumprod_prev
        self.sqrt_ab             = alphas_cumprod.sqrt()
        self.sqrt_1mab           = (1 - alphas_cumprod).sqrt()
        self.posterior_variance  = (
            betas * (1 - alphas_cumprod_prev) / (1 - alphas_cumprod)
        ).clamp(1e-20)

    def predict_x0_from_output(
        self,
        model_output: torch.Tensor,
        xt:           torch.Tensor,
        t:            torch.Tensor,
    ) -> torch.Tensor:
        """从网络输出还原 x̂₀(估计的原始图像)"""
        sab  = self.sqrt_ab[t].reshape(-1, 1, 1, 1)
        smab = self.sqrt_1mab[t].reshape(-1, 1, 1, 1)

        if self.predict_target == "epsilon":
            # x̂₀ = (xₜ - √(1-ᾱₜ) · ε̂) / √ᾱₜ
            x0_hat = (xt - smab * model_output) / sab
        elif self.predict_target == "x0":
            x0_hat = model_output
        elif self.predict_target == "v":
            # x̂₀ = √ᾱₜ · xₜ - √(1-ᾱₜ) · v̂
            x0_hat = sab * xt - smab * model_output
        else:
            raise ValueError()

        # 限制范围(防止极端值,通常图像在 [-1, 1])
        return x0_hat.clamp(-1, 1)

    def p_mean_variance(
        self,
        xt: torch.Tensor,
        t:  torch.Tensor,
    ) -> tuple:
        """计算 p_θ(xₜ₋₁|xₜ) 的均值和方差"""
        model_output = self.denoiser(xt, t)
        x0_hat       = self.predict_x0_from_output(model_output, xt, t)

        # 后验均值(用估计的 x̂₀ 代替真实 x₀)
        ab_prev = self.alphas_cumprod_prev[t].reshape(-1, 1, 1, 1)
        ab      = self.alphas_cumprod[t].reshape(-1, 1, 1, 1)
        b       = self.betas[t].reshape(-1, 1, 1, 1)
        a       = self.alphas[t].reshape(-1, 1, 1, 1)

        mean = ab_prev.sqrt() * b / (1 - ab) * x0_hat + \
               a.sqrt() * (1 - ab_prev) / (1 - ab) * xt

        var  = self.posterior_variance[t].reshape(-1, 1, 1, 1)
        log_var = var.log()

        return mean, var, log_var, x0_hat

    @torch.no_grad()
    def p_sample(self, xt: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        """单步采样:xₜ → xₜ₋₁"""
        mean, var, _, _ = self.p_mean_variance(xt, t)

        # t > 0 时加随机噪声,t = 0 时不加
        noise = torch.zeros_like(xt)
        mask  = (t > 0).float().reshape(-1, 1, 1, 1)
        noise = mask * torch.randn_like(xt)

        return mean + var.sqrt() * noise

    @torch.no_grad()
    def sample(
        self,
        shape:      tuple,
        return_intermediates: bool = False,
        log_every:  int   = 100,
    ) -> tuple:
        """
        完整采样:从 x_T ~ N(0,I) 生成图像
        """
        self.denoiser.eval()
        B = shape[0]
        device = self.device

        # 从纯噪声出发
        xt = torch.randn(*shape, device=device)

        intermediates = [xt.cpu()] if return_intermediates else []

        for i, t_val in enumerate(reversed(range(self.T))):
            t = torch.full((B,), t_val, dtype=torch.long, device=device)
            xt = self.p_sample(xt, t)

            if return_intermediates and (t_val % log_every == 0):
                intermediates.append(xt.cpu())

        # 反归一化到 [0, 1]
        image = (xt.clamp(-1, 1) + 1) / 2

        if return_intermediates:
            return image, intermediates
        return image, []

    def sample_progress_stats(self, shape: tuple) -> None:
        """展示采样过程中的统计量变化(可视化调试用)"""
        self.denoiser.eval()
        B      = shape[0]
        device = self.device
        xt     = torch.randn(*shape, device=device)

        print("DDPM 采样过程统计量变化:")
        print(f"{'时间步 t':^10} {'xₜ 均值':^12} {'xₜ 方差':^12} {'x̂₀ 均值':^12} {'x̂₀ 方差':^12}")
        print("─" * 62)

        log_steps = list(range(self.T - 1, -1, -1))[::self.T // 10]

        with torch.no_grad():
            for t_val in reversed(range(self.T)):
                t = torch.full((B,), t_val, dtype=torch.long, device=device)
                _, _, _, x0_hat = self.p_mean_variance(xt, t)
                xt = self.p_sample(xt, t)

                if t_val in log_steps:
                    print(f"  {t_val:^10} {xt.mean():.4f}      "
                          f"{xt.var():.4f}      "
                          f"{x0_hat.mean():.4f}      "
                          f"{x0_hat.var():.4f}")


# ── 测试采样器 ──────────────────────────────────────────────────
sampler = DDPMSampler(denoiser, T=100)   # 用 100 步演示(完整是 1000 步)
sampler.sample_progress_stats(shape=(2, 3, 8, 8))

print()

# 完整采样(演示)
image, intermediates = sampler.sample(
    shape=(4, 3, 8, 8),
    return_intermediates=True,
    log_every=20,
)
print(f"\n采样完成:")
print(f"  生成图像形状:{image.shape}")
print(f"  像素值范围:[{image.min():.4f}, {image.max():.4f}](应在 [0,1])")
print(f"  中间帧数量:{len(intermediates)}")
print()
print("完整流程总结:")
print("  训练:随机 t → 加噪 → 预测噪声 → MSE 损失 → 反向传播")
print("  推理:x_T ~ N(0,I) → 1000 步去噪 → x̂₀ 生成图像")

总结

概念 公式 含义
反向过程 p θ ( x t − 1 ∣ x t ) = N ( μ θ , Σ θ ) p_\theta(x_{t-1}|x_t) = \mathcal{N}(\mu_\theta, \Sigma_\theta) pθ(xt1xt)=N(μθ,Σθ) 神经网络近似真实反向分布
后验分布 q ( x t − 1 ∣ x t , x 0 ) = N ( μ ~ t , β ~ t I ) q(x_{t-1}|x_t, x_0) = \mathcal{N}(\tilde{\mu}_t, \tilde{\beta}_t I) q(xt1xt,x0)=N(μ~t,β~tI) 有解析形式,是网络学习的目标
训练目标 L = ∣ ϵ − ϵ θ ( x t , t ) ∣ 2 \mathcal{L} = |\epsilon - \epsilon_\theta(x_t, t)|^2 L=ϵϵθ(xt,t)2 从 ELBO 化简而来的 MSE
均值参数化 μ θ = 1 α t ( x t − β t 1 − α ˉ t ϵ θ ) \mu_\theta = \frac{1}{\sqrt{\alpha_t}}(x_t - \frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}}\epsilon_\theta) μθ=αt 1(xt1αˉt βtϵθ) 预测 ϵ \epsilon ϵ 等价于预测均值
采样公式 x t − 1 = μ θ + σ t z x_{t-1} = \mu_\theta + \sigma_t z xt1=μθ+σtz 逐步去噪,加入随机性保证多样性

三种参数化方式的选择

  • 预测 ϵ \epsilon ϵ:DDPM 原版,大多数模型的默认选择
  • 预测 x 0 x_0 x0:直观,但高噪声阶段方差大
  • 预测 v v v:全程数值稳定,Stable Diffusion v2 / Imagen 的选择

下一篇预告:U-Net 架构——去噪网络的设计:为什么是 U-Net?跳跃连接的作用,时间步 t t t 如何注入网络,自注意力模块在哪里加,以及 Residual Block 的具体实现。


💬 从 ELBO 到简化 MSE 的推导过程,哪一步你觉得最关键? 欢迎评论区分享!

🙏 如果这篇帮到你,点赞 + 收藏,系列持续更新!


本文为原创技术分享。代码在 Python 3.12 + PyTorch 2.x 下验证。最后更新:2026-05-20

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐