环境声明

  • Python 版本Python 3.10+ (建议使用 3.10 以上版本)
  • 深度学习框架PyTorch 2.0+
  • GPU 支持:CUDA 11.8+ (推荐,用于加速训练)
  • 操作系统Windows / macOS / Linux (通用)
  • 依赖库torch, torchvision, numpy, matplotlib, tqdm, einops

学习目标和摘要

学习目标

本章将带领读者深入理解扩散模型的核心原理与实现。通过本章学习,你将能够:

  1. 理解扩散模型的前向过程与反向过程的数学原理
  2. 掌握DDPM(去噪扩散概率模型)的完整推导与实现
  3. 理解DDIM加速采样的核心思想
  4. 掌握条件扩散模型与引导采样技术(Classifier Guidance、CFG)
  5. 深入理解Stable Diffusion的潜在扩散模型架构
  6. 了解扩散模型在图像、音频、视频生成中的最新应用
  7. 理解扩散模型与流模型(Flow Matching)的联系
  8. 能够独立实现一个完整的扩散模型

文章摘要

扩散模型(Diffusion Models)是近年来生成式AI领域最重要的突破之一。从2020年DDPM的提出,到2022年Stable Diffusion的开源,再到2024年Diffusion Transformer(DiT)和Flow Matching的兴起,扩散模型经历了飞速发展。本章将从数学原理出发,系统讲解扩散模型的前向加噪过程、反向去噪过程,深入剖析DDPM、DDIM等经典算法,并介绍Stable Diffusion的潜在扩散模型架构。同时,我们将探讨2024-2025年的最新进展,包括SDXL、SD3、DiT以及Flow Matching等前沿技术。


1. 扩散模型概述

1.1 什么是扩散模型

扩散模型是一类生成模型,其核心思想来源于物理学中的扩散现象。想象一滴墨水滴入清水中:墨水分子会逐渐扩散,最终均匀分布在整个水体中。扩散模型正是借鉴了这一思想:

  • 前向过程(Forward Process):向原始数据逐步添加噪声,直到数据完全变成随机噪声
  • 反向过程(Reverse Process):学习如何从噪声中逐步恢复出原始数据

扩散模型的本质:通过模拟数据的"逐渐损坏"过程,并学习如何逆转这一过程来实现数据生成。

1.2 扩散模型的发展历程

时间 模型/技术 贡献
2015 扩散概率模型 Sohl-Dickstein等人首次提出扩散模型概念
2020 DDPM Ho等人改进训练目标,使扩散模型实用化
2021 DDIM Song等人提出确定性采样,加速推理
2021 LDM Rombach等人提出潜在扩散模型,降低计算成本
2022 Stable Diffusion Stability AI开源LDM,引发AIGC热潮
2022 DALL-E 2 OpenAI发布基于扩散的文本到图像模型
2023 SDXL 更高分辨率、更强语义的图像生成
2024 SD3 / DiT 采用Transformer架构,Flow Matching训练目标
2024 Flux Black Forest Labs发布,进一步提升生成质量
2025 视频扩散模型 Sora、可灵等视频生成模型兴起

1.3 扩散模型与其他生成模型的对比

特性 GAN VAE 扩散模型 流模型
训练稳定性 较差(模式崩溃) 较好
生成质量 中等 很高
训练速度 中等
采样速度 快(单次前向) 慢(多步迭代) 中等
似然估计 困难 变分下界 精确 精确
条件生成 需要特殊设计 较容易 容易 较容易
可编辑性 有限 较好 很好 较好

2. 前向扩散过程

2.1 前向过程的定义

前向过程是一个马尔可夫链,逐步向原始数据添加高斯噪声。给定原始数据 x0x_0x0,前向过程定义如下:

q(xt∣xt−1)=N(xt;1−βtxt−1,βtI)q(x_t | x_{t-1}) = \mathcal{N}(x_t; \sqrt{1-\beta_t} x_{t-1}, \beta_t I)q(xtxt1)=N(xt;1βt xt1,βtI)

其中:

  • x0x_0x0 是原始数据(如图像)
  • xtx_txt 是第 ttt 步加噪后的数据
  • βt∈(0,1)\beta_t \in (0, 1)βt(0,1) 是噪声调度参数,控制每步添加的噪声量
  • TTT 是总步数,通常取1000

2.2 重参数化技巧

由于前向过程是马尔可夫链,我们可以利用重参数化技巧直接计算任意时刻 tttxtx_txt,而不需要逐步计算:

αt=1−βt\alpha_t = 1 - \beta_tαt=1βtαˉt=∏s=1tαs\bar{\alpha}_t = \prod_{s=1}^{t} \alpha_sαˉt=s=1tαs,则有:

q(xt∣x0)=N(xt;αˉtx0,(1−αˉt)I)q(x_t | x_0) = \mathcal{N}(x_t; \sqrt{\bar{\alpha}_t} x_0, (1-\bar{\alpha}_t) I)q(xtx0)=N(xt;αˉt x0,(1αˉt)I)

这意味着我们可以直接从 x0x_0x0 采样得到任意时刻 ttt 的加噪数据:

xt=αˉtx0+1−αˉtϵ,ϵ∼N(0,I)x_t = \sqrt{\bar{\alpha}_t} x_0 + \sqrt{1-\bar{\alpha}_t} \epsilon, \quad \epsilon \sim \mathcal{N}(0, I)xt=αˉt x0+1αˉt ϵ,ϵN(0,I)

2.3 噪声调度策略

噪声调度 βt\beta_tβt 的选择对模型性能有重要影响。常见的调度策略包括:

线性调度
βt=βmin+tT(βmax−βmin)\beta_t = \beta_{min} + \frac{t}{T}(\beta_{max} - \beta_{min})βt=βmin+Tt(βmaxβmin)

余弦调度( improved ):
αˉt=f(t)f(0),f(t)=cos⁡(t/T+s1+s⋅π2)2\bar{\alpha}_t = \frac{f(t)}{f(0)}, \quad f(t) = \cos\left(\frac{t/T + s}{1+s} \cdot \frac{\pi}{2}\right)^2αˉt=f(0)f(t),f(t)=cos(1+st/T+s2π)2

余弦调度在训练后期保持更多的信号,通常能获得更好的生成质量。

2.4 前向过程的直观理解

可以把前向过程想象成:

  1. 初始状态(t=0t=0t=0):清晰的原始图像
  2. 中间状态(t≈T/2t \approx T/2tT/2):模糊的、带有部分结构的噪声图像
  3. 最终状态(t=Tt=Tt=T):纯高斯噪声,没有任何可辨识的结构

ttt 接近 TTT 时,αˉt≈0\bar{\alpha}_t \approx 0αˉt0xTx_TxT 几乎完全是随机噪声,与原始数据 x0x_0x0 无关。


3. 反向去噪过程

3.1 反向过程的定义

反向过程的目标是从噪声 xTx_TxT 逐步恢复出原始数据 x0x_0x0。理论上,反向过程也是高斯分布:

q(xt−1∣xt,x0)=N(xt−1;μ~t(xt,x0),β~tI)q(x_{t-1} | x_t, x_0) = \mathcal{N}(x_{t-1}; \tilde{\mu}_t(x_t, x_0), \tilde{\beta}_t I)q(xt1xt,x0)=N(xt1;μ~t(xt,x0),β~tI)

其中:
μ~t(xt,x0)=αˉt−1βt1−αˉtx0+αt(1−αˉt−1)1−αˉtxt\tilde{\mu}_t(x_t, x_0) = \frac{\sqrt{\bar{\alpha}_{t-1}}\beta_t}{1-\bar{\alpha}_t} x_0 + \frac{\sqrt{\alpha_t}(1-\bar{\alpha}_{t-1})}{1-\bar{\alpha}_t} x_tμ~t(xt,x0)=1αˉtαˉt1 βtx0+1αˉtαt (1αˉt1)xt

β~t=1−αˉt−11−αˉtβt\tilde{\beta}_t = \frac{1-\bar{\alpha}_{t-1}}{1-\bar{\alpha}_t} \beta_tβ~t=1αˉt1αˉt1βt

3.2 神经网络的引入

问题在于,在生成过程中我们不知道 x0x_0x0。因此,我们需要训练一个神经网络来近似这个反向过程:

pθ(xt−1∣xt)=N(xt−1;μθ(xt,t),Σθ(xt,t))p_\theta(x_{t-1} | x_t) = \mathcal{N}(x_{t-1}; \mu_\theta(x_t, t), \Sigma_\theta(x_t, t))pθ(xt1xt)=N(xt1;μθ(xt,t),Σθ(xt,t))

3.3 参数化策略

Ho等人(DDPM论文)提出了两种参数化方式:

预测均值:直接让网络预测 μθ(xt,t)\mu_\theta(x_t, t)μθ(xt,t)

预测噪声:让网络预测噪声 ϵθ(xt,t)\epsilon_\theta(x_t, t)ϵθ(xt,t),然后通过下式计算均值:

μθ(xt,t)=1αt(xt−βt1−αˉtϵθ(xt,t))\mu_\theta(x_t, t) = \frac{1}{\sqrt{\alpha_t}}\left(x_t - \frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}} \epsilon_\theta(x_t, t)\right)μθ(xt,t)=αt 1(xt1αˉt βtϵθ(xt,t))

预测噪声的方式在实践中表现更好,是DDPM的标准做法。


4. DDPM:去噪扩散概率模型

4.1 变分下界推导

扩散模型的训练目标是最大化数据的对数似然的变分下界(ELBO):

log⁡p(x0)≥Eq(x1:T∣x0)[log⁡pθ(x0:T)q(x1:T∣x0)]=−L\log p(x_0) \geq \mathbb{E}_{q(x_{1:T}|x_0)}\left[\log \frac{p_\theta(x_{0:T})}{q(x_{1:T}|x_0)}\right] = -Llogp(x0)Eq(x1:Tx0)[logq(x1:Tx0)pθ(x0:T)]=L

其中 LLL 可以分解为:

L=LT+∑t=2TLt−1+L0L = L_T + \sum_{t=2}^{T} L_{t-1} + L_0L=LT+t=2TLt1+L0

各项的含义:

  • LTL_TLT:最终时刻的KL散度,与参数无关
  • Lt−1L_{t-1}Lt1:每个时间步的KL散度
  • L0L_0L0:重构项

4.2 简化的训练目标

DDPM论文发现,可以进一步简化训练目标。经过推导,Lt−1L_{t-1}Lt1 可以写成:

Lt−1=Ex0,ϵ[βt22σt2αt(1−αˉt)∥ϵ−ϵθ(xt,t)∥2]L_{t-1} = \mathbb{E}_{x_0, \epsilon}\left[\frac{\beta_t^2}{2\sigma_t^2\alpha_t(1-\bar{\alpha}_t)} \|\epsilon - \epsilon_\theta(x_t, t)\|^2\right]Lt1=Ex0,ϵ[2σt2αt(1αˉt)βt2ϵϵθ(xt,t)2]

忽略权重系数,得到简化的目标函数:

Lsimple=Ex0,t,ϵ[∥ϵ−ϵθ(αˉtx0+1−αˉtϵ,t)∥2]L_{simple} = \mathbb{E}_{x_0, t, \epsilon}\left[\|\epsilon - \epsilon_\theta(\sqrt{\bar{\alpha}_t} x_0 + \sqrt{1-\bar{\alpha}_t}\epsilon, t)\|^2\right]Lsimple=Ex0,t,ϵ[ϵϵθ(αˉt x0+1αˉt ϵ,t)2]

这就是DDPM的训练目标:让网络预测的噪声与真实添加的噪声尽可能接近。

4.3 DDPM训练算法

# DDPM训练伪代码
for each batch of data x_0:
    # 随机采样时间步
    t ~ Uniform({1, 2, ..., T})
    
    # 采样噪声
    epsilon ~ N(0, I)
    
    # 前向加噪
    x_t = sqrt(alpha_bar[t]) * x_0 + sqrt(1 - alpha_bar[t]) * epsilon
    
    # 网络预测噪声
    epsilon_theta = model(x_t, t)
    
    # 计算损失并反向传播
    loss = MSE(epsilon, epsilon_theta)
    loss.backward()
    optimizer.step()

4.4 DDPM采样算法

# DDPM采样伪代码
x_T ~ N(0, I)  # 从纯噪声开始

for t = T, T-1, ..., 1:
    z ~ N(0, I) if t > 1 else 0  # 采样噪声(最后一步不加噪声)
    
    # 预测噪声
    epsilon_theta = model(x_t, t)
    
    # 计算均值
    mu_t = (x_t - beta[t] / sqrt(1 - alpha_bar[t]) * epsilon_theta) / sqrt(alpha[t])
    
    # 采样x_{t-1}
    x_{t-1} = mu_t + sqrt(beta[t]) * z

return x_0

5. DDIM:更快的采样

5.1 DDPM采样的问题

DDPM的一个主要缺点是采样速度慢。生成一张图像需要执行模型 TTT 次(通常 T=1000T=1000T=1000),这限制了其实际应用。

5.2 DDIM的核心思想

DDIM(Denoising Diffusion Implicit Models)的核心思想是:将随机采样过程转化为确定性过程。

DDIM定义了一个非马尔可夫的反向过程:

xt−1=αˉt−1x0(t)+1−αˉt−1−σt2⋅xt−αˉtx0(t)1−αˉt+σtϵtx_{t-1} = \sqrt{\bar{\alpha}_{t-1}} x_0^{(t)} + \sqrt{1-\bar{\alpha}_{t-1} - \sigma_t^2} \cdot \frac{x_t - \sqrt{\bar{\alpha}_t} x_0^{(t)}}{\sqrt{1-\bar{\alpha}_t}} + \sigma_t \epsilon_txt1=αˉt1 x0(t)+1αˉt1σt2 1αˉt xtαˉt x0(t)+σtϵt

其中 x0(t)x_0^{(t)}x0(t) 是从 xtx_txt 预测的原始数据:

x0(t)=xt−1−αˉtϵθ(xt,t)αˉtx_0^{(t)} = \frac{x_t - \sqrt{1-\bar{\alpha}_t} \epsilon_\theta(x_t, t)}{\sqrt{\bar{\alpha}_t}}x0(t)=αˉt xt1αˉt ϵθ(xt,t)

5.3 确定性采样

σt=0\sigma_t = 0σt=0 时,采样过程完全确定:

xt−1=αˉt−1x0(t)+1−αˉt−1⋅xt−αˉtx0(t)1−αˉtx_{t-1} = \sqrt{\bar{\alpha}_{t-1}} x_0^{(t)} + \sqrt{1-\bar{\alpha}_{t-1}} \cdot \frac{x_t - \sqrt{\bar{\alpha}_t} x_0^{(t)}}{\sqrt{1-\bar{\alpha}_t}}xt1=αˉt1 x0(t)+1αˉt1 1αˉt xtαˉt x0(t)

这意味着给定 xTx_TxT,生成的 x0x_0x0 是确定的。

5.4 跳步采样

由于DDIM是确定性的,我们可以使用更少的步数进行采样(如50步、20步甚至10步),而不会显著降低生成质量。这是通过选择子序列的时间步实现的:

# DDIM采样(50步)
timesteps = [1000, 980, 960, ..., 20]  # 选择子序列

for t in timesteps:
    # 确定性更新
    x_{t_prev} = sqrt(alpha_bar[t_prev]) * x_0_pred + \
                 sqrt(1 - alpha_bar[t_prev]) * (x_t - sqrt(alpha_bar[t]) * x_0_pred) / sqrt(1 - alpha_bar[t])

6. 条件扩散与引导采样

6.1 条件扩散模型

在实际应用中,我们通常希望根据某些条件生成数据,如文本描述、类别标签等。条件扩散模型将条件 ccc 引入网络:

ϵθ(xt,t,c)\epsilon_\theta(x_t, t, c)ϵθ(xt,t,c)

条件可以通过多种方式注入:

  • 拼接:将条件向量与输入拼接
  • 注意力:使用交叉注意力机制
  • AdaGN:自适应组归一化

6.2 Classifier Guidance

Classifier Guidance通过分类器的梯度来引导生成过程朝向特定条件:

ϵ~θ(xt,t,c)=ϵθ(xt,t,c)−w1−αˉt∇xtlog⁡pϕ(c∣xt)\tilde{\epsilon}_\theta(x_t, t, c) = \epsilon_\theta(x_t, t, c) - w \sqrt{1-\bar{\alpha}_t} \nabla_{x_t} \log p_\phi(c | x_t)ϵ~θ(xt,t,c)=ϵθ(xt,t,c)w1αˉt xtlogpϕ(cxt)

其中:

  • pϕ(c∣xt)p_\phi(c | x_t)pϕ(cxt) 是训练好的分类器
  • www 是引导强度,控制条件遵循程度
  • w>1w > 1w>1 时,生成样本更符合条件但多样性降低

6.3 Classifier-Free Guidance (CFG)

Classifier Guidance需要额外训练分类器,且分类器需要在噪声数据上训练。CFG提出了一种更优雅的方案:

在训练时,以一定概率(如10%)丢弃条件,让模型学习无条件生成:

ϵθ(xt,t,∅)\epsilon_\theta(x_t, t, \emptyset)ϵθ(xt,t,)

在采样时,使用以下公式进行引导:

ϵ~θ(xt,t,c)=ϵθ(xt,t,∅)+w⋅(ϵθ(xt,t,c)−ϵθ(xt,t,∅))\tilde{\epsilon}_\theta(x_t, t, c) = \epsilon_\theta(x_t, t, \emptyset) + w \cdot (\epsilon_\theta(x_t, t, c) - \epsilon_\theta(x_t, t, \emptyset))ϵ~θ(xt,t,c)=ϵθ(xt,t,)+w(ϵθ(xt,t,c)ϵθ(xt,t,))

其中 www 是引导尺度:

  • w=1w = 1w=1:无引导,标准条件生成
  • w>1w > 1w>1:增强条件遵循(通常取7.5)
  • w<1w < 1w<1:减弱条件影响

CFG的优势:

  • 不需要额外训练分类器
  • 引导强度可以灵活调整
  • 已成为文本到图像生成的标准技术

7. Stable Diffusion:潜在扩散模型

7.1 问题背景

原始扩散模型直接在像素空间操作,存在以下问题:

  • 计算成本高:高分辨率图像需要巨大的显存
  • 训练效率低:大量计算浪费在不可见的像素级细节上

7.2 潜在扩散模型(LDM)

Stable Diffusion基于潜在扩散模型(Latent Diffusion Model, LDM),其核心思想是:在压缩的潜在空间中进行扩散过程。

架构组成:

  1. 自编码器(VAE):将图像编码到潜在空间

    • 编码器:E(x)\mathcal{E}(x)E(x),将图像压缩为潜在表示 zzz
    • 解码器:D(z)\mathcal{D}(z)D(z),从潜在表示重建图像
    • 通常压缩率为 8×8\times8×(如 512×512512\times512512×512 图像压缩为 64×64×464\times64\times464×64×4
  2. U-Net:在潜在空间进行去噪

    • 输入:带噪的潜在表示 ztz_tzt
    • 输出:预测的噪声
    • 包含交叉注意力层,用于注入文本条件
  3. 文本编码器:将文本提示编码为嵌入向量

    • 通常使用CLIP的文本编码器
    • 输出文本嵌入用于U-Net的交叉注意力

7.3 Stable Diffusion的推理流程

# Stable Diffusion推理流程

# 1. 文本编码
text_embeddings = text_encoder(prompt)

# 2. 初始化潜在噪声
latents = randn(batch_size, 4, height//8, width//8)

# 3. 在潜在空间进行去噪
for t in timesteps:
    # U-Net预测噪声
    noise_pred = unet(latents, t, text_embeddings)
    
    # 使用CFG
    noise_pred_uncond = unet(latents, t, null_embeddings)
    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond)
    
    # DDIM/DDPM更新
    latents = scheduler.step(noise_pred, t, latents)

# 4. VAE解码
images = vae_decoder(latents)

7.4 SDXL与SD3的改进

SDXL(Stable Diffusion XL)

  • 更大的U-Net(2.6B参数 vs 0.9B)
  • 双文本编码器(OpenCLIP ViT-bigG + CLIP ViT-L)
  • 引入条件注入到所有UNet块
  • 原生支持更高分辨率(1024x1024)

SD3(Stable Diffusion 3)

  • 采用Diffusion Transformer(DiT)架构替代U-Net
  • 使用MMDiT(Multimodal Diffusion Transformer)处理多模态输入
  • 引入Flow Matching训练目标
  • 更好的文本渲染能力和多主题生成

8. 扩散模型的应用

8.1 图像生成

应用 代表模型 特点
文本到图像 DALL-E 2/3, Midjourney, SD 根据文本描述生成图像
图像编辑 InstructPix2Pix, SDEdit 根据指令编辑图像
超分辨率 SR3, LDM-SR 图像超分辨率重建
图像修复 RePaint, DDRM 图像inpainting/outpainting
风格迁移 DiffusionCLIP 基于扩散的风格迁移

8.2 音频生成

扩散模型在音频领域也有广泛应用:

  • 音乐生成:MusicLM, AudioLDM
  • 语音合成:NaturalSpeech, VoiceLDM
  • 音效生成:AudioGen

音频扩散模型通常将音频转换为频谱图(如Mel频谱),然后在频谱图上进行扩散,最后通过声码器(如HiFi-GAN)转换为波形。

8.3 视频生成

视频扩散模型是2024-2025年的热点:

  • Sora(OpenAI):基于Transformer架构,生成长达60秒的高质量视频
  • 可灵(快手):国产视频生成模型
  • Lumiere(Google):时空扩散模型

视频扩散的挑战:

  • 计算成本极高(3D时空卷积)
  • 时序一致性
  • 长视频生成

8.4 其他应用

  • 分子生成:用于药物发现的3D分子生成
  • 点云生成:3D形状生成
  • 时间序列预测:基于扩散的概率预测

9. 扩散模型与流模型的联系

9.1 流匹配(Flow Matching)

2022-2023年,流匹配(Flow Matching)和Rectified Flow的提出,建立了扩散模型与流模型之间的联系。

核心思想:直接学习从噪声分布到数据分布的流(flow),而不是通过多步去噪。

流匹配的目标函数:

LFM=Et,x0,x1[∥vθ(xt,t)−(x1−x0)∥2]\mathcal{L}_{FM} = \mathbb{E}_{t, x_0, x_1}\left[\|v_\theta(x_t, t) - (x_1 - x_0)\|^2\right]LFM=Et,x0,x1[vθ(xt,t)(x1x0)2]

其中 xt=(1−t)x0+tx1x_t = (1-t)x_0 + t x_1xt=(1t)x0+tx1 是线性插值。

9.2 与扩散模型的关系

流匹配可以看作是扩散模型的推广:

  • 扩散模型:使用特定的噪声调度
  • 流匹配:更灵活的传输路径

优势:

  • 训练更稳定
  • 采样步数可以更少
  • 与ODE/PDE理论联系更紧密

9.3 一致性模型(Consistency Models)

一致性模型是另一种加速采样的方法:

  • 学习将任意时刻的 xtx_txt 直接映射到 x0x_0x0
  • 可以实现单步生成
  • 代表工作:Consistency Models, LCM(Latent Consistency Models)

10. 扩散模型完整实现

10.1 完整PyTorch实现

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Optional, Tuple

class SinusoidalPositionEmbeddings(nn.Module):
    """正弦位置编码,用于时间步嵌入"""
    def __init__(self, dim: int):
        super().__init__()
        self.dim = dim

    def forward(self, time: torch.Tensor) -> torch.Tensor:
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings


class ResidualBlock(nn.Module):
    """带时间嵌入的残差块"""
    def __init__(self, in_channels: int, out_channels: int, time_emb_dim: int, dropout: float = 0.1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
        
        self.time_mlp = nn.Sequential(
            nn.SiLU(),
            nn.Linear(time_emb_dim, out_channels)
        )
        
        self.norm1 = nn.GroupNorm(8, in_channels)
        self.norm2 = nn.GroupNorm(8, out_channels)
        
        self.dropout = nn.Dropout(dropout)
        
        if in_channels != out_channels:
            self.shortcut = nn.Conv2d(in_channels, out_channels, 1)
        else:
            self.shortcut = nn.Identity()

    def forward(self, x: torch.Tensor, t_emb: torch.Tensor) -> torch.Tensor:
        h = self.conv1(F.silu(self.norm1(x)))
        
        # 添加时间嵌入
        t = self.time_mlp(t_emb)[:, :, None, None]
        h = h + t
        
        h = self.conv2(self.dropout(F.silu(self.norm2(h))))
        return h + self.shortcut(x)


class AttentionBlock(nn.Module):
    """自注意力块"""
    def __init__(self, channels: int, num_heads: int = 4):
        super().__init__()
        self.num_heads = num_heads
        self.norm = nn.GroupNorm(8, channels)
        self.qkv = nn.Conv2d(channels, channels * 3, 1)
        self.proj = nn.Conv2d(channels, channels, 1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, C, H, W = x.shape
        h = self.norm(x)
        qkv = self.qkv(h)
        q, k, v = qkv.chunk(3, dim=1)
        
        # 重塑为多头注意力格式
        q = q.view(B, self.num_heads, C // self.num_heads, H * W).transpose(-2, -1)
        k = k.view(B, self.num_heads, C // self.num_heads, H * W).transpose(-2, -1)
        v = v.view(B, self.num_heads, C // self.num_heads, H * W).transpose(-2, -1)
        
        # 计算注意力
        scale = (C // self.num_heads) ** -0.5
        attn = torch.softmax(q @ k.transpose(-2, -1) * scale, dim=-1)
        h = attn @ v
        
        # 重塑回原始形状
        h = h.transpose(-2, -1).view(B, C, H, W)
        h = self.proj(h)
        
        return x + h


class UNet(nn.Module):
    """用于扩散模型的U-Net架构"""
    def __init__(
        self,
        in_channels: int = 3,
        model_channels: int = 128,
        out_channels: int = 3,
        num_res_blocks: int = 2,
        attention_resolutions: Tuple[int, ...] = (2, 4),
        dropout: float = 0.1,
        channel_mult: Tuple[int, ...] = (1, 2, 4, 8),
        num_heads: int = 4,
    ):
        super().__init__()
        
        self.in_channels = in_channels
        self.model_channels = model_channels
        self.num_res_blocks = num_res_blocks
        
        time_emb_dim = model_channels * 4
        self.time_embed = nn.Sequential(
            SinusoidalPositionEmbeddings(model_channels),
            nn.Linear(model_channels, time_emb_dim),
            nn.SiLU(),
            nn.Linear(time_emb_dim, time_emb_dim),
        )
        
        # 下采样路径
        self.input_blocks = nn.ModuleList([
            nn.Conv2d(in_channels, model_channels, 3, padding=1)
        ])
        
        ch = model_channels
        ds = 1
        for level, mult in enumerate(channel_mult):
            for _ in range(num_res_blocks):
                layers = [ResidualBlock(ch, model_channels * mult, time_emb_dim, dropout)]
                ch = model_channels * mult
                if ds in attention_resolutions:
                    layers.append(AttentionBlock(ch, num_heads))
                self.input_blocks.append(nn.Sequential(*layers))
            
            if level != len(channel_mult) - 1:
                self.input_blocks.append(nn.Conv2d(ch, ch, 3, stride=2, padding=1))
                ds *= 2
        
        # 中间层
        self.middle_block = nn.Sequential(
            ResidualBlock(ch, ch, time_emb_dim, dropout),
            AttentionBlock(ch, num_heads),
            ResidualBlock(ch, ch, time_emb_dim, dropout),
        )
        
        # 上采样路径
        self.output_blocks = nn.ModuleList([])
        for level, mult in list(enumerate(channel_mult))[::-1]:
            for i in range(num_res_blocks + 1):
                layers = [ResidualBlock(ch * 2, model_channels * mult, time_emb_dim, dropout)]
                ch = model_channels * mult
                if ds in attention_resolutions:
                    layers.append(AttentionBlock(ch, num_heads))
                if level != 0 and i == num_res_blocks:
                    layers.append(nn.ConvTranspose2d(ch, ch, 4, stride=2, padding=1))
                    ds //= 2
                self.output_blocks.append(nn.Sequential(*layers))
        
        self.out = nn.Sequential(
            nn.GroupNorm(8, ch),
            nn.SiLU(),
            nn.Conv2d(ch, out_channels, 3, padding=1),
        )

    def forward(self, x: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:
        # 时间嵌入
        t_emb = self.time_embed(timesteps)
        
        # 下采样
        hs = []
        h = x
        for module in self.input_blocks:
            if isinstance(module, ResidualBlock):
                h = module(h, t_emb)
            elif isinstance(module, nn.Sequential):
                for layer in module:
                    if isinstance(layer, ResidualBlock):
                        h = layer(h, t_emb)
                    else:
                        h = layer(h)
            else:
                h = module(h)
            hs.append(h)
        
        # 中间层
        for layer in self.middle_block:
            if isinstance(layer, ResidualBlock):
                h = layer(h, t_emb)
            else:
                h = layer(h)
        
        # 上采样
        for module in self.output_blocks:
            h = torch.cat([h, hs.pop()], dim=1)
            for layer in module:
                if isinstance(layer, ResidualBlock):
                    h = layer(h, t_emb)
                else:
                    h = layer(h)
        
        return self.out(h)


class DiffusionModel(nn.Module):
    """扩散模型主类"""
    def __init__(
        self,
        model: nn.Module,
        timesteps: int = 1000,
        beta_start: float = 0.0001,
        beta_end: float = 0.02,
        schedule: str = "linear",
    ):
        super().__init__()
        self.model = model
        self.timesteps = timesteps
        
        # 定义beta调度
        if schedule == "linear":
            betas = torch.linspace(beta_start, beta_end, timesteps)
        elif schedule == "cosine":
            s = 0.008
            x = torch.linspace(0, timesteps, timesteps + 1)
            alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2
            alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
            betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
            betas = torch.clip(betas, 0.0001, 0.9999)
        else:
            raise ValueError(f"Unknown schedule: {schedule}")
        
        alphas = 1.0 - betas
        alphas_cumprod = torch.cumprod(alphas, dim=0)
        alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
        
        # 注册为缓冲区
        self.register_buffer("betas", betas)
        self.register_buffer("alphas", alphas)
        self.register_buffer("alphas_cumprod", alphas_cumprod)
        self.register_buffer("alphas_cumprod_prev", alphas_cumprod_prev)
        self.register_buffer("sqrt_alphas_cumprod", torch.sqrt(alphas_cumprod))
        self.register_buffer("sqrt_one_minus_alphas_cumprod", torch.sqrt(1.0 - alphas_cumprod))
        self.register_buffer("sqrt_recip_alphas", torch.sqrt(1.0 / alphas))
        
        # 用于后验方差
        posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
        self.register_buffer("posterior_variance", posterior_variance)
        self.register_buffer("posterior_log_variance_clipped", 
                           torch.log(torch.clamp(posterior_variance, min=1e-20)))

    def q_sample(self, x_0: torch.Tensor, t: torch.Tensor, noise: Optional[torch.Tensor] = None) -> torch.Tensor:
        """前向扩散过程:从x_0采样x_t"""
        if noise is None:
            noise = torch.randn_like(x_0)
        
        sqrt_alphas_cumprod_t = self.sqrt_alphas_cumprod[t][:, None, None, None]
        sqrt_one_minus_alphas_cumprod_t = self.sqrt_one_minus_alphas_cumprod[t][:, None, None, None]
        
        return sqrt_alphas_cumprod_t * x_0 + sqrt_one_minus_alphas_cumprod_t * noise

    def predict_start_from_noise(self, x_t: torch.Tensor, t: torch.Tensor, noise: torch.Tensor) -> torch.Tensor:
        """从预测的噪声恢复x_0"""
        sqrt_alphas_cumprod_t = self.sqrt_alphas_cumprod[t][:, None, None, None]
        sqrt_one_minus_alphas_cumprod_t = self.sqrt_one_minus_alphas_cumprod[t][:, None, None, None]
        return (x_t - sqrt_one_minus_alphas_cumprod_t * noise) / sqrt_alphas_cumprod_t

    def q_posterior_mean_variance(self, x_0: torch.Tensor, x_t: torch.Tensor, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """计算后验分布的均值和方差 q(x_{t-1} | x_t, x_0)"""
        posterior_mean = (
            self.betas[t][:, None, None, None] * self.sqrt_alphas_cumprod_prev[t][:, None, None, None] * x_0 +
            (1.0 - self.alphas_cumprod_prev[t][:, None, None, None]) * self.sqrt_alphas[t][:, None, None, None] * x_t
        ) / (1.0 - self.alphas_cumprod[t][:, None, None, None])
        
        posterior_variance = self.posterior_variance[t][:, None, None, None]
        return posterior_mean, posterior_variance

    def p_mean_variance(self, x_t: torch.Tensor, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """模型预测p(x_{t-1} | x_t)的均值和方差"""
        # 预测噪声
        pred_noise = self.model(x_t, t)
        
        # 预测x_0
        x_0_pred = self.predict_start_from_noise(x_t, t, pred_noise)
        x_0_pred = torch.clamp(x_0_pred, -1.0, 1.0)
        
        # 计算后验均值和方差
        model_mean, model_variance = self.q_posterior_mean_variance(x_0_pred, x_t, t)
        return model_mean, model_variance

    def p_sample(self, x_t: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        """从p(x_{t-1} | x_t)采样"""
        model_mean, model_variance = self.p_mean_variance(x_t, t)
        
        noise = torch.randn_like(x_t)
        # 当t=0时不添加噪声
        nonzero_mask = (t != 0).float()[:, None, None, None]
        
        return model_mean + nonzero_mask * torch.sqrt(model_variance) * noise

    @torch.no_grad()
    def sample(self, shape: Tuple[int, ...], device: str = "cuda") -> torch.Tensor:
        """生成样本"""
        b = shape[0]
        img = torch.randn(shape, device=device)
        
        for i in reversed(range(self.timesteps)):
            t = torch.full((b,), i, device=device, dtype=torch.long)
            img = self.p_sample(img, t)
        
        return img

    def training_loss(self, x_0: torch.Tensor) -> torch.Tensor:
        """计算训练损失"""
        b = x_0.shape[0]
        device = x_0.device
        
        # 随机采样时间步
        t = torch.randint(0, self.timesteps, (b,), device=device, dtype=torch.long)
        
        # 采样噪声
        noise = torch.randn_like(x_0)
        
        # 前向加噪
        x_t = self.q_sample(x_0, t, noise)
        
        # 模型预测
        pred_noise = self.model(x_t, t)
        
        # 简单损失(预测噪声)
        loss = F.mse_loss(pred_noise, noise)
        
        return loss


# 训练代码示例
def train_diffusion_model():
    """训练扩散模型的示例代码"""
    import torchvision
    from torchvision import transforms
    from torch.utils.data import DataLoader
    from tqdm import tqdm
    
    # 超参数
    batch_size = 64
    epochs = 100
    lr = 2e-4
    timesteps = 1000
    image_size = 32
    channels = 3
    
    # 数据预处理
    transform = transforms.Compose([
        transforms.Resize(image_size),
        transforms.CenterCrop(image_size),
        transforms.ToTensor(),
        transforms.Normalize([0.5] * channels, [0.5] * channels)
    ])
    
    # 加载数据集(以CIFAR-10为例)
    dataset = torchvision.datasets.CIFAR10(
        root="./data", train=True, download=True, transform=transform
    )
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    
    # 创建模型
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    unet = UNet(
        in_channels=channels,
        model_channels=128,
        out_channels=channels,
        num_res_blocks=2,
        attention_resolutions=(2, 4),
        channel_mult=(1, 2, 4),
    ).to(device)
    
    diffusion = DiffusionModel(unet, timesteps=timesteps, schedule="cosine").to(device)
    
    # 优化器
    optimizer = torch.optim.AdamW(diffusion.parameters(), lr=lr, betas=(0.9, 0.999), weight_decay=0.01)
    
    # 学习率调度
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    
    # 训练循环
    for epoch in range(epochs):
        diffusion.train()
        pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}")
        
        for batch_idx, (images, _) in enumerate(pbar):
            images = images.to(device)
            
            # 计算损失
            loss = diffusion.training_loss(images)
            
            # 反向传播
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(diffusion.parameters(), 1.0)
            optimizer.step()
            
            # 更新进度条
            pbar.set_postfix({"loss": f"{loss.item():.4f}"})
        
        scheduler.step()
        
        # 每10个epoch保存样本
        if (epoch + 1) % 10 == 0:
            diffusion.eval()
            with torch.no_grad():
                samples = diffusion.sample((16, channels, image_size, image_size), device=device)
                # 保存或显示样本...
                print(f"Epoch {epoch+1}: Generated samples saved")
    
    # 保存模型
    torch.save(diffusion.state_dict(), "diffusion_model.pt")
    print("Training completed!")


if __name__ == "__main__":
    train_diffusion_model()

10.2 DDIM采样实现

class DDIMSampler:
    """DDIM采样器,用于加速推理"""
    def __init__(self, diffusion_model: DiffusionModel, ddim_timesteps: int = 50):
        self.model = diffusion_model
        self.ddim_timesteps = ddim_timesteps
        
        # 选择子序列的时间步
        c = diffusion_model.timesteps // ddim_timesteps
        self.timestep_seq = list(range(0, diffusion_model.timesteps, c))
        if self.timestep_seq[-1] != diffusion_model.timesteps - 1:
            self.timestep_seq.append(diffusion_model.timesteps - 1)
    
    @torch.no_grad()
    def sample(self, shape: Tuple[int, ...], eta: float = 0.0, device: str = "cuda") -> torch.Tensor:
        """
        DDIM采样
        eta: 随机性控制参数,0为确定性采样,1为随机采样
        """
        b = shape[0]
        x = torch.randn(shape, device=device)
        
        for i in reversed(range(len(self.timestep_seq))):
            t = torch.full((b,), self.timestep_seq[i], device=device, dtype=torch.long)
            
            # 预测噪声
            pred_noise = self.model.model(x, t)
            
            # 预测x_0
            alpha_t = self.model.alphas_cumprod[t][:, None, None, None]
            alpha_t_prev = self.model.alphas_cumprod_prev[t][:, None, None, None] if i > 0 else torch.ones_like(alpha_t)
            
            pred_x0 = (x - torch.sqrt(1 - alpha_t) * pred_noise) / torch.sqrt(alpha_t)
            pred_x0 = torch.clamp(pred_x0, -1.0, 1.0)
            
            # 计算方向
            sigma_t = eta * torch.sqrt((1 - alpha_t_prev) / (1 - alpha_t) * (1 - alpha_t / alpha_t_prev))
            noise = torch.randn_like(x) if i > 0 else 0
            
            x = (
                torch.sqrt(alpha_t_prev) * pred_x0 +
                torch.sqrt(1 - alpha_t_prev - sigma_t ** 2) * pred_noise +
                sigma_t * noise
            )
        
        return x

避坑小贴士

1. 数值稳定性问题

问题:训练过程中出现NaN或Inf。

解决方案

  • 使用梯度裁剪:torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
  • 检查beta调度,确保beta值不会过大
  • 使用混合精度训练时注意数值溢出

2. 采样质量差

问题:生成的图像模糊或质量不高。

解决方案

  • 增加训练迭代次数
  • 尝试余弦噪声调度替代线性调度
  • 检查时间步编码是否正确实现
  • 确保输入数据归一化到[-1, 1]范围

3. 内存不足

问题:GPU内存溢出。

解决方案

  • 使用梯度累积
  • 降低batch size
  • 使用检查点技术(checkpointing)节省激活内存
  • 考虑使用潜在扩散模型降低计算成本

4. 训练速度慢

问题:训练时间过长。

解决方案

  • 使用DDIM或DPM-Solver++加速采样
  • 使用分布式训练
  • 考虑使用预训练的VAE编码器
  • 减少时间步数(如从1000降到100)

5. 模式崩溃与多样性

问题:生成样本缺乏多样性。

解决方案

  • 调整CFG的引导强度,避免过大
  • 增加训练数据的多样性
  • 使用分类器自由引导时,确保无条件训练充分

本章小结

本章系统介绍了扩散模型的核心原理与实现:

核心概念

  • 扩散模型通过前向加噪和反向去噪实现数据生成
  • DDPM使用简化的噪声预测目标进行训练
  • DDIM通过确定性采样实现快速推理

关键技术

  • Classifier-Free Guidance实现条件生成控制
  • 潜在扩散模型(LDM)大幅降低计算成本
  • Flow Matching提供了新的训练视角

2024-2025年最新进展

  • Diffusion Transformer(DiT)替代U-Net成为新架构
  • Stable Diffusion 3采用MMDiT和Flow Matching
  • 视频扩散模型(Sora、可灵)成为新热点

一句话总结:扩散模型通过"先破坏、后修复"的思想,结合深度学习的强大拟合能力,实现了高质量的图像、音频、视频生成,是生成式AI领域最重要的技术突破之一。


知识点回顾

知识点 核心内容
前向过程 马尔可夫链逐步加噪,可重参数化直接计算任意时刻
反向过程 神经网络学习去噪,预测噪声比预测均值效果更好
DDPM训练 最小化预测噪声与真实噪声的MSE损失
DDIM加速 确定性采样,可用50步甚至10步生成高质量图像
CFG引导 无条件和有条件预测的插值,控制生成多样性
LDM 在VAE压缩的潜在空间进行扩散,降低计算成本
DiT 使用Transformer替代U-Net,可扩展性更好
Flow Matching 直接学习流场,与扩散模型统一

参考资料

  1. Ho et al. “Denoising Diffusion Probabilistic Models” (NeurIPS 2020)
  2. Song et al. “Denoising Diffusion Implicit Models” (ICLR 2021)
  3. Rombach et al. “High-Resolution Image Synthesis with Latent Diffusion Models” (CVPR 2022)
  4. Peebles & Xie “Scalable Diffusion Models with Transformers” (ICCV 2023)
  5. Lipman et al. “Flow Matching for Generative Modeling” (ICLR 2023)
  6. Stable Diffusion官方文档:https://stability.ai/

本文档遵循CSDN专栏发布规范,使用Python 3.10+和PyTorch 2.0+环境。如有疑问,欢迎在评论区留言讨论。

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐