现在讲DDPM的教程太多了,这里就不重点说了,作为一个回顾,先过一遍。


1. 前向过程 (Forward Process / Diffusion Process)

前向过程是一个固定的马尔可夫链,它通过逐步向原始数据 x 0 x_0 x0 添加高斯噪声,最终将其破坏为一个纯高斯噪声 x T x_T xT

  • 单步状态转移:在每一步 t t t,我们添加由方差表 (Variance Schedule) β t \beta_t βt 控制的微小噪声:

    q ( x t ∣ x t − 1 ) = N ( x t ; 1 − β t x t − 1 , β t I ) q(x_t | x_{t-1}) = \mathcal{N}(x_t; \sqrt{1-\beta_t}x_{t-1}, \beta_t I) q(xtxt1)=N(xt;1βt xt1,βtI)

    (注: β 1 < β 2 < ⋯ < β T \beta_1 < \beta_2 < \dots < \beta_T β1<β2<<βT,通常通过线性或余弦调度来设定。)

  • 任意步状态的直接采样 (重参数化技巧)

    为了避免在训练时一步步迭代,DDPM 利用独立同分布高斯变量相加的性质,实现了从 x 0 x_0 x0 直接采样出任意时刻的 x t x_t xt

    α t = 1 − β t \alpha_t = 1 - \beta_t αt=1βt α ˉ t = ∏ i = 1 t α i \bar{\alpha}_t = \prod_{i=1}^t \alpha_i αˉt=i=1tαi,则有:

    q ( x t ∣ x 0 ) = N ( x t ; α ˉ t x 0 , ( 1 − α ˉ t ) I ) q(x_t | x_0) = \mathcal{N}(x_t; \sqrt{\bar{\alpha}_t}x_0, (1-\bar{\alpha}_t)I) q(xtx0)=N(xt;αˉt x0,(1αˉt)I)

    写成等式即为:

    x t = α ˉ t x 0 + 1 − α ˉ t ϵ x_t = \sqrt{\bar{\alpha}_t}x_0 + \sqrt{1-\bar{\alpha}_t}\epsilon xt=αˉt x0+1αˉt ϵ

    其中 ϵ ∼ N ( 0 , I ) \epsilon \sim \mathcal{N}(0, I) ϵN(0,I)。这个公式是后续模型训练的基石。

import torch

# 假设总步数 T 为 1000
T = 1000
beta_start = 1e-4
beta_end = 0.02

# 1. 定义 betas (\beta_t)
betas = torch.linspace(beta_start, beta_end, T)

# 2. 定义 alphas (\alpha_t = 1 - \beta_t)
alphas = 1.0 - betas

# 3. 定义 alphas_cumprod (\bar{\alpha}_t)
alphas_cumprod = torch.cumprod(alphas, dim=0)

# 辅助函数:由于输入是 Batch,我们需要把一维的参数广播到图片的维度 (B, C, H, W)
def extract(a, t, x_shape):
    batch_size = t.shape[0]
    out = a.gather(-1, t.cpu()) # 取出对应时间步的参数
    return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)
def q_sample(x_0, t, noise=None):
    """
    x_0: 真实图片数据, shape (B, C, H, W)
    t: 随机采样的时间步, shape (B,)
    noise: 采样的标准高斯噪声 \epsilon
    """
    if noise is None:
        noise = torch.randn_like(x_0)

    # 提取 \sqrt{\bar{\alpha}_t}
    sqrt_alphas_cumprod_t = extract(torch.sqrt(alphas_cumprod), t, x_0.shape)
    
    # 提取 \sqrt{1 - \bar{\alpha}_t}
    sqrt_one_minus_alphas_cumprod_t = extract(
        torch.sqrt(1. - alphas_cumprod), t, x_0.shape
    )

    # 核心等式:一步到位生成 x_t
    x_t = sqrt_alphas_cumprod_t * x_0 + sqrt_one_minus_alphas_cumprod_t * noise
    
    return x_t

2. 模型与训练

抛开一切推导不谈的话,其实看最后的结论,非常简单

前向过程加了什么噪声 ϵ \epsilon ϵ,网络 ϵ θ \epsilon_\theta ϵθ 就去预测什么噪声

整个去噪的模型就是一个U-net,现在看来有些过时了,以后单独开一页来写,这里跳过算了。

import torch.nn.functional as F

def p_losses(denoise_model, x_0, t, noise=None):
    if noise is None:
        noise = torch.randn_like(x_0) # 真实的 \epsilon,这也是我们的 Ground Truth!

    # 1. 制造 x_t
    x_t = q_sample(x_0=x_0, t=t, noise=noise)

    # 2. 扔给 U-Net 去预测噪声
    # U-Net 必须同时接收噪声图 x_t 和时间步 t (以利用 Time Embedding)
    predicted_noise = denoise_model(x_t, t)

    # 3. 计算 MSE 损失
    loss = F.mse_loss(noise, predicted_noise)
    
    return loss

训练

from torch.optim import Adam

device = "cuda" if torch.cuda.is_available() else "cpu"

# 假设你的 U-Net 已经搭好,并且支持接受时间步 t 作为输入
model = UNet(...).to(device) 
optimizer = Adam(model.parameters(), lr=1e-4)

epochs = 100
for epoch in range(epochs):
    for batch in dataloader: 
        # dataloader 返回的 batch 就是干净的真实图片 x_0
        x_0 = batch.to(device)
        batch_size = x_0.shape[0]

        optimizer.zero_grad()

        # 核心:为 Batch 中的每一张图片独立随机采样一个时间步 t ~ Uniform(0, T-1)
        t = torch.randint(0, T, (batch_size,), device=device).long()

        # 计算 Loss
        loss = p_losses(model, x_0, t)

        # 反向传播和更新
        loss.backward()
        optimizer.step()

3. 逆向过程 (Reverse Process / Denoising Process)

逆向过程的目标是从标准高斯分布 x T ∼ N ( 0 , I ) x_T \sim \mathcal{N}(0, I) xTN(0,I) 开始,逐步去噪还原出真实数据 x 0 x_0 x0

  • 逆向转移概率:由于真实的逆向分布 q ( x t − 1 ∣ x t ) q(x_{t-1}|x_t) q(xt1xt) 无法直接计算,DDPM 使用一个神经网络 p θ p_\theta pθ 来近似它。当 β t \beta_t βt 足够小时,逆向过程也可以近似为高斯分布:

    p θ ( x t − 1 ∣ x t ) = N ( x t − 1 ; μ θ ( x t , t ) , Σ θ ( x t , t ) ) p_\theta(x_{t-1}|x_t) = \mathcal{N}(x_{t-1}; \mu_\theta(x_t, t), \Sigma_\theta(x_t, t)) pθ(xt1xt)=N(xt1;μθ(xt,t),Σθ(xt,t))

  • 后验分布的解析解:虽然 q ( x t − 1 ∣ x t ) q(x_{t-1}|x_t) q(xt1xt) 不可解,但如果我们知道 x 0 x_0 x0,其条件后验分布 q ( x t − 1 ∣ x t , x 0 ) q(x_{t-1}|x_t, x_0) q(xt1xt,x0) 是可以通过贝叶斯定理求出解析解的:

    q ( x t − 1 ∣ x t , x 0 ) = N ( x t − 1 ; μ ~ t ( x t , x 0 ) , β ~ t I ) q(x_{t-1}|x_t, x_0) = \mathcal{N}(x_{t-1}; \tilde{\mu}_t(x_t, x_0), \tilde{\beta}_t I) q(xt1xt,x0)=N(xt1;μ~t(xt,x0),β~tI)

    其中,均值 μ ~ t ( x t , x 0 ) = α t ( 1 − α ˉ t − 1 ) 1 − α ˉ t x t + α ˉ t − 1 β t 1 − α ˉ t x 0 \tilde{\mu}t(x_t, x_0) = \frac{\sqrt{\alpha_t}(1-\bar{\alpha}_{t-1})}{1-\bar{\alpha}_t}x_t + \frac{\sqrt{\bar{\alpha}{t-1}}\beta_t}{1-\bar{\alpha}_t}x_0 μ~t(xt,x0)=1αˉtαt (1αˉt1)xt+1αˉtαˉt1 βtx0

反向采样的目标是从 x T ∼ N ( 0 , I ) x_T \sim \mathcal{N}(0, I) xTN(0,I) 开始,逐步推导 x t − 1 x_{t-1} xt1,直到 x 0 x_0 x0。在这个过程中,我们的 U-Net ϵ θ ( x t , t ) \epsilon_\theta(x_t, t) ϵθ(xt,t) 会预测出当前图像中的噪声,然后我们用数学公式把它“减”掉。

根据 Ho et al. 的论文,每一步的具体推导公式为:

x t − 1 = 1 α t ( x t − 1 − α t 1 − α ˉ t ϵ θ ( x t , t ) ) + σ t z x_{t-1} = \frac{1}{\sqrt{\alpha_t}} \left( x_t - \frac{1 - \alpha_t}{\sqrt{1 - \bar{\alpha}_t}} \epsilon_\theta(x_t, t) \right) + \sigma_t z xt1=αt 1(xt1αˉt 1αtϵθ(xt,t))+σtz

其中:

  • z ∼ N ( 0 , I ) z \sim \mathcal{N}(0, I) zN(0,I)注意:除了最后一步 t=1 时 z=0 外,其他时候我们都要加回这个随机噪声。
  • σ t \sigma_t σt 是反向过程的方差,论文中通常取 σ t = β t \sigma_t = \sqrt{\beta_t} σt=βt 或者真实后验方差 β ~ t \tilde{\beta}_t β~t

我们将这个过程拆分为两个函数:一个是执行单步去噪的 p_sample,另一个是控制整个循环的 p_sample_loop。这里会用到我们之前写的 extract 函数来对齐维度。

import torch

# 假设我们在外部已经定义好了 T, betas, alphas, alphas_cumprod 等全局常量
# 为了公式计算方便,我们需要预计算几个额外的系数:
sqrt_recip_alphas = torch.sqrt(1.0 / alphas) # 公式里的 1 / \sqrt{\alpha_t}
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod) # 公式里的 \sqrt{1 - \bar{\alpha}_t}

@torch.no_grad()
def p_sample(model, x_t, t, t_index):
    """
    单步去噪:从 x_t 推导 x_{t-1}
    t_index 是一个整数,用来判断是否到了最后一步 (t=0)
    """
    # 1. 提取当前时间步 t 对应的标量系数,并撑开维度以便广播计算
    betas_t = extract(betas, t, x_t.shape)
    sqrt_one_minus_alphas_cumprod_t = extract(sqrt_one_minus_alphas_cumprod, t, x_t.shape)
    sqrt_recip_alphas_t = extract(sqrt_recip_alphas, t, x_t.shape)
    alphas_t = extract(alphas, t, x_t.shape)

    # 2. 调用 U-Net 预测当前图片 x_t 中蕴含的噪声 \epsilon
    pred_noise = model(x_t, t)
    
    # 3. 计算公式中的均值部分 (Model Mean)
    # \frac{1}{\sqrt{\alpha_t}} * (x_t - \frac{1 - \alpha_t}{\sqrt{1 - \bar{\alpha}_t}} * pred_noise)
    model_mean = sqrt_recip_alphas_t * (x_t - betas_t * pred_noise / sqrt_one_minus_alphas_cumprod_t)

    # 4. 如果是最后一步 (t==0),就不加噪声了,直接返回均值;否则加上方差 \sigma_t * z
    if t_index == 0:
        return model_mean
    else:
        # 采样随机噪声 z ~ N(0, I)
        z = torch.randn_like(x_t)
        # 这里的方差通常用 \beta_t
        posterior_variance_t = extract(betas, t, x_t.shape) 
        return model_mean + torch.sqrt(posterior_variance_t) * z

@torch.no_grad()
def p_sample_loop(model, shape):
    """
    完整的反向采样循环
    shape: 想要生成的图像形状,例如 (Batch_Size, 3, 64, 64)
    """
    device = next(model.parameters()).device
    b = shape[0]

    # 1. 从纯高斯噪声 x_T 开始
    img = torch.randn(shape, device=device)
    imgs = [] # 可以用来收集中间过程的图像,做成 GIF

    # 2. 从 T 倒序循环到 1 (在代码索引里是 T-1 到 0)
    import tqdm
    for i in tqdm.tqdm(reversed(range(0, T)), desc='Sampling loop time step', total=T):
        # 构造一个形状为 (B,) 的张量,里面全是当前的时间步 i
        t = torch.full((b,), i, device=device, dtype=torch.long)
        
        # 单步去噪更新 img
        img = p_sample(model, img, t, i)

    return img # 返回最终生成的 x_0

  • 为什么要加回噪声 z?(Langevin Dynamics)

    如果把最后那个 + torch.sqrt(...) * z 去掉,生成过程就变成了确定性的(Deterministic)。在 DDPM 的理论下,这会导致生成的图像极其平滑且模糊,失去高频细节,甚至所有初始噪声都会坍缩到同一个平均图像上。加回噪声相当于在解空间中注入扰动,让模型能探索真实数据分布的多样性,这在物理上对应着朗之万动力学(Langevin Dynamics)。

    当推导到 x_0 时,任何额外的噪声都会直接成为最终图像的噪点,破坏生成质量,因此最后一步只取期望(Mean)。

  • 采样的昂贵代价

    p_sample_loop 里的 for i in reversed(range(0, T))。如果 T=1000,意味着生成一张图片,你需要让这庞大的 U-Net 执行 1000 次前向传播(Forward Pass)。这也就是为什么原始 DDPM 采样极其缓慢的工程根本原因。

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐