论文信息

  • 标题:Score-Based Generative Modeling through Stochastic Differential Equations
  • 会议:ICLR 2021
  • 单位:斯坦福大学、Google Brain
  • 代码:https://github.com/yang-song/score_sde
  • 论文:https://arxiv.org/pdf/2011.13456.pdf

一、开篇导读:这篇论文到底炸在哪?

如果把扩散模型比作离散步数的加噪去噪游戏,这篇论文直接把它升级成了连续时间的流体动力学

它用随机微分方程(SDE) 统一了两大生成流派:

  • DDPM系列(离散扩散,一步步加噪)
  • NCSN系列(分数匹配,估计分布梯度)

不仅如此,它还顺手搞出:
✅ 连续时间扩散,理论更优美
✅ Predictor-Corrector采样,精度飙升
✅ 概率流ODE,支持精确似然计算
✅ 无条件模型就能做图像修复、上色、条件生成
✅ CIFAR-10刷到FID 2.20,史上最强 unconditional 生成之一

全文看完,你会彻底懂:扩散模型 = 正向SDE + 反向SDE + 分数估计


二、核心思想一句话讲透

把数据慢慢搅成噪声(正向SDE),再用神经网络把噪声慢慢还原成数据(反向SDE)。
还原的关键,是学会分数函数∇xlog⁡pt(x)\nabla_x \log p_t(x)xlogpt(x),也就是数据分布的对数密度梯度。

在这里插入图片描述

图1:解决反向时间的随机微分方程会得出一种基于得分的生成模型。将数据转换为简单的噪声分布可以通过连续时间的随机微分方程来实现。如果我们知道在每个中间时间步的分布得分 Vxlog pt(x),那么这个随机微分方程就可以逆向求解。

三、数学基石:正向SDE与反向SDE(全文最重要公式)

在这里插入图片描述

图 2:通过 SDE 进行基于分数的生成建模概述。我们可以用一个 SDE 将数据映射到噪声分布(先验分布)(第 3.1 节),并反向使用此 SDE 进行生成建模(第 3.2 节)。我们还可以反向关联的概率流 ODE(第 4.3 节),这会产生一个确定性过程,从与 SDE 相同的分布中采样。反向时间 SDE 和概率流 ODE 都可以通过估计分数 Vx log p(x) 来获得(第 3.3 节)。

1. 正向扩散SDE(数据→噪声)

dx=f(x,t)dt+g(t)dwdx = f(x,t)dt + g(t)dwdx=f(x,t)dt+g(t)dw

  • dxdxdx:无穷小时间内数据的变化
  • f(x,t)f(x,t)f(x,t)漂移系数(确定性变化趋势,通俗:数据自己要飘的方向)
  • g(t)g(t)g(t)扩散系数(噪声强度,通俗:加噪力度)
  • dtdtdt:无穷小时间步
  • www:标准布朗运动(高斯噪声的连续版本)

直观理解:每一瞬间,数据既按固定方向漂移,又被高斯噪声轻轻扰动,最后完全变成高斯分布。

2. 反向时间SDE(噪声→数据)

dx=[f(x,t)−g(t)2∇xlog⁡pt(x)]dt+g(t)dwˉdx = \left[f(x,t) - g(t)^2 \nabla_x \log p_t(x)\right]dt + g(t)d\bar{w}dx=[f(x,t)g(t)2xlogpt(x)]dt+g(t)dwˉ

  • ∇xlog⁡pt(x)\nabla_x \log p_t(x)xlogpt(x)分数函数(全文灵魂!网络要学的目标)
  • wˉ\bar{w}wˉ:时间倒流的布朗运动
  • 多出来的 −g2∇log⁡p-g^2\nabla\log pg2logp 项:去噪修正项,把噪声往数据分布拉

通俗翻译:
反向过程 = 正向过程倒着走 + 用分数函数修正方向,把噪声“掰回”原图。


四、三大经典SDE:VE、VP、sub-VP

论文给出3种最实用的正向SDE,覆盖所有现代扩散模型👇

1. VE SDE(方差爆炸型,对应NCSN)

dx=d[σ2(t)]dtdwdx = \sqrt{\frac{d[\sigma^2(t)]}{dt}}dwdx=dtd[σ2(t)] dw

  • 特点:方差越来越大,最后炸成无穷大噪声
  • 适合:高质量图像生成
  • 对应模型:NCSN、一致性模型基础

2. VP SDE(方差守恒型,对应DDPM)

dx=−12β(t)xdt+β(t)dwdx = -\frac{1}{2}\beta(t)x dt + \sqrt{\beta(t)}dwdx=21β(t)xdt+β(t) dw

  • 特点:方差始终接近1,稳定好训练
  • 对应模型:DDPM、Stable Diffusion底层
  • 漂移项把数据往原点拉,噪声平稳加入

3. sub-VP SDE(方差有界型,论文新提出)

dx=−12β(t)xdt+β(t)(1−e−2∫0tβ(s)ds)dwdx = -\frac{1}{2}\beta(t)x dt + \sqrt{\beta(t)\left(1-e^{-2\int_0^t \beta(s)ds}\right)}dwdx=21β(t)xdt+β(t)(1e20tβ(s)ds) dw

  • 特点:方差比VP更小,似然效果最好
  • 贡献:CIFAR-10似然达到2.99 bits/dim,SOTA

五、训练目标:去噪分数匹配

网络 sθ(x,t)s_\theta(x,t)sθ(x,t) 要拟合分数 ∇xlog⁡pt(x)\nabla_x \log p_t(x)xlogpt(x),损失函数如下:
L(θ)=Et,x0,xt[∥sθ(xt,t)−∇xtlog⁡p0t(xt∣x0)∥2]\mathcal{L}(\theta) = \mathbb{E}_{t,x_0,x_t}\left[\left\|s_\theta(x_t,t) - \nabla_{x_t}\log p_{0t}(x_t|x_0)\right\|^2\right]L(θ)=Et,x0,xt[sθ(xt,t)xtlogp0t(xtx0)2]

  • x0x_0x0:干净原图
  • xtx_txt:t时刻加噪图
  • p0t(xt∣x0)p_{0t}(x_t|x_0)p0t(xtx0):从x0到xt的扩散分布
  • 直观:让网络预测的分数,等于真实加噪分布的梯度

六、超级采样器:Predictor-Corrector(PC)

算法流程

  1. Predictor:用反向SDE走一步(数值求解器,如Euler-Maruyama)
  2. Corrector:用朗之万MCMC矫正一步,修复离散误差
    交替执行,效果碾压纯采样!

直观理解

  • Predictor:粗略还原
  • Corrector:精细修图
    就像先画草稿,再用橡皮擦精修。

七、概率流ODE: deterministic 采样 + 精确似然

在这里插入图片描述

图 3:概率流常微分方程能够以自适应步长实现快速采样,且随着数值精度的变化而变化(左),在不损害质量的情况下减少分数函数评估次数(NFE)(中)。从潜在变量到图像的可逆映射允许进行插值(右)。

从SDE推导出等价确定性ODE,无随机噪声:
dx=[f(x,t)−12g(t)2∇xlog⁡pt(x)]dtdx = \left[f(x,t) - \frac{1}{2}g(t)^2\nabla_x \log p_t(x)\right]dtdx=[f(x,t)21g(t)2xlogpt(x)]dt
✅ 完全确定性,可复现
✅ 支持精确对数似然计算
✅ 自适应步长,速度超快


八、逆天功能:无条件模型做条件生成

不用重训,单个无条件模型直接做:

  • 类别条件生成
  • 图像修复(inpainting)
  • 图像上色(colorization)

核心公式(条件反向SDE):
$dx=[f(x,t)−g(t)2(∇xlog⁡pt(x)+∇xlog⁡pt(y∣x))]dt+g(t)dwˉdx = \left[f(x,t)-g(t)^2\left(\nabla_x\log p_t(x)+\nabla_x\log p_t(y|x)\right)\right]dt+g(t)d\bar{w}dx=[f(x,t)g(t)2(xlogpt(x)+xlogpt(yx))]dt+g(t)dwˉ

  • 加了一项 ∇xlog⁡pt(y∣x)\nabla_x\log p_t(y|x)xlogpt(yx)条件引导,把生成约束到满足y的区域

九、实验结果:强到不讲理

表格1 CIFAR-10 样本质量(FID越低越好)

模型 FID
StyleGAN2-ADA 2.92
DDPM 3.17
NCSN++ cont. (deep, VE) 2.20

出处:论文表3
分析:无条件分数SDE,打败有条件StyleGAN,生成质量巅峰。

表格2 CIFAR-10 似然结果(bits/dim 越低越好)

模型 NLL
Glow 3.35
DDPM 3.70
DDPM++ cont. (deep, sub-VP) 2.99

出处:论文表2
分析:sub-VP SDE似然屠榜,精确密度估计最强。

高分辨率生成效果(1024×1024 CelebA-HQ)

在这里插入图片描述

分析:首次从分数模型生成1K高清人脸,细节拉满,证明 scalability。

图像修复与上色

在这里插入图片描述

分析:单无条件模型,直接补全缺失区域、给灰度图上色,效果逼真。


十、核心PyTorch代码(可直接跑)

1. 时间正弦编码(标准实现)

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

def timestep_embedding(timesteps, dim, max_period=10000):
    half = dim // 2
    freqs = torch.exp(-math.log(max_period) * torch.arange(half).float() / half)
    args = timesteps[:, None].float() * freqs[None]
    embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
    if dim % 2:
        embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
    return embedding

2. 三大SDE基类(VE/VP/sub-VP)

class SDE:
    def __init__(self):
        pass
    def sde(self, x, t):
        raise NotImplementedError()
    def prior_sampling(self, shape):
        return torch.randn(*shape)

# VP SDE (对应DDPM)
class VPSDE(SDE):
    def __init__(self, beta_min=0.1, beta_max=20):
        self.beta_min = beta_min
        self.beta_max = beta_max
    def beta_t(self, t):
        return self.beta_min + t*(self.beta_max-self.beta_min)
    def sde(self, x, t):
        beta_t = self.beta_t(t)
        drift = -0.5 * beta_t[:, None, None, None] * x
        diffusion = torch.sqrt(beta_t)[:, None, None, None]
        return drift, diffusion

# VE SDE (对应NCSN)
class VESDE(SDE):
    def __init__(self, sigma_min=0.01, sigma_max=50):
        self.sigma_min = sigma_min
        self.sigma_max = sigma_max
    def sde(self, x, t):
        sigma = self.sigma_min * (self.sigma_max/self.sigma_min)**t
        dsigma_dt = sigma * math.log(self.sigma_max/self.sigma_min)
        drift = torch.zeros_like(x)
        diffusion = torch.sqrt(2*dsigma_dt)[:, None, None, None]
        return drift, diffusion

3. 分数网络(带时间嵌入)

class ScoreNet(nn.Module):
    def __init__(self, dim=64, emb_dim=128):
        super().__init__()
        self.emb = nn.Sequential(
            nn.Linear(emb_dim, emb_dim),
            nn.SiLU(),
            nn.Linear(emb_dim, emb_dim)
        )
        self.conv1 = nn.Conv2d(1, dim, 3, padding=1)
        self.conv2 = nn.Conv2d(dim, dim, 3, padding=1)
        self.conv3 = nn.Conv2d(dim, 1, 3, padding=1)
        self.act = nn.SiLU()
    def forward(self, x, t):
        emb = timestep_embedding(t, 128)
        emb = self.emb(emb).reshape(x.shape[0], -1, 1, 1)
        h = self.act(self.conv1(x) + emb)
        h = self.act(self.conv2(h) + emb)
        return self.conv3(h)

4. PC采样核心(Predictor+Corrector)

@torch.no_grad()
def pc_sampler(score_fn, sde, img_shape, steps=1000):
    x = sde.prior_sampling(img_shape)
    dt = 1.0 / steps
    for i in range(steps):
        t = torch.ones(img_shape[0], device=x.device) * (1.0 - i*dt)
        f, g = sde.sde(x, t)
        # Predictor
        score = score_fn(x, t)
        x = x - f*dt + g**2 * score * dt + g*torch.randn_like(x)*math.sqrt(dt)
        # Corrector (Langevin)
        score = score_fn(x, t)
        eps = 1e-4
        x = x + eps*score + math.sqrt(2*eps)*torch.randn_like(x)
    return x

十一、全文总结

  1. 统一框架:用SDE统一离散扩散与分数匹配模型
  2. 正向SDE:数据→噪声,3种范式VE/VP/sub-VP
  3. 反向SDE:噪声→数据,靠分数函数∇log⁡pt(x)\nabla\log p_t(x)logpt(x)
  4. 训练:去噪分数匹配,MSE损失
  5. 采样:Predictor-Corrector,精度最高
  6. 概率流ODE:确定性采样+精确似然
  7. 条件生成:无条件模型直接用,修复/上色/条件生成全搞定

这篇是扩散模型理论天花板,吃透它,后续所有改进(Consistency Model、Flow Matching、LCM)全是它的变体!


Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐