论文信息

  • 标题:Deep Unsupervised Learning using Nonequilibrium Thermodynamics
  • 会议:ICML 2015
  • 单位:斯坦福大学、加州大学伯克利分校
  • 代码:https://github.com/Sohl-Dickstein/Diffusion-Probabilistic-Models
  • 论文:https://arxiv.org/pdf/1503.03585.pdf

一、前言:生成模型的“灵活性-可解性”死结

传统生成模型一直绕不开两大矛盾

  • 简单模型(高斯、拉普拉斯):好算、好训练,但表达能力极差,拟合不了复杂数据
  • 复杂灵活模型:能拟合任意分布,但归一化常数Z算不出来,采样、评估全卡死

这篇论文直接用热力学+马尔可夫链破局:
先慢慢给数据加噪声把结构毁掉(前向扩散),再学一步一步把噪声去掉(反向扩散)
从此,扩散模型(Diffusion Model) 正式诞生!


二、核心思想一句话讲透

  • 前向过程(固定):对数据一步步加高斯噪声,T步后完全变成标准高斯分布
  • 反向过程(学习):从高斯噪声出发,一步步去噪,还原成真实数据分布

通俗解释:
就像把一杯奶茶慢慢兑成白开水(前向),再学怎么把白开水兑回奶茶(反向)


三、整体框架

在这里插入图片描述

图1 扩散模型整体流程

  • 上排:前向扩散,数据分布→标准高斯
  • 中排:反向生成,标准高斯→数据分布
  • 下排:反向过程的漂移项

四、前向扩散过程(Forward Process)

4.1 数学定义

q(x(t)∣x(t−1))=N(x(t);x(t−1)1−βt,Iβt)q(x^{(t)}|x^{(t-1)}) = \mathcal{N}\left(x^{(t)}; x^{(t-1)}\sqrt{1-\beta_t}, I\beta_t\right)q(x(t)x(t1))=N(x(t);x(t1)1βt ,Iβt)

  • x(t)x^{(t)}x(t):第t步加噪后数据
  • βt\beta_tβt:第t步噪声率(越来越大)
  • 1−βt\sqrt{1-\beta_t}1βt :均值系数
  • IβtI\beta_tIβt:方差(单位阵乘噪声率)
  • N(⋅;μ,Σ)\mathcal{N}(\cdot;\mu,\Sigma)N(;μ,Σ):高斯分布

通俗解释:
每一步都在上一步基础上保留一部分信号,混入一部分高斯噪声

4.2 联合分布

q(x(0⋯T))=q(x(0))∏t=1Tq(x(t)∣x(t−1))q(x^{(0\cdots T)}) = q(x^{(0)})\prod_{t=1}^T q(x^{(t)}|x^{(t-1)})q(x(0T))=q(x(0))t=1Tq(x(t)x(t1))

  • x(0)x^{(0)}x(0):原始真实数据
  • 连乘:马尔可夫链一步步加噪

我们来推导扩散模型前向过程的核心闭式公式(一步采样),也就是从 x(0)x^{(0)}x(0) 直接得到任意 ttt 步的 x(t)x^{(t)}x(t),这是扩散模型训练的基础。


4.3 前向过程的一步闭式推导

我们从前向马尔可夫链的递推式开始:
q(x(t)∣x(t−1))=N(x(t);x(t−1)1−βt,Iβt) q(x^{(t)} | x^{(t-1)}) = \mathcal{N}\left(x^{(t)}; x^{(t-1)}\sqrt{1-\beta_t}, I\beta_t\right) q(x(t)x(t1))=N(x(t);x(t1)1βt ,Iβt)
其中,我们定义 αt=1−βt\alpha_t = 1-\beta_tαt=1βt,并记累积乘积 αˉt=∏i=1tαi\bar{\alpha}_t = \prod_{i=1}^t \alpha_iαˉt=i=1tαi

推导过程:
  • 第一步(t=1t=1t=1
    x(1)=α1x(0)+1−α1ϵ(1),ϵ(1)∼N(0,I) x^{(1)} = \sqrt{\alpha_1} x^{(0)} + \sqrt{1-\alpha_1}\epsilon^{(1)},\quad \epsilon^{(1)} \sim \mathcal{N}(0,I) x(1)=α1 x(0)+1α1 ϵ(1),ϵ(1)N(0,I)

  • 第二步(t=2t=2t=2
    x(1)x^{(1)}x(1) 代入递推式:
    x(2)=α2x(1)+1−α2ϵ(2)=α2(α1x(0)+1−α1ϵ(1))+1−α2ϵ(2)=α1α2x(0)+α2(1−α1)ϵ(1)+1−α2ϵ(2) \begin{align*} x^{(2)} &= \sqrt{\alpha_2}x^{(1)} + \sqrt{1-\alpha_2}\epsilon^{(2)} \\ &= \sqrt{\alpha_2}\left(\sqrt{\alpha_1}x^{(0)} + \sqrt{1-\alpha_1}\epsilon^{(1)}\right) + \sqrt{1-\alpha_2}\epsilon^{(2)} \\ &= \sqrt{\alpha_1\alpha_2}x^{(0)} + \sqrt{\alpha_2(1-\alpha_1)}\epsilon^{(1)} + \sqrt{1-\alpha_2}\epsilon^{(2)} \end{align*} x(2)=α2 x(1)+1α2 ϵ(2)=α2 (α1 x(0)+1α1 ϵ(1))+1α2 ϵ(2)=α1α2 x(0)+α2(1α1) ϵ(1)+1α2 ϵ(2)
    根据高斯分布的性质,两个独立高斯噪声项可以合并:
    N(0,α2(1−α1)I)+N(0,(1−α2)I)=N(0,(α2−α1α2+1−α2)I)=N(0,(1−α1α2)I) \mathcal{N}(0, \alpha_2(1-\alpha_1)I) + \mathcal{N}(0, (1-\alpha_2)I) = \mathcal{N}(0, (\alpha_2 - \alpha_1\alpha_2 + 1 - \alpha_2)I) = \mathcal{N}(0, (1-\alpha_1\alpha_2)I) N(0,α2(1α1)I)+N(0,(1α2)I)=N(0,(α2α1α2+1α2)I)=N(0,(1α1α2)I)
    因此,x(2)x^{(2)}x(2) 可以写成:
    x(2)=αˉ2x(0)+1−αˉ2ϵ(2),ϵ(2)∼N(0,I) x^{(2)} = \sqrt{\bar{\alpha}_2}x^{(0)} + \sqrt{1-\bar{\alpha}_2}\epsilon^{(2)},\quad \epsilon^{(2)} \sim \mathcal{N}(0,I) x(2)=αˉ2 x(0)+1αˉ2 ϵ(2),ϵ(2)N(0,I)

  • 归纳推广到第 ttt
    以此类推,我们可以得到任意 ttt 步的闭式表达:
    x(t)=αˉtx(0)+1−αˉtϵ x^{(t)} = \sqrt{\bar{\alpha}_t}x^{(0)} + \sqrt{1-\bar{\alpha}_t}\epsilon x(t)=αˉt x(0)+1αˉt ϵ

  • αˉtx(0)\sqrt{\bar{\alpha}_t}x^{(0)}αˉt x(0):保留的原始图片信号

  • 1−αˉtϵ\sqrt{1-\bar{\alpha}_t}\epsilon1αˉt ϵ:加入的噪声信号

  • x(t)x^{(t)}x(t):两个信号加权相加 → 得到加噪后的图片 x(t)x^{(t)}x(t)

    其中,ϵ∼N(0,I)\epsilon \sim \mathcal{N}(0, I)ϵN(0,I) 是一个与 ttt 相关的高斯噪声。ttt 越小 → αˉt\bar{\alpha}_tαˉt越大 → 保留原图多,噪声少;ttt 越大 →αˉt\bar{\alpha}_tαˉt越小 → 保留原图少,噪声多;t=Tt = Tt=T(最后一步) → 图片完全变成纯高斯噪声。


公式各符号解释

符号 含义
x(0)x^{(0)}x(0) 原始数据样本(无噪声)
x(t)x^{(t)}x(t) 经过 ttt 步扩散后的带噪声样本
αˉt\bar{\alpha}_tαˉt 累积乘积 ∏i=1t(1−βi)\prod_{i=1}^t (1-\beta_i)i=1t(1βi),表示信号的保留比例
ϵ\epsilonϵ 从标准高斯分布中采样的噪声

通俗解释
这个公式让我们不用一步步迭代,直接从原始数据 x(0)x^{(0)}x(0) 和随机噪声 ϵ\epsilonϵ,就能算出任意 ttt 步的带噪声数据 x(t)x^{(t)}x(t)。这极大地简化了训练过程,让我们可以在任意时间步 ttt 采样并训练模型。


关键推论:当 t→Tt \to TtT

当扩散步数 TTT 足够大时,αˉT→0\bar{\alpha}_T \to 0αˉT0,此时:
x(T)≈ϵ∼N(0,I) x^{(T)} \approx \epsilon \sim \mathcal{N}(0, I) x(T)ϵN(0,I)
这说明,经过足够多步的加噪,数据会完全变成一个标准高斯分布,为后续的反向生成过程提供了起点。


五、反向生成过程(Reverse Process)

5.1 数学定义

p(x(t−1)∣x(t))=N(x(t−1);fμ(x(t),t),fΣ(x(t),t))p(x^{(t-1)}|x^{(t)}) = \mathcal{N}\left(x^{(t-1)}; f_\mu(x^{(t)},t), f_\Sigma(x^{(t)},t)\right)p(x(t1)x(t))=N(x(t1);fμ(x(t),t),fΣ(x(t),t))

  • fμ(⋅)f_\mu(\cdot)fμ():网络预测均值
  • fΣ(⋅)f_\Sigma(\cdot)fΣ():网络预测方差
  • 网络只需要学这两个函数,就能从噪声生成数据

5.2 联合分布

p(x(0⋯T))=p(x(T))∏t=1Tp(x(t−1)∣x(t))p(x^{(0\cdots T)}) = p(x^{(T)})\prod_{t=1}^T p(x^{(t-1)}|x^{(t)})p(x(0T))=p(x(T))t=1Tp(x(t1)x(t))

  • p(x(T))=N(0,I)p(x^{(T)})=\mathcal{N}(0,I)p(x(T))=N(0,I):初始纯噪声

我们来推导反向过程的损失函数,也就是DDPM中使用的简化MSE损失,这是扩散模型训练的核心。


5.3 反向过程的损失函数推导

我们的目标是让模型学习反向过程 p(x(t−1)∣x(t))p(x^{(t-1)}|x^{(t)})p(x(t1)x(t)),去逼近前向过程的后验分布 q(x(t−1)∣x(t),x(0))q(x^{(t-1)}|x^{(t)},x^{(0)})q(x(t1)x(t),x(0))

步骤1:计算真实后验分布 q(x(t−1)∣x(t),x(0))q(x^{(t-1)}|x^{(t)},x^{(0)})q(x(t1)x(t),x(0))

利用贝叶斯公式,结合前向过程的马尔可夫链性质,我们可以推导出:
q(x(t−1)∣x(t),x(0))=N(x(t−1);μ~t(x(t),x(0)),β~tI) q(x^{(t-1)}|x^{(t)},x^{(0)}) = \mathcal{N}\left(x^{(t-1)}; \tilde{\mu}_t(x^{(t)},x^{(0)}), \tilde{\beta}_t I\right) q(x(t1)x(t),x(0))=N(x(t1);μ~t(x(t),x(0)),β~tI)
其中,均值和方差都是可以用已知参数计算出来的:

  • 均值:
    μ~t(x(t),x(0))=αˉt−1βt1−αˉtx(0)+αt(1−αˉt−1)1−αˉtx(t) \tilde{\mu}_t(x^{(t)},x^{(0)}) = \frac{\sqrt{\bar{\alpha}_{t-1}}\beta_t}{1-\bar{\alpha}_t}x^{(0)} + \frac{\sqrt{\alpha_t}(1-\bar{\alpha}_{t-1})}{1-\bar{\alpha}_t}x^{(t)} μ~t(x(t),x(0))=1αˉtαˉt1 βtx(0)+1αˉtαt (1αˉt1)x(t)
  • 方差:
    β~t=1−αˉt−11−αˉtβt \tilde{\beta}_t = \frac{1-\bar{\alpha}_{t-1}}{1-\bar{\alpha}_t}\beta_t β~t=1αˉt1αˉt1βt
步骤2:用模型预测噪声来参数化均值

在DDPM中,模型不直接预测均值 μ\muμ,而是预测噪声 ϵθ(x(t),t)\epsilon_\theta(x^{(t)}, t)ϵθ(x(t),t)。我们可以利用前向过程的一步闭式公式 x(t)=αˉtx(0)+1−αˉtϵx^{(t)} = \sqrt{\bar{\alpha}_t}x^{(0)} + \sqrt{1-\bar{\alpha}_t}\epsilonx(t)=αˉt x(0)+1αˉt ϵ,把 x(0)x^{(0)}x(0) 表示为:
x(0)=x(t)−1−αˉtϵαˉt x^{(0)} = \frac{x^{(t)} - \sqrt{1-\bar{\alpha}_t}\epsilon}{\sqrt{\bar{\alpha}_t}} x(0)=αˉt x(t)1αˉt ϵ
将其代入 μ~t\tilde{\mu}_tμ~t 的表达式,可以发现:让模型预测噪声 ϵθ(x(t),t)\epsilon_\theta(x^{(t)}, t)ϵθ(x(t),t),等价于让模型预测反向过程的均值。

步骤3:简化损失函数为MSE

为了训练模型,我们需要最小化模型分布 pθ(x(t−1)∣x(t))p_\theta(x^{(t-1)}|x^{(t)})pθ(x(t1)x(t)) 和真实后验分布 q(x(t−1)∣x(t),x(0))q(x^{(t-1)}|x^{(t)},x^{(0)})q(x(t1)x(t),x(0)) 之间的KL散度。
Lt=DKL(q(x(t−1)∣x(t),x(0))∥pθ(x(t−1)∣x(t))) \mathcal{L}_t = D_{KL}\left(q(x^{(t-1)}|x^{(t)},x^{(0)}) \parallel p_\theta(x^{(t-1)}|x^{(t)})\right) Lt=DKL(q(x(t1)x(t),x(0))pθ(x(t1)x(t)))
当我们把 pθp_\thetapθ 的方差固定为常数(如DDPM中直接使用 β~t\tilde{\beta}_tβ~t),那么最小化KL散度,就等价于最小化预测噪声和真实噪声之间的均方误差(MSE):
L(θ)=Et,x(0),ϵ[∥ϵ−ϵθ(x(t),t)∥2] \mathcal{L}(\theta) = \mathbb{E}_{t,x^{(0)},\epsilon} \left[ \|\epsilon - \epsilon_\theta(x^{(t)}, t)\|^2 \right] L(θ)=Et,x(0),ϵ[ϵϵθ(x(t),t)2]
其中,x(t)=αˉtx(0)+1−αˉtϵx^{(t)} = \sqrt{\bar{\alpha}_t}x^{(0)} + \sqrt{1-\bar{\alpha}_t}\epsilonx(t)=αˉt x(0)+1αˉt ϵ


最终核心损失公式与解释

L(θ)=Et∼Uniform(1,T),x(0)∼q(x(0)),ϵ∼N(0,I)[∥ϵ−ϵθ(x(t),t)∥2] \mathcal{L}(\theta) = \mathbb{E}_{t \sim \text{Uniform}(1,T), x^{(0)} \sim q(x^{(0)}), \epsilon \sim \mathcal{N}(0,I)} \left[ \|\epsilon - \epsilon_\theta(x^{(t)}, t)\|^2 \right] L(θ)=EtUniform(1,T),x(0)q(x(0)),ϵN(0,I)[ϵϵθ(x(t),t)2]

符号 含义
L(θ)\mathcal{L}(\theta)L(θ) 模型的训练损失,越小越好
ϵ\epsilonϵ 前向过程中加入的真实噪声
ϵθ(x(t),t)\epsilon_\theta(x^{(t)}, t)ϵθ(x(t),t) 模型预测的噪声
x(t)x^{(t)}x(t) ttt步的带噪声数据
ttt 随机采样的时间步
x(0)x^{(0)}x(0) 原始无噪声数据

通俗解释:
这个损失函数的目标非常简单粗暴:让模型学会“猜噪声”。给它一张加了 ttt 步噪声的图片 x(t)x^{(t)}x(t),让它预测出图片里的噪声是什么。当模型能精准地猜出噪声,就能从噪声中一步一步还原出原始图片。



六、训练目标:对数似然下界

6.1 核心公式

K=−∑t=2TEq(x(0),x(t))[DKL(q(x(t−1)∣x(t),x(0))∥p(x(t−1)∣x(t)))]+CK = -\sum_{t=2}^T \mathbb{E}_{q(x^{(0)},x^{(t)})}\left[ D_{KL}\left( q(x^{(t-1)}|x^{(t)},x^{(0)}) \parallel p(x^{(t-1)}|x^{(t)}) \right) \right] + CK=t=2TEq(x(0),x(t))[DKL(q(x(t1)x(t),x(0))p(x(t1)x(t)))]+C

  • KKK:对数似然下界(越大越好)
  • DKLD_{KL}DKL:KL散度(衡量分布差异)
  • q(x(t−1)∣x(t),x(0))q(x^{(t-1)}|x^{(t)},x^{(0)})q(x(t1)x(t),x(0)):后验分布(可解析计算)
  • p(x(t−1)∣x(t))p(x^{(t-1)}|x^{(t)})p(x(t1)x(t)):模型预测分布
  • CCC:常数项

通俗解释:
让模型预测的去噪分布,尽可能逼近真实的后验分布


七、核心代码(PyTorch极简实现)

7.1 前向扩散(加噪)

import torch
import torch.nn as nn
import torch.nn.functional as F

def forward_diffusion(x0, t, betas):
    # 计算alpha和累积alpha
    # 给干净图片 x0,按时间步 t 加入对应强度的噪声,返回加噪图和真实噪声。
    alpha = 1. - betas
    alpha_bar = torch.cumprod(alpha, dim=0)
    alpha_bar_t = alpha_bar[t].reshape(-1, 1, 1, 1)
    
    # 采样噪声
    noise = torch.randn_like(x0)
    
    # 加噪结果
    xt = torch.sqrt(alpha_bar_t) * x0 + torch.sqrt(1 - alpha_bar_t) * noise
    return xt, noise

7.2 反向去噪网络(简单MLP/CNN)

class ReverseNet(nn.Module):
    def __init__(self, dim=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(1, dim, 3, padding=1), nn.ReLU(),
            nn.Conv2d(dim, dim, 3, padding=1), nn.ReLU(),
            nn.Conv2d(dim, 1, 3, padding=1)
        )
    def forward(self, x, t):
        return self.net(x)  # 预测噪声

7.3 训练损失(简化版)

def loss_fn(pred_noise, true_noise):
    return F.mse_loss(pred_noise, true_noise)

八、实验结果:全任务SOTA

8.1 各数据集对数似然(表格1 出处:原论文Table 1)

数据集 K(下界) 相对提升
Swiss Roll 2.35 bits 6.45 bits
Binary Heartbeat -2.414 bits/seq 12.024 bits/seq
Bark -0.55 bits/pixel 1.5 bits/pixel
Dead Leaves 1.489 bits/pixel 3.536 bits/pixel
CIFAR-10 5.4 ± 0.2 bits/pixel 11.5 ± 0.2 bits/pixel

表格1 各数据集对数似然下界
分析:

  • 扩散模型在所有数据集上远超简单高斯模型
  • 高复杂度自然图像(CIFAR-10)提升最明显

8.2 MNIST对比SOTA(表格2 出处:原论文Table 2)

模型 对数似然(bits)
Stacked CAE 174 ± 2.3
DBN 199 ± 2.9
Deep GSN 309 ± 1.6
扩散模型 317 ± 2.7
GAN 325 ± 2.9

表格2 MNIST对数似然对比
分析:

  • 扩散模型远超传统生成模型,逼近GAN性能
  • 训练更稳定,无模式崩溃

8.3 图像生成与修复

在这里插入图片描述

图2 左:CIFAR-10样本;右:图像去噪
分析:

  • 扩散模型可直接用于去噪、修复、填充
  • 天然支持条件生成,无需修改架构

九、关键创新与贡献

  1. 首次提出扩散模型框架,用非平衡热力学解决生成模型难题
  2. 前向过程固定,只学反向过程,训练极其稳定
  3. 支持精确采样、概率计算、条件生成
  4. 无模式崩溃、无对抗训练不稳定问题
  5. 为后续DDPM、IDDPM、Stable Diffusion奠定全部理论基础

十、总结

这篇论文是整个扩散模型家族的开山鼻祖

  • 用最简单的高斯加噪+高斯去噪搞定复杂分布建模
  • 训练稳定、生成质量高、理论优美
  • 今天所有文生图、图生图扩散模型,全是这篇论文的直系后代

它证明了:不用GAN、不用VAE,一步一步去噪,就能生成最真实的图像

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐