【扩散模型系列·第三篇】反向过程:学习去噪,训练目标推导与简化损失全解析
【扩散模型系列·第三篇】反向过程:学习去噪,训练目标推导与简化损失全解析
作者:技术博主 | 更新时间:2026-05-20 | 阅读时长:约 23 分钟
系列:扩散模型从零到实战(共 8 篇)
环境:Python 3.12 + PyTorch 2.x
标签:扩散模型DDPM反向过程ELBO去噪训练目标参数化采样算法

🔥 本篇目标:前两篇建立了直觉和前向过程的数学。本篇解决核心问题:神经网络到底在学什么?损失函数是怎么来的?采样时每步如何计算? 从变分下界(ELBO)出发推导出简化的 MSE 损失,再推导采样公式,最后实现完整的训练和采样代码。读完本篇,DDPM 的完整训练-推理循环就彻底透明了。
系列进度
| 篇次 | 主题 | 状态 |
|---|---|---|
| 第一篇 | 扩散模型是什么:从加噪到去噪的直觉 | ✅ 已发布 |
| 第二篇 | 数学基础:前向过程与马尔可夫链 | ✅ 已发布 |
| 第三篇(本篇) | 反向过程:学习去噪 | — |
| 第四篇 | U-Net 架构:去噪网络的设计 | 即将发布 |
| 第五篇 | DDIM:加速采样 | 即将发布 |
| 第六篇 | 条件生成:文本引导与 CFG | 即将发布 |
| 第七篇 | Stable Diffusion:潜在扩散模型 | 即将发布 |
| 第八篇 | 实战与前沿:LoRA、DreamBooth、FLUX | 即将发布 |
目录
- 一、反向过程的数学定义
- 二、为什么反向过程也是高斯分布
- 三、训练目标:从 ELBO 到简化损失
- 四、三种参数化方式:预测 ϵ \epsilon ϵ、 x 0 x_0 x0 还是 v v v
- 五、反向过程的采样公式
- 六、完整训练循环代码
- 七、完整采样(推理)代码
一、反向过程的数学定义
1.1 我们想要什么
前向过程 q q q 是已知的(人为定义),我们真正需要的是反向过程 p p p:给定一张加了噪声的图 x t x_t xt,推算出上一步更干净的 x t − 1 x_{t-1} xt−1。
如果我们知道完整的数据分布 q ( x 0 ) q(x_0) q(x0),真实的反向条件分布是:
q ( x t − 1 ∣ x t ) = q ( x t ∣ x t − 1 ) q ( x t − 1 ) q ( x t ) q(x_{t-1} \mid x_t) = \frac{q(x_t \mid x_{t-1})\, q(x_{t-1})}{q(x_t)} q(xt−1∣xt)=q(xt)q(xt∣xt−1)q(xt−1)
问题: q ( x t − 1 ) q(x_{t-1}) q(xt−1) 需要对所有可能的 x 0 x_0 x0 积分,计算不可行。
解决:用神经网络 p θ p_\theta pθ 来近似这个分布:
p θ ( x t − 1 ∣ x t ) = N ( x t − 1 ; μ θ ( x t , t ) , Σ θ ( x t , t ) ) p_\theta(x_{t-1} \mid x_t) = \mathcal{N}\!\left(x_{t-1};\, \mu_\theta(x_t, t),\, \Sigma_\theta(x_t, t)\right) pθ(xt−1∣xt)=N(xt−1;μθ(xt,t),Σθ(xt,t))
网络的任务:学习均值 μ θ \mu_\theta μθ 和方差 Σ θ \Sigma_\theta Σθ,使得 p θ ≈ q p_\theta \approx q pθ≈q。
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
# 反向过程的核心问题
print("反向过程的核心问题:")
print()
print(" 已知:xₜ(加了 t 步噪声的图像)")
print(" 想要:xₜ₋₁(少一步噪声的图像)")
print()
print(" 真实分布 q(xₜ₋₁|xₜ) 需要积分所有 x₀,不可行")
print()
print(" → 用神经网络 p_θ(xₜ₋₁|xₜ) 来近似")
print(" → p_θ 假设为高斯分布(均值和方差由网络预测)")
print()
print(" 关键发现(Sohl-Dickstein et al., 2015):")
print(" 当 βₜ 足够小时,q(xₜ₋₁|xₜ) 近似高斯分布")
print(" → 用高斯来近似反向过程是合理的!")
二、为什么反向过程也是高斯分布
2.1 后验分布的解析形式
虽然 q ( x t − 1 ∣ x t ) q(x_{t-1} \mid x_t) q(xt−1∣xt) 难以计算,但给定 x 0 x_0 x0 的条件后验有解析形式:
q ( x t − 1 ∣ x t , x 0 ) = N ( x t − 1 ; μ ~ t ( x t , x 0 ) , β ~ t I ) q(x_{t-1} \mid x_t, x_0) = \mathcal{N}\!\left(x_{t-1};\, \tilde{\mu}_t(x_t, x_0),\, \tilde{\beta}_t I\right) q(xt−1∣xt,x0)=N(xt−1;μ~t(xt,x0),β~tI)
其中:
μ ~ t = α ˉ t − 1 β t 1 − α ˉ t x 0 + α t ( 1 − α ˉ t − 1 ) 1 − α ˉ t x t \tilde{\mu}_t = \frac{\sqrt{\bar{\alpha}_{t-1}}\,\beta_t}{1 - \bar{\alpha}_t}\,x_0 + \frac{\sqrt{\alpha_t}(1 - \bar{\alpha}_{t-1})}{1 - \bar{\alpha}_t}\,x_t μ~t=1−αˉtαˉt−1βtx0+1−αˉtαt(1−αˉt−1)xt
β ~ t = ( 1 − α ˉ t − 1 ) β t 1 − α ˉ t \tilde{\beta}_t = \frac{(1 - \bar{\alpha}_{t-1})\,\beta_t}{1 - \bar{\alpha}_t} β~t=1−αˉt(1−αˉt−1)βt
def compute_posterior(
x0: torch.Tensor,
xt: torch.Tensor,
t: torch.Tensor,
betas: torch.Tensor,
alphas_cumprod: torch.Tensor,
) -> tuple:
"""
计算后验分布 q(xₜ₋₁ | xₜ, x₀) 的均值和方差
这是反向过程的"真实目标"——网络要学的正是这个分布
"""
B = x0.shape[0]
def extract(arr, t, shape):
vals = arr[t]
return vals.reshape(B, *([1] * (len(shape) - 1)))
# 提取关键系数
beta_t = extract(betas, t, x0.shape)
ab_t = extract(alphas_cumprod, t, x0.shape)
# ᾱₜ₋₁(t=0 时令 ᾱ₋₁ = 1)
alphas_cumprod_prev = torch.cat([torch.ones(1), alphas_cumprod[:-1]])
ab_t_prev = extract(alphas_cumprod_prev, t, x0.shape)
alpha_t = extract(1 - betas, t, x0.shape)
# 后验方差:β̃ₜ = (1 - ᾱₜ₋₁) / (1 - ᾱₜ) · βₜ
posterior_variance = (1 - ab_t_prev) / (1 - ab_t) * beta_t
posterior_variance = posterior_variance.clamp(min=1e-20)
# 后验均值的两个系数
coef_x0 = ab_t_prev.sqrt() * beta_t / (1 - ab_t)
coef_xt = alpha_t.sqrt() * (1 - ab_t_prev) / (1 - ab_t)
# 后验均值
posterior_mean = coef_x0 * x0 + coef_xt * xt
return posterior_mean, posterior_variance
# 验证:后验分布与前向过程的一致性
torch.manual_seed(42)
T = 1000
betas = torch.linspace(1e-4, 0.02, T)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, 0)
sqrt_ab = alphas_cumprod.sqrt()
sqrt_1mab = (1 - alphas_cumprod).sqrt()
x0 = torch.randn(4, 3, 8, 8)
t = torch.tensor([500, 500, 500, 500])
# 前向加噪
eps = torch.randn_like(x0)
xt = sqrt_ab[t].reshape(-1,1,1,1) * x0 + sqrt_1mab[t].reshape(-1,1,1,1) * eps
# 后验分布
mean, var = compute_posterior(x0, xt, t, betas, alphas_cumprod)
print("后验分布 q(xₜ₋₁ | xₜ, x₀) 验证(t=500):")
print(f" 后验均值的形状:{mean.shape}")
print(f" 后验方差的形状:{var.shape}")
print(f" 后验方差的数值:{var[0,0,0,0].item():.6f}(应为正数)✓")
print()
# 从后验分布采样,应该比 xt 更接近 x0
sample = mean + var.sqrt() * torch.randn_like(mean)
dist_xt_to_x0 = (xt - x0).abs().mean().item()
dist_prev_to_x0 = (sample - x0).abs().mean().item()
print(f" xₜ 与 x₀ 的平均距离: {dist_xt_to_x0:.4f}")
print(f" xₜ₋₁ 与 x₀ 的平均距离: {dist_prev_to_x0:.4f}")
print(f" xₜ₋₁ 更接近 x₀:{dist_prev_to_x0 < dist_xt_to_x0} ✓")
三、训练目标:从 ELBO 到简化损失
3.1 最大似然目标
生成模型的目标是最大化数据的对数似然 log p θ ( x 0 ) \log p_\theta(x_0) logpθ(x0)。由于直接优化困难,DDPM 优化证据下界(ELBO):
log p θ ( x 0 ) ≥ E q [ log p θ ( x 0 : T ) q ( x 1 : T ∣ x 0 ) ] : = − L ELBO \log p_\theta(x_0) \geq \mathbb{E}_q\!\left[\log \frac{p_\theta(x_{0:T})}{q(x_{1:T} \mid x_0)}\right] := -\mathcal{L}_{\text{ELBO}} logpθ(x0)≥Eq[logq(x1:T∣x0)pθ(x0:T)]:=−LELBO
展开后,ELBO 分解为三项:
L ELBO = L T ⏟ 先验匹配 + ∑ t = 2 T L t − 1 ⏟ 去噪匹配 + L 0 ⏟ 重建项 \mathcal{L}_{\text{ELBO}} = \underbrace{\mathcal{L}_T}_{\text{先验匹配}} + \sum_{t=2}^{T}\underbrace{\mathcal{L}_{t-1}}_{\text{去噪匹配}} + \underbrace{\mathcal{L}_0}_{\text{重建项}} LELBO=先验匹配 LT+t=2∑T去噪匹配 Lt−1+重建项 L0
# ELBO 的三个组成部分(直觉解释)
print("ELBO 的三个组成部分:")
print()
print("① L_T(先验匹配项)")
print(" KL( q(x_T|x_0) || p(x_T) )")
print(" = x_T 的分布与 N(0,I) 的 KL 散度")
print(" → 没有可学参数,通常忽略(调度选得好的话这项很小)")
print()
print("② L_{t-1}(去噪匹配项,t=2,...,T)← 主要训练目标")
print(" KL( q(xₜ₋₁|xₜ,x₀) || p_θ(xₜ₋₁|xₜ) )")
print(" = 学到的反向过程 vs 真实后验分布 的 KL 散度")
print(" → 网络要学习的核心目标!")
print()
print("③ L_0(重建项)")
print(" -log p_θ(x₀|x₁)")
print(" = 从 x₁ 重建原图的对数似然")
print(" → 通常用离散化的高斯近似处理")
3.2 去噪匹配项的推导
核心的去噪匹配项 L t − 1 \mathcal{L}_{t-1} Lt−1 是两个高斯分布的 KL 散度:
L t − 1 = E q [ 1 2 σ t 2 ∥ μ ~ t ( x t , x 0 ) − μ θ ( x t , t ) ∥ 2 ] + C \mathcal{L}_{t-1} = \mathbb{E}_q\!\left[\frac{1}{2\sigma_t^2}\,\left\|\tilde{\mu}_t(x_t, x_0) - \mu_\theta(x_t, t)\right\|^2\right] + C Lt−1=Eq[2σt21∥μ~t(xt,x0)−μθ(xt,t)∥2]+C
关键步骤:把 μ ~ t \tilde{\mu}_t μ~t 中的 x 0 x_0 x0 替换掉。
由前向过程的重参数化公式 x t = α ˉ t x 0 + 1 − α ˉ t ϵ x_t = \sqrt{\bar{\alpha}_t}\,x_0 + \sqrt{1-\bar{\alpha}_t}\,\epsilon xt=αˉtx0+1−αˉtϵ,可以反解出:
x 0 = x t − 1 − α ˉ t ϵ α ˉ t x_0 = \frac{x_t - \sqrt{1-\bar{\alpha}_t}\,\epsilon}{\sqrt{\bar{\alpha}_t}} x0=αˉtxt−1−αˉtϵ
代入后验均值 μ ~ t \tilde{\mu}_t μ~t:
μ ~ t ( x t , ϵ ) = 1 α t ( x t − β t 1 − α ˉ t ϵ ) \tilde{\mu}_t(x_t, \epsilon) = \frac{1}{\sqrt{\alpha_t}}\!\left(x_t - \frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}}\,\epsilon\right) μ~t(xt,ϵ)=αt1(xt−1−αˉtβtϵ)
因此,如果网络预测 ϵ θ ( x t , t ) \epsilon_\theta(x_t, t) ϵθ(xt,t),对应的均值参数化为:
μ θ ( x t , t ) = 1 α t ( x t − β t 1 − α ˉ t ϵ θ ( x t , t ) ) \mu_\theta(x_t, t) = \frac{1}{\sqrt{\alpha_t}}\!\left(x_t - \frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}}\,\epsilon_\theta(x_t, t)\right) μθ(xt,t)=αt1(xt−1−αˉtβtϵθ(xt,t))
代入 L t − 1 \mathcal{L}_{t-1} Lt−1,经过化简,得到简化损失:
L simple = E t , x 0 , ϵ [ ∥ ϵ − ϵ θ ( α ˉ t x 0 + 1 − α ˉ t ϵ , t ) ∥ 2 ] \boxed{\mathcal{L}_{\text{simple}} = \mathbb{E}_{t, x_0, \epsilon}\!\left[\left\|\epsilon - \epsilon_\theta\!\left(\sqrt{\bar{\alpha}_t}\,x_0 + \sqrt{1-\bar{\alpha}_t}\,\epsilon,\; t\right)\right\|^2\right]} Lsimple=Et,x0,ϵ[ ϵ−ϵθ(αˉtx0+1−αˉtϵ,t) 2]
# 从 KL 散度到简化 MSE 损失的推导代码验证
def compute_elbo_loss_term(
x0: torch.Tensor,
xt: torch.Tensor,
t: torch.Tensor,
predicted_mean: torch.Tensor, # 网络预测的 μ_θ(xₜ, t)
betas: torch.Tensor,
alphas_cumprod: torch.Tensor,
) -> torch.Tensor:
"""
完整的 L_{t-1}(KL 散度形式)
两个高斯的 KL 散度:KL(q(xₜ₋₁|xₜ,x₀) || p_θ(xₜ₋₁|xₜ))
= (1/2σ²) ||μ̃_t - μ_θ||² (当方差相等时)
"""
true_mean, posterior_var = compute_posterior(x0, xt, t, betas, alphas_cumprod)
# 两个高斯(方差相同时)的 KL 散度 = MSE / (2 * 方差)
kl = 0.5 * ((predicted_mean - true_mean) ** 2) / posterior_var
return kl.mean()
def verify_simplified_loss():
"""
验证:完整 KL 损失和简化 MSE 损失对应相同的优化方向
"""
T = 1000
betas = torch.linspace(1e-4, 0.02, T)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, 0)
sqrt_ab = alphas_cumprod.sqrt()
sqrt_1mab = (1 - alphas_cumprod).sqrt()
torch.manual_seed(0)
x0 = torch.randn(8, 3, 16, 16)
t = torch.randint(1, T, (8,))
eps = torch.randn_like(x0)
# 前向加噪
xt = sqrt_ab[t].reshape(-1,1,1,1) * x0 + sqrt_1mab[t].reshape(-1,1,1,1) * eps
# 假设网络"完美预测"epsilon(用真实 eps 作为预测值)
eps_pred_perfect = eps
# 从预测的 epsilon 还原均值 μ_θ
alpha_t = (1 - betas[t]).reshape(-1,1,1,1)
beta_t = betas[t].reshape(-1,1,1,1)
sqrt_1mab_t = sqrt_1mab[t].reshape(-1,1,1,1)
mu_theta = (xt - beta_t / sqrt_1mab_t * eps_pred_perfect) / alpha_t.sqrt()
# 完整 KL 损失
kl_loss = compute_elbo_loss_term(x0, xt, t, mu_theta, betas, alphas_cumprod)
# 简化 MSE 损失
mse_loss = F.mse_loss(eps_pred_perfect, eps)
print("完整 KL 损失 vs 简化 MSE 损失(完美预测时):")
print(f" KL 损失: {kl_loss.item():.6f}")
print(f" MSE 损失: {mse_loss.item():.6f}")
print(f" 当预测完美时(eps_pred = eps),MSE → 0,KL → 0 ✓")
print()
# 用随机(差)预测
eps_pred_bad = torch.randn_like(eps)
mu_theta_bad = (xt - beta_t / sqrt_1mab_t * eps_pred_bad) / alpha_t.sqrt()
kl_loss_bad = compute_elbo_loss_term(x0, xt, t, mu_theta_bad, betas, alphas_cumprod)
mse_loss_bad = F.mse_loss(eps_pred_bad, eps)
print(" 随机(差)预测时:")
print(f" KL 损失: {kl_loss_bad.item():.6f}(更大)")
print(f" MSE 损失: {mse_loss_bad.item():.6f}(更大)")
print(f" → 两者方向一致,MSE 是 KL 的简化版 ✓")
verify_simplified_loss()
四、三种参数化方式
4.1 预测 ϵ \epsilon ϵ、 x 0 x_0 x0 还是 v v v
# 网络可以预测不同的量,通过变换互相转换
def eps_to_x0(eps_pred, xt, t, sqrt_ab, sqrt_1mab):
"""从预测的 ε 还原 x₀"""
sab = sqrt_ab[t].reshape(-1,1,1,1)
smab = sqrt_1mab[t].reshape(-1,1,1,1)
return (xt - smab * eps_pred) / sab
def x0_to_eps(x0_pred, xt, t, sqrt_ab, sqrt_1mab):
"""从预测的 x₀ 还原 ε"""
sab = sqrt_ab[t].reshape(-1,1,1,1)
smab = sqrt_1mab[t].reshape(-1,1,1,1)
return (xt - sab * x0_pred) / smab
def compute_v_target(x0, eps, t, sqrt_ab, sqrt_1mab):
"""
v-prediction(Salimans & Ho 2022)
v = √ᾱₜ · ε - √(1-ᾱₜ) · x₀
在高噪声时与预测 ε 等价,低噪声时与预测 x₀ 等价
数值更稳定,Imagen 和 Stable Diffusion v2 使用
"""
sab = sqrt_ab[t].reshape(-1,1,1,1)
smab = sqrt_1mab[t].reshape(-1,1,1,1)
return sab * eps - smab * x0
def v_to_x0(v_pred, xt, t, sqrt_ab, sqrt_1mab):
"""从预测的 v 还原 x₀"""
sab = sqrt_ab[t].reshape(-1,1,1,1)
smab = sqrt_1mab[t].reshape(-1,1,1,1)
return sab * xt - smab * v_pred
# 三种参数化的对比
T = 1000
betas = torch.linspace(1e-4, 0.02, T)
alphas_cumprod = torch.cumprod(1 - betas, 0)
sqrt_ab = alphas_cumprod.sqrt()
sqrt_1mab = (1 - alphas_cumprod).sqrt()
torch.manual_seed(42)
x0 = torch.randn(4, 3, 8, 8)
eps = torch.randn_like(x0)
t = torch.tensor([1, 250, 500, 750])
xt = sqrt_ab[t].reshape(-1,1,1,1) * x0 + sqrt_1mab[t].reshape(-1,1,1,1) * eps
# 计算三种参数化的目标(真实值,训练时作为标签)
v_target = compute_v_target(x0, eps, t, sqrt_ab, sqrt_1mab)
# 验证相互转换的一致性
x0_from_eps = eps_to_x0(eps, xt, t, sqrt_ab, sqrt_1mab)
x0_from_v = v_to_x0(v_target, xt, t, sqrt_ab, sqrt_1mab)
print("三种参数化方式的等价性验证:")
print(f" 从 ε 还原 x₀ 的误差:{(x0_from_eps - x0).abs().mean():.2e}")
print(f" 从 v 还原 x₀ 的误差:{(x0_from_v - x0).abs().mean():.2e}")
print(f" → 三者等价,误差为机器精度 ✓")
print()
# 各参数化在不同时间步的数值稳定性
print("各参数化目标在不同时间步的数值范围:")
print(f"{'t':^6} {'|ε|(预测目标)':^20} {'|x₀|(预测目标)':^20} {'|v|(预测目标)':^20}")
print("─" * 68)
for i, t_val in enumerate(t.tolist()):
eps_norm = eps[i].abs().mean().item()
x0_norm = x0[i].abs().mean().item()
v_norm = v_target[i].abs().mean().item()
print(f" {t_val:^6} {eps_norm:^20.4f} {x0_norm:^20.4f} {v_norm:^20.4f}")
print()
print("各参数化适用场景:")
comparisons = [
("预测 ε(epsilon)", "DDPM 原版,大多数模型默认选择,高噪声阶段稳定"),
("预测 x₀", "直观易理解,低噪声阶段可能方差大,Improved DDPM"),
("预测 v", "全程数值稳定,Imagen/SDv2/SVD,推荐生产使用"),
]
for name, desc in comparisons:
print(f" [{name}]:{desc}")
五、反向过程的采样公式
5.1 DDPM 采样(随机,1000步)
def ddpm_sample_step(
xt: torch.Tensor,
t: int,
eps_pred: torch.Tensor, # 网络预测的噪声 ε_θ(xₜ, t)
betas: torch.Tensor,
alphas_cumprod: torch.Tensor,
) -> torch.Tensor:
"""
DDPM 反向采样的单步:从 xₜ 得到 xₜ₋₁
xₜ₋₁ = (1/√αₜ) · (xₜ - βₜ/√(1-ᾱₜ) · ε_θ) + σₜ · z
注意:t 从 1 开始(不是 0),t=1 时不加随机噪声
"""
alpha_t = 1.0 - betas[t].item()
beta_t = betas[t].item()
ab_t = alphas_cumprod[t].item()
# 后验方差 σₜ²(这里用 β̃ₜ 近似,也可以用 β 本身)
if t > 0:
ab_t_prev = alphas_cumprod[t - 1].item()
# 后验方差(Improved DDPM 建议学习插值系数)
posterior_var = beta_t * (1 - ab_t_prev) / (1 - ab_t)
sigma_t = posterior_var ** 0.5
else:
sigma_t = 0.0
# 去噪均值:μ_θ(xₜ, t)
mu = (xt - beta_t / (1 - ab_t) ** 0.5 * eps_pred) / alpha_t ** 0.5
# 加入随机噪声(随机采样,保证多样性)
if sigma_t > 0:
z = torch.randn_like(xt)
xt_prev = mu + sigma_t * z
else:
xt_prev = mu # t=1 时不加噪声(最后一步)
return xt_prev
def ddpm_full_sample(
denoiser, # 去噪网络 ε_θ(xₜ, t)
shape: tuple, # 生成图像的形状
T: int = 1000,
betas: torch.Tensor = None,
device: str = "cpu",
verbose: bool = False,
) -> torch.Tensor:
"""
DDPM 完整采样(Algorithm 2)
从纯高斯噪声开始,逐步去噪 T 步
"""
if betas is None:
betas = torch.linspace(1e-4, 0.02, T).to(device)
alphas_cumprod = torch.cumprod(1 - betas, 0)
# 从纯噪声出发:x_T ~ N(0, I)
xt = torch.randn(*shape, device=device)
denoiser.eval()
with torch.no_grad():
for t in reversed(range(T)): # T-1, T-2, ..., 1, 0
# 时间步编码(网络输入)
t_tensor = torch.full((shape[0],), t, dtype=torch.long, device=device)
# 网络预测噪声
eps_pred = denoiser(xt, t_tensor)
# 去噪一步
xt = ddpm_sample_step(xt, t, eps_pred, betas, alphas_cumprod)
if verbose and t % 100 == 0:
print(f" t={t:4d}: mean={xt.mean():.4f}, std={xt.std():.4f}")
return xt
# 采样过程的统计分析(不用真实网络,用简单近似演示框架)
print("DDPM 采样过程统计(随机网络,仅演示框架):")
print()
print("采样步骤的关键参数:")
T = 10 # 演示用 10 步
betas = torch.linspace(0.01, 0.5, T)
alphas_cumprod = torch.cumprod(1 - betas, 0)
print(f"{'t':^6} {'αₜ':^10} {'βₜ':^10} {'后验方差σ²':^16} {'噪声强度σ':^14}")
print("─" * 60)
for t in range(T - 1, -1, -1):
alpha_t = (1 - betas[t]).item()
beta_t = betas[t].item()
ab_t = alphas_cumprod[t].item()
if t > 0:
ab_prev = alphas_cumprod[t - 1].item()
pvar = beta_t * (1 - ab_prev) / (1 - ab_t)
sigma = pvar ** 0.5
else:
pvar, sigma = 0.0, 0.0
print(f" {t:^6} {alpha_t:^10.4f} {beta_t:^10.4f} {pvar:^16.6f} {sigma:^14.6f}")
5.2 方差的选择:固定 vs 可学习
# DDPM 中方差有两种选择:
# β̃ₜ(后验方差下界)和 βₜ(先验方差上界)
# 原版 DDPM 固定其中之一
# Improved DDPM (Nichol & Dhariwal 2021) 建议学习插值系数
def learned_variance_parameterization():
"""
Improved DDPM:用网络额外输出方差的插值系数 v ∈ [0, 1]
Σ_θ(xₜ, t) = exp(v · log βₜ + (1-v) · log β̃ₜ)
"""
T = 1000
betas = torch.linspace(1e-4, 0.02, T)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, 0)
alphas_cumprod_prev = torch.cat([torch.ones(1), alphas_cumprod[:-1]])
# 上界(βₜ)和下界(β̃ₜ)
beta_upper = betas
beta_lower = betas * (1 - alphas_cumprod_prev) / (1 - alphas_cumprod)
beta_lower = beta_lower.clamp(min=1e-20)
# 两种固定方差在不同时间步的差异
print("固定方差的两种选择对比(对数尺度):")
print(f"{'时间步 t':^10} {'log β̃ₜ(下界)':^18} {'log βₜ(上界)':^18} {'差异':^10}")
print("─" * 58)
for t in [1, 100, 250, 500, 750, 999]:
lb = math.log(beta_lower[t].item())
ub = math.log(beta_upper[t].item())
print(f" {t:^10} {lb:^18.4f} {ub:^18.4f} {ub-lb:^10.4f}")
print()
print("→ 小 t 时两者差异大(学习方差很有价值)")
print("→ 大 t 时两者接近(接近纯噪声,方差选择影响小)")
print("→ Improved DDPM 学习插值系数,使网络自适应选择")
learned_variance_parameterization()
六、完整训练循环代码
class DDPMTrainer:
"""
DDPM 完整训练器
集成前向过程、损失计算、优化步骤
"""
def __init__(
self,
denoiser: nn.Module,
T: int = 1000,
schedule_type: str = "linear",
predict_target: str = "epsilon", # "epsilon" | "x0" | "v"
loss_type: str = "simple", # "simple" | "vlb" | "hybrid"
beta_start: float = 1e-4,
beta_end: float = 0.02,
device: str = "cpu",
):
self.denoiser = denoiser.to(device)
self.T = T
self.predict_target = predict_target
self.loss_type = loss_type
self.device = device
# ── 噪声调度 ──────────────────────────────────────────
if schedule_type == "linear":
betas = torch.linspace(beta_start, beta_end, T)
elif schedule_type == "cosine":
s = 0.008
t_arr = torch.linspace(0, T, T + 1)
f = torch.cos((t_arr / T + s) / (1 + s) * math.pi / 2) ** 2
ab = f / f[0]
betas = (1 - ab[1:] / ab[:-1]).clamp(1e-4, 0.9999)
else:
raise ValueError(f"Unknown schedule: {schedule_type}")
self.register_buffers(betas, device)
def register_buffers(self, betas: torch.Tensor, device: str):
"""预计算所有需要的统计量"""
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, 0)
alphas_cumprod_prev = torch.cat([torch.ones(1), alphas_cumprod[:-1]])
self.betas = betas.to(device)
self.alphas = alphas.to(device)
self.alphas_cumprod = alphas_cumprod.to(device)
self.alphas_cumprod_prev = alphas_cumprod_prev.to(device)
self.sqrt_ab = alphas_cumprod.sqrt().to(device)
self.sqrt_1mab = (1 - alphas_cumprod).sqrt().to(device)
self.sqrt_recip_a = (1 / alphas).sqrt().to(device)
self.posterior_variance = (
betas * (1 - alphas_cumprod_prev) / (1 - alphas_cumprod)
).clamp(min=1e-20).to(device)
def q_sample(self, x0: torch.Tensor, t: torch.Tensor, noise=None):
"""前向加噪:x₀ → xₜ"""
if noise is None:
noise = torch.randn_like(x0)
sab = self.sqrt_ab[t].reshape(-1, 1, 1, 1)
smab = self.sqrt_1mab[t].reshape(-1, 1, 1, 1)
return sab * x0 + smab * noise, noise
def get_target(self, x0: torch.Tensor, noise: torch.Tensor,
t: torch.Tensor) -> torch.Tensor:
"""根据参数化方式计算训练目标"""
if self.predict_target == "epsilon":
return noise
elif self.predict_target == "x0":
return x0
elif self.predict_target == "v":
sab = self.sqrt_ab[t].reshape(-1, 1, 1, 1)
smab = self.sqrt_1mab[t].reshape(-1, 1, 1, 1)
return sab * noise - smab * x0
else:
raise ValueError(f"Unknown predict_target: {self.predict_target}")
def compute_loss(
self,
x0: torch.Tensor,
noise: torch.Tensor = None,
) -> dict:
"""
计算训练损失
返回 dict 包含 loss 和调试信息
"""
B = x0.shape[0]
# 随机时间步
t = torch.randint(0, self.T, (B,), device=self.device)
# 前向加噪
if noise is None:
noise = torch.randn_like(x0)
xt, noise = self.q_sample(x0, t, noise)
# 网络预测
model_output = self.denoiser(xt, t)
# 目标
target = self.get_target(x0, noise, t)
# 简化 MSE 损失
if self.loss_type == "simple":
loss = F.mse_loss(model_output, target)
# 加权 MSE(按 SNR 权重)
elif self.loss_type == "weighted":
snr = self.alphas_cumprod[t] / (1 - self.alphas_cumprod[t])
weights = torch.minimum(snr, torch.full_like(snr, 5.0)) / snr
weights = weights.reshape(-1, 1, 1, 1)
loss = (weights * (model_output - target) ** 2).mean()
return {
"loss": loss,
"t": t.float().mean().item(), # 平均时间步(调试用)
"pred_norm": model_output.norm().item(),
"target_norm": target.norm().item(),
}
def train_step(
self,
x0: torch.Tensor,
optimizer: torch.optim.Optimizer,
) -> dict:
"""执行一步训练"""
self.denoiser.train()
optimizer.zero_grad()
metrics = self.compute_loss(x0)
metrics["loss"].backward()
# 梯度裁剪(稳定训练)
torch.nn.utils.clip_grad_norm_(self.denoiser.parameters(), max_norm=1.0)
optimizer.step()
return {k: v.item() if isinstance(v, torch.Tensor) else v
for k, v in metrics.items()}
# ── 测试训练器 ──────────────────────────────────────────────────
# 极简去噪网络(仅用于测试框架,实际用 U-Net,见第四篇)
class TinyDenoiser(nn.Module):
"""极简的去噪网络(测试用)"""
def __init__(self, channels: int = 3, hidden: int = 64):
super().__init__()
# 时间步嵌入
self.time_embed = nn.Sequential(
nn.Linear(1, hidden),
nn.SiLU(),
nn.Linear(hidden, hidden),
)
# 图像处理(极简版,实际是 U-Net)
self.net = nn.Sequential(
nn.Conv2d(channels, hidden, 3, padding=1),
nn.GroupNorm(8, hidden),
nn.SiLU(),
nn.Conv2d(hidden, hidden, 3, padding=1),
nn.GroupNorm(8, hidden),
nn.SiLU(),
nn.Conv2d(hidden, channels, 3, padding=1),
)
def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
# 时间步嵌入
t_emb = self.time_embed(t.float().unsqueeze(1) / 1000) # (B, hidden)
# 简化:将时间信息加到特征图(实际是通过 ResBlock 注入)
return self.net(x)
# 初始化和测试
denoiser = TinyDenoiser(channels=3, hidden=32)
trainer = DDPMTrainer(denoiser, T=1000, predict_target="epsilon")
optimizer = torch.optim.Adam(denoiser.parameters(), lr=2e-4)
# 模拟一批训练数据(随机图像,实际应该是真实图片)
x0_batch = torch.randn(8, 3, 32, 32) # 8 张 32×32 图像
print("DDPM 训练器测试:")
print()
for step in range(5):
metrics = trainer.train_step(x0_batch, optimizer)
print(f" 步骤 {step+1}: loss={metrics['loss']:.4f}, "
f"平均t={metrics['t']:.1f}, "
f"预测范数={metrics['pred_norm']:.4f}")
七、完整采样(推理)代码
class DDPMSampler:
"""
DDPM 完整采样器
支持不同的采样策略和可视化
"""
def __init__(
self,
denoiser: nn.Module,
T: int = 1000,
predict_target: str = "epsilon",
device: str = "cpu",
):
self.denoiser = denoiser.to(device)
self.T = T
self.predict_target = predict_target
self.device = device
# 噪声调度(与训练时保持一致!)
betas = torch.linspace(1e-4, 0.02, T).to(device)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, 0)
alphas_cumprod_prev = torch.cat([
torch.ones(1, device=device),
alphas_cumprod[:-1]
])
self.betas = betas
self.alphas = alphas
self.alphas_cumprod = alphas_cumprod
self.alphas_cumprod_prev = alphas_cumprod_prev
self.sqrt_ab = alphas_cumprod.sqrt()
self.sqrt_1mab = (1 - alphas_cumprod).sqrt()
self.posterior_variance = (
betas * (1 - alphas_cumprod_prev) / (1 - alphas_cumprod)
).clamp(1e-20)
def predict_x0_from_output(
self,
model_output: torch.Tensor,
xt: torch.Tensor,
t: torch.Tensor,
) -> torch.Tensor:
"""从网络输出还原 x̂₀(估计的原始图像)"""
sab = self.sqrt_ab[t].reshape(-1, 1, 1, 1)
smab = self.sqrt_1mab[t].reshape(-1, 1, 1, 1)
if self.predict_target == "epsilon":
# x̂₀ = (xₜ - √(1-ᾱₜ) · ε̂) / √ᾱₜ
x0_hat = (xt - smab * model_output) / sab
elif self.predict_target == "x0":
x0_hat = model_output
elif self.predict_target == "v":
# x̂₀ = √ᾱₜ · xₜ - √(1-ᾱₜ) · v̂
x0_hat = sab * xt - smab * model_output
else:
raise ValueError()
# 限制范围(防止极端值,通常图像在 [-1, 1])
return x0_hat.clamp(-1, 1)
def p_mean_variance(
self,
xt: torch.Tensor,
t: torch.Tensor,
) -> tuple:
"""计算 p_θ(xₜ₋₁|xₜ) 的均值和方差"""
model_output = self.denoiser(xt, t)
x0_hat = self.predict_x0_from_output(model_output, xt, t)
# 后验均值(用估计的 x̂₀ 代替真实 x₀)
ab_prev = self.alphas_cumprod_prev[t].reshape(-1, 1, 1, 1)
ab = self.alphas_cumprod[t].reshape(-1, 1, 1, 1)
b = self.betas[t].reshape(-1, 1, 1, 1)
a = self.alphas[t].reshape(-1, 1, 1, 1)
mean = ab_prev.sqrt() * b / (1 - ab) * x0_hat + \
a.sqrt() * (1 - ab_prev) / (1 - ab) * xt
var = self.posterior_variance[t].reshape(-1, 1, 1, 1)
log_var = var.log()
return mean, var, log_var, x0_hat
@torch.no_grad()
def p_sample(self, xt: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
"""单步采样:xₜ → xₜ₋₁"""
mean, var, _, _ = self.p_mean_variance(xt, t)
# t > 0 时加随机噪声,t = 0 时不加
noise = torch.zeros_like(xt)
mask = (t > 0).float().reshape(-1, 1, 1, 1)
noise = mask * torch.randn_like(xt)
return mean + var.sqrt() * noise
@torch.no_grad()
def sample(
self,
shape: tuple,
return_intermediates: bool = False,
log_every: int = 100,
) -> tuple:
"""
完整采样:从 x_T ~ N(0,I) 生成图像
"""
self.denoiser.eval()
B = shape[0]
device = self.device
# 从纯噪声出发
xt = torch.randn(*shape, device=device)
intermediates = [xt.cpu()] if return_intermediates else []
for i, t_val in enumerate(reversed(range(self.T))):
t = torch.full((B,), t_val, dtype=torch.long, device=device)
xt = self.p_sample(xt, t)
if return_intermediates and (t_val % log_every == 0):
intermediates.append(xt.cpu())
# 反归一化到 [0, 1]
image = (xt.clamp(-1, 1) + 1) / 2
if return_intermediates:
return image, intermediates
return image, []
def sample_progress_stats(self, shape: tuple) -> None:
"""展示采样过程中的统计量变化(可视化调试用)"""
self.denoiser.eval()
B = shape[0]
device = self.device
xt = torch.randn(*shape, device=device)
print("DDPM 采样过程统计量变化:")
print(f"{'时间步 t':^10} {'xₜ 均值':^12} {'xₜ 方差':^12} {'x̂₀ 均值':^12} {'x̂₀ 方差':^12}")
print("─" * 62)
log_steps = list(range(self.T - 1, -1, -1))[::self.T // 10]
with torch.no_grad():
for t_val in reversed(range(self.T)):
t = torch.full((B,), t_val, dtype=torch.long, device=device)
_, _, _, x0_hat = self.p_mean_variance(xt, t)
xt = self.p_sample(xt, t)
if t_val in log_steps:
print(f" {t_val:^10} {xt.mean():.4f} "
f"{xt.var():.4f} "
f"{x0_hat.mean():.4f} "
f"{x0_hat.var():.4f}")
# ── 测试采样器 ──────────────────────────────────────────────────
sampler = DDPMSampler(denoiser, T=100) # 用 100 步演示(完整是 1000 步)
sampler.sample_progress_stats(shape=(2, 3, 8, 8))
print()
# 完整采样(演示)
image, intermediates = sampler.sample(
shape=(4, 3, 8, 8),
return_intermediates=True,
log_every=20,
)
print(f"\n采样完成:")
print(f" 生成图像形状:{image.shape}")
print(f" 像素值范围:[{image.min():.4f}, {image.max():.4f}](应在 [0,1])")
print(f" 中间帧数量:{len(intermediates)}")
print()
print("完整流程总结:")
print(" 训练:随机 t → 加噪 → 预测噪声 → MSE 损失 → 反向传播")
print(" 推理:x_T ~ N(0,I) → 1000 步去噪 → x̂₀ 生成图像")
总结
| 概念 | 公式 | 含义 |
|---|---|---|
| 反向过程 | p θ ( x t − 1 ∣ x t ) = N ( μ θ , Σ θ ) p_\theta(x_{t-1}|x_t) = \mathcal{N}(\mu_\theta, \Sigma_\theta) pθ(xt−1∣xt)=N(μθ,Σθ) | 神经网络近似真实反向分布 |
| 后验分布 | q ( x t − 1 ∣ x t , x 0 ) = N ( μ ~ t , β ~ t I ) q(x_{t-1}|x_t, x_0) = \mathcal{N}(\tilde{\mu}_t, \tilde{\beta}_t I) q(xt−1∣xt,x0)=N(μ~t,β~tI) | 有解析形式,是网络学习的目标 |
| 训练目标 | L = ∣ ϵ − ϵ θ ( x t , t ) ∣ 2 \mathcal{L} = |\epsilon - \epsilon_\theta(x_t, t)|^2 L=∣ϵ−ϵθ(xt,t)∣2 | 从 ELBO 化简而来的 MSE |
| 均值参数化 | μ θ = 1 α t ( x t − β t 1 − α ˉ t ϵ θ ) \mu_\theta = \frac{1}{\sqrt{\alpha_t}}(x_t - \frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}}\epsilon_\theta) μθ=αt1(xt−1−αˉtβtϵθ) | 预测 ϵ \epsilon ϵ 等价于预测均值 |
| 采样公式 | x t − 1 = μ θ + σ t z x_{t-1} = \mu_\theta + \sigma_t z xt−1=μθ+σtz | 逐步去噪,加入随机性保证多样性 |
三种参数化方式的选择:
- 预测 ϵ \epsilon ϵ:DDPM 原版,大多数模型的默认选择
- 预测 x 0 x_0 x0:直观,但高噪声阶段方差大
- 预测 v v v:全程数值稳定,Stable Diffusion v2 / Imagen 的选择
下一篇预告:U-Net 架构——去噪网络的设计:为什么是 U-Net?跳跃连接的作用,时间步 t t t 如何注入网络,自注意力模块在哪里加,以及 Residual Block 的具体实现。
💬 从 ELBO 到简化 MSE 的推导过程,哪一步你觉得最关键? 欢迎评论区分享!
🙏 如果这篇帮到你,点赞 + 收藏,系列持续更新!
本文为原创技术分享。代码在 Python 3.12 + PyTorch 2.x 下验证。最后更新:2026-05-20
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐

所有评论(0)