【DDPM】——扩散模型启程
现在讲DDPM的教程太多了,这里就不重点说了,作为一个回顾,先过一遍。
文章目录
1. 前向过程 (Forward Process / Diffusion Process)
前向过程是一个固定的马尔可夫链,它通过逐步向原始数据 x 0 x_0 x0 添加高斯噪声,最终将其破坏为一个纯高斯噪声 x T x_T xT。
-
单步状态转移:在每一步 t t t,我们添加由方差表 (Variance Schedule) β t \beta_t βt 控制的微小噪声:
q ( x t ∣ x t − 1 ) = N ( x t ; 1 − β t x t − 1 , β t I ) q(x_t | x_{t-1}) = \mathcal{N}(x_t; \sqrt{1-\beta_t}x_{t-1}, \beta_t I) q(xt∣xt−1)=N(xt;1−βtxt−1,βtI)
(注: β 1 < β 2 < ⋯ < β T \beta_1 < \beta_2 < \dots < \beta_T β1<β2<⋯<βT,通常通过线性或余弦调度来设定。)
-
任意步状态的直接采样 (重参数化技巧):
为了避免在训练时一步步迭代,DDPM 利用独立同分布高斯变量相加的性质,实现了从 x 0 x_0 x0 直接采样出任意时刻的 x t x_t xt。
令 α t = 1 − β t \alpha_t = 1 - \beta_t αt=1−βt, α ˉ t = ∏ i = 1 t α i \bar{\alpha}_t = \prod_{i=1}^t \alpha_i αˉt=∏i=1tαi,则有:
q ( x t ∣ x 0 ) = N ( x t ; α ˉ t x 0 , ( 1 − α ˉ t ) I ) q(x_t | x_0) = \mathcal{N}(x_t; \sqrt{\bar{\alpha}_t}x_0, (1-\bar{\alpha}_t)I) q(xt∣x0)=N(xt;αˉtx0,(1−αˉt)I)
写成等式即为:
x t = α ˉ t x 0 + 1 − α ˉ t ϵ x_t = \sqrt{\bar{\alpha}_t}x_0 + \sqrt{1-\bar{\alpha}_t}\epsilon xt=αˉtx0+1−αˉtϵ
其中 ϵ ∼ N ( 0 , I ) \epsilon \sim \mathcal{N}(0, I) ϵ∼N(0,I)。这个公式是后续模型训练的基石。
import torch
# 假设总步数 T 为 1000
T = 1000
beta_start = 1e-4
beta_end = 0.02
# 1. 定义 betas (\beta_t)
betas = torch.linspace(beta_start, beta_end, T)
# 2. 定义 alphas (\alpha_t = 1 - \beta_t)
alphas = 1.0 - betas
# 3. 定义 alphas_cumprod (\bar{\alpha}_t)
alphas_cumprod = torch.cumprod(alphas, dim=0)
# 辅助函数:由于输入是 Batch,我们需要把一维的参数广播到图片的维度 (B, C, H, W)
def extract(a, t, x_shape):
batch_size = t.shape[0]
out = a.gather(-1, t.cpu()) # 取出对应时间步的参数
return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)
def q_sample(x_0, t, noise=None):
"""
x_0: 真实图片数据, shape (B, C, H, W)
t: 随机采样的时间步, shape (B,)
noise: 采样的标准高斯噪声 \epsilon
"""
if noise is None:
noise = torch.randn_like(x_0)
# 提取 \sqrt{\bar{\alpha}_t}
sqrt_alphas_cumprod_t = extract(torch.sqrt(alphas_cumprod), t, x_0.shape)
# 提取 \sqrt{1 - \bar{\alpha}_t}
sqrt_one_minus_alphas_cumprod_t = extract(
torch.sqrt(1. - alphas_cumprod), t, x_0.shape
)
# 核心等式:一步到位生成 x_t
x_t = sqrt_alphas_cumprod_t * x_0 + sqrt_one_minus_alphas_cumprod_t * noise
return x_t
2. 模型与训练
抛开一切推导不谈的话,其实看最后的结论,非常简单
前向过程加了什么噪声 ϵ \epsilon ϵ,网络 ϵ θ \epsilon_\theta ϵθ 就去预测什么噪声
整个去噪的模型就是一个U-net,现在看来有些过时了,以后单独开一页来写,这里跳过算了。
import torch.nn.functional as F
def p_losses(denoise_model, x_0, t, noise=None):
if noise is None:
noise = torch.randn_like(x_0) # 真实的 \epsilon,这也是我们的 Ground Truth!
# 1. 制造 x_t
x_t = q_sample(x_0=x_0, t=t, noise=noise)
# 2. 扔给 U-Net 去预测噪声
# U-Net 必须同时接收噪声图 x_t 和时间步 t (以利用 Time Embedding)
predicted_noise = denoise_model(x_t, t)
# 3. 计算 MSE 损失
loss = F.mse_loss(noise, predicted_noise)
return loss
训练
from torch.optim import Adam
device = "cuda" if torch.cuda.is_available() else "cpu"
# 假设你的 U-Net 已经搭好,并且支持接受时间步 t 作为输入
model = UNet(...).to(device)
optimizer = Adam(model.parameters(), lr=1e-4)
epochs = 100
for epoch in range(epochs):
for batch in dataloader:
# dataloader 返回的 batch 就是干净的真实图片 x_0
x_0 = batch.to(device)
batch_size = x_0.shape[0]
optimizer.zero_grad()
# 核心:为 Batch 中的每一张图片独立随机采样一个时间步 t ~ Uniform(0, T-1)
t = torch.randint(0, T, (batch_size,), device=device).long()
# 计算 Loss
loss = p_losses(model, x_0, t)
# 反向传播和更新
loss.backward()
optimizer.step()
3. 逆向过程 (Reverse Process / Denoising Process)
逆向过程的目标是从标准高斯分布 x T ∼ N ( 0 , I ) x_T \sim \mathcal{N}(0, I) xT∼N(0,I) 开始,逐步去噪还原出真实数据 x 0 x_0 x0。
-
逆向转移概率:由于真实的逆向分布 q ( x t − 1 ∣ x t ) q(x_{t-1}|x_t) q(xt−1∣xt) 无法直接计算,DDPM 使用一个神经网络 p θ p_\theta pθ 来近似它。当 β t \beta_t βt 足够小时,逆向过程也可以近似为高斯分布:
p θ ( x t − 1 ∣ x t ) = N ( x t − 1 ; μ θ ( x t , t ) , Σ θ ( x t , t ) ) p_\theta(x_{t-1}|x_t) = \mathcal{N}(x_{t-1}; \mu_\theta(x_t, t), \Sigma_\theta(x_t, t)) pθ(xt−1∣xt)=N(xt−1;μθ(xt,t),Σθ(xt,t))
-
后验分布的解析解:虽然 q ( x t − 1 ∣ x t ) q(x_{t-1}|x_t) q(xt−1∣xt) 不可解,但如果我们知道 x 0 x_0 x0,其条件后验分布 q ( x t − 1 ∣ x t , x 0 ) q(x_{t-1}|x_t, x_0) q(xt−1∣xt,x0) 是可以通过贝叶斯定理求出解析解的:
q ( x t − 1 ∣ x t , x 0 ) = N ( x t − 1 ; μ ~ t ( x t , x 0 ) , β ~ t I ) q(x_{t-1}|x_t, x_0) = \mathcal{N}(x_{t-1}; \tilde{\mu}_t(x_t, x_0), \tilde{\beta}_t I) q(xt−1∣xt,x0)=N(xt−1;μ~t(xt,x0),β~tI)
其中,均值 μ ~ t ( x t , x 0 ) = α t ( 1 − α ˉ t − 1 ) 1 − α ˉ t x t + α ˉ t − 1 β t 1 − α ˉ t x 0 \tilde{\mu}t(x_t, x_0) = \frac{\sqrt{\alpha_t}(1-\bar{\alpha}_{t-1})}{1-\bar{\alpha}_t}x_t + \frac{\sqrt{\bar{\alpha}{t-1}}\beta_t}{1-\bar{\alpha}_t}x_0 μ~t(xt,x0)=1−αˉtαt(1−αˉt−1)xt+1−αˉtαˉt−1βtx0。
反向采样的目标是从 x T ∼ N ( 0 , I ) x_T \sim \mathcal{N}(0, I) xT∼N(0,I) 开始,逐步推导 x t − 1 x_{t-1} xt−1,直到 x 0 x_0 x0。在这个过程中,我们的 U-Net ϵ θ ( x t , t ) \epsilon_\theta(x_t, t) ϵθ(xt,t) 会预测出当前图像中的噪声,然后我们用数学公式把它“减”掉。
根据 Ho et al. 的论文,每一步的具体推导公式为:
x t − 1 = 1 α t ( x t − 1 − α t 1 − α ˉ t ϵ θ ( x t , t ) ) + σ t z x_{t-1} = \frac{1}{\sqrt{\alpha_t}} \left( x_t - \frac{1 - \alpha_t}{\sqrt{1 - \bar{\alpha}_t}} \epsilon_\theta(x_t, t) \right) + \sigma_t z xt−1=αt1(xt−1−αˉt1−αtϵθ(xt,t))+σtz
其中:
- z ∼ N ( 0 , I ) z \sim \mathcal{N}(0, I) z∼N(0,I)。注意:除了最后一步 t=1 时 z=0 外,其他时候我们都要加回这个随机噪声。
- σ t \sigma_t σt 是反向过程的方差,论文中通常取 σ t = β t \sigma_t = \sqrt{\beta_t} σt=βt 或者真实后验方差 β ~ t \tilde{\beta}_t β~t。
我们将这个过程拆分为两个函数:一个是执行单步去噪的 p_sample,另一个是控制整个循环的 p_sample_loop。这里会用到我们之前写的 extract 函数来对齐维度。
import torch
# 假设我们在外部已经定义好了 T, betas, alphas, alphas_cumprod 等全局常量
# 为了公式计算方便,我们需要预计算几个额外的系数:
sqrt_recip_alphas = torch.sqrt(1.0 / alphas) # 公式里的 1 / \sqrt{\alpha_t}
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod) # 公式里的 \sqrt{1 - \bar{\alpha}_t}
@torch.no_grad()
def p_sample(model, x_t, t, t_index):
"""
单步去噪:从 x_t 推导 x_{t-1}
t_index 是一个整数,用来判断是否到了最后一步 (t=0)
"""
# 1. 提取当前时间步 t 对应的标量系数,并撑开维度以便广播计算
betas_t = extract(betas, t, x_t.shape)
sqrt_one_minus_alphas_cumprod_t = extract(sqrt_one_minus_alphas_cumprod, t, x_t.shape)
sqrt_recip_alphas_t = extract(sqrt_recip_alphas, t, x_t.shape)
alphas_t = extract(alphas, t, x_t.shape)
# 2. 调用 U-Net 预测当前图片 x_t 中蕴含的噪声 \epsilon
pred_noise = model(x_t, t)
# 3. 计算公式中的均值部分 (Model Mean)
# \frac{1}{\sqrt{\alpha_t}} * (x_t - \frac{1 - \alpha_t}{\sqrt{1 - \bar{\alpha}_t}} * pred_noise)
model_mean = sqrt_recip_alphas_t * (x_t - betas_t * pred_noise / sqrt_one_minus_alphas_cumprod_t)
# 4. 如果是最后一步 (t==0),就不加噪声了,直接返回均值;否则加上方差 \sigma_t * z
if t_index == 0:
return model_mean
else:
# 采样随机噪声 z ~ N(0, I)
z = torch.randn_like(x_t)
# 这里的方差通常用 \beta_t
posterior_variance_t = extract(betas, t, x_t.shape)
return model_mean + torch.sqrt(posterior_variance_t) * z
@torch.no_grad()
def p_sample_loop(model, shape):
"""
完整的反向采样循环
shape: 想要生成的图像形状,例如 (Batch_Size, 3, 64, 64)
"""
device = next(model.parameters()).device
b = shape[0]
# 1. 从纯高斯噪声 x_T 开始
img = torch.randn(shape, device=device)
imgs = [] # 可以用来收集中间过程的图像,做成 GIF
# 2. 从 T 倒序循环到 1 (在代码索引里是 T-1 到 0)
import tqdm
for i in tqdm.tqdm(reversed(range(0, T)), desc='Sampling loop time step', total=T):
# 构造一个形状为 (B,) 的张量,里面全是当前的时间步 i
t = torch.full((b,), i, device=device, dtype=torch.long)
# 单步去噪更新 img
img = p_sample(model, img, t, i)
return img # 返回最终生成的 x_0
-
为什么要加回噪声 z?(Langevin Dynamics)
如果把最后那个
+ torch.sqrt(...) * z去掉,生成过程就变成了确定性的(Deterministic)。在 DDPM 的理论下,这会导致生成的图像极其平滑且模糊,失去高频细节,甚至所有初始噪声都会坍缩到同一个平均图像上。加回噪声相当于在解空间中注入扰动,让模型能探索真实数据分布的多样性,这在物理上对应着朗之万动力学(Langevin Dynamics)。当推导到 x_0 时,任何额外的噪声都会直接成为最终图像的噪点,破坏生成质量,因此最后一步只取期望(Mean)。
-
采样的昂贵代价:
看
p_sample_loop里的for i in reversed(range(0, T))。如果 T=1000,意味着生成一张图片,你需要让这庞大的 U-Net 执行 1000 次前向传播(Forward Pass)。这也就是为什么原始 DDPM 采样极其缓慢的工程根本原因。
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐



所有评论(0)