Consistency Models:打破扩散模型生成速度瓶颈的新一代生成范式
论文信息
- 标题:Consistency Models
- 会议:Proceedings of the 40th International Conference on Machine Learning (ICML 2023)
- 单位:OpenAI, San Francisco, CA 94110, USA
- 代码:暂无官方开源代码
- 论文:https://arxiv.org/pdf/2303.01469.pdf
一、背景:扩散模型的甜蜜与烦恼
扩散模型在AIGC领域实现了跨越式突破,在图像、音频、视频生成任务上都交出了远超传统模型的答卷,但它有一个始终被行业诟病的致命短板:生成速度太慢。
扩散模型的核心是迭代式采样,需要几十到上千次的网络前向传播,才能从纯噪声逐步还原出清晰的图像,相比GAN、VAE这类单步生成模型,推理耗时高出10~2000倍,直接限制了它在实时生成场景的落地。而单步生成的GAN又存在训练不稳定、模式崩溃、泛化能力差的问题,始终无法替代扩散模型的生态位。
OpenAI这篇工作提出的一致性模型,直接解决了这个行业级痛点:它天生支持单步快速生成,同时完整保留了扩散模型的高质量生成能力、零样本图像编辑能力,还提供了两种训练范式——既可以从预训练扩散模型中蒸馏知识,也能作为完全独立的生成模型从零训练,彻底摆脱了对扩散模型的依赖,开辟了生成模型的全新赛道。
二、前置基础:连续时间扩散模型核心逻辑
一致性模型完全建立在连续时间扩散模型的理论体系之上,我们先拆解其核心公式与概念,同时为每个专业内容补充通俗解释。
扩散模型的核心是两个过程:前向加噪过程(把干净图片逐步变成纯噪声)和反向去噪过程(从纯噪声还原出干净图片)。论文中采用了Karras等人提出的连续时间扩散框架,核心公式如下。
2.1 前向扩散的随机微分方程(SDE)
dxt=μ(xt,t) dt+σ(t) dwtdx_{t}=\mu (x_{t},t)\, dt+\sigma (t)\, dw_{t}dxt=μ(xt,t)dt+σ(t)dwt
符号逐字解释:
- xtx_txt:t时刻的带噪数据(图像),t的取值范围是[0,T][0, T][0,T]。其中t=0t=0t=0对应干净的原始数据,t=Tt=Tt=T对应完全的高斯噪声,论文中设置T=80T=80T=80;
- μ(xt,t)\mu(x_t, t)μ(xt,t):SDE的漂移系数,决定了数据在t时刻的平均变化趋势;
- σ(t)\sigma(t)σ(t):扩散系数,决定了t时刻向数据中加入的高斯噪声强度;
- dtdtdt:时间的微小变化量;
- wtw_twt:标准布朗运动(维纳过程),可以通俗理解为连续时间维度上的纯随机高斯噪声序列,保证扩散过程的随机性。
通俗解释:这个公式描述了“给干净图片逐步加噪,最终变成纯噪声”的随机过程,就像往一杯清水里持续滴入墨汁,每一个瞬间都加入一点随机的墨滴,最终整杯水变成均匀的黑色,完全看不出原本清水的样子。
论文中采用了简化的设置:μ(x,t)=0\mu(x,t)=0μ(x,t)=0,σ(t)=2t\sigma(t)=\sqrt{2t}σ(t)=2t,此时加噪过程可以简化为:
xt=x0+t⋅z,z∼N(0,I)x_t = x_0 + t \cdot z, \quad z \sim N(0, I)xt=x0+t⋅z,z∼N(0,I)
也就是直接给干净图片x0x_0x0加上标准差为t的高斯噪声,就能得到t时刻的带噪图片,大幅简化了计算逻辑。
2.2 概率流常微分方程(PF ODE)
扩散模型的SDE过程,对应着一条确定性的常微分方程轨迹,也就是论文中核心的PF ODE:
dxt=[μ(xt,t)−12σ(t)2∇logpt(xt)]dtd x_{t}=\left[\mu\left(x_{t}, t\right)-\frac{1}{2} \sigma(t)^{2} \nabla log p_{t}\left(x_{t}\right)\right] d tdxt=[μ(xt,t)−21σ(t)2∇logpt(xt)]dt
符号逐字解释:
- ∇logpt(xt)\nabla log p_t(x_t)∇logpt(xt):t时刻数据分布pt(x)p_t(x)pt(x)的得分函数(score function),也是扩散模型中需要训练的核心网络,通俗来说,它的作用是“告诉模型当前带噪图片里,噪声应该往哪个方向去除”;
- 其余符号与上述SDE公式完全一致。
结合论文的简化设置,最终得到扩散模型采样的核心经验PF ODE:
dxtdt=−tsϕ(xt,t)\frac{d x_{t}}{ d t}=-t s_{\phi}\left(x_{t}, t\right)dtdxt=−tsϕ(xt,t)
符号补充解释:
- sϕ(xt,t)s_\phi(x_t, t)sϕ(xt,t):我们训练的得分网络,参数为ϕ\phiϕ,用来近似真实的得分函数∇logpt(xt)\nabla log p_t(x_t)∇logpt(xt)。
通俗解释:PF ODE是扩散过程对应的确定性平滑轨迹,它和随机SDE过程有着完全一致的边缘数据分布,但没有随机波动。就像同样是把清水变成黑水,SDE是随机滴墨,而PF ODE是精准控制墨水滴入的速度和位置,走出一条完全确定的路径,最终结果和随机滴墨的分布一模一样。扩散模型的采样,本质就是用数值求解器(Euler、Heun等)反向迭代求解这个ODE,从噪声还原出图片,这也是它需要多次网络前向传播、速度慢的根源。
三、核心定义:一致性模型到底是什么?
3.1 自一致性:模型的核心灵魂
一致性模型的核心思想,完全围绕PF ODE的轨迹特性展开,我们先通过论文中的核心示意图理解其逻辑。
图片1 概率流ODE轨迹与一致性模型的映射关系
出处:Consistency Models论文Figure 1
分析:这张图清晰展示了一致性模型的核心目标——PF ODE的每一条轨迹,都从干净数据x0x_0x0(t=0)出发,平滑过渡到纯噪声xTx_TxT(t=T)。而一致性模型要做的,就是把这条轨迹上任意一个时间点的状态,直接映射到轨迹的起点(干净数据)。
基于此,论文给出了一致性函数的严格定义:
对于PF ODE在区间t∈[ϵ,T]t \in [\epsilon, T]t∈[ϵ,T]上的任意一条解轨迹{xt}\{x_t\}{xt},一致性函数fff满足:
f:(xt,t)↦xϵf:(x_{t}, t) \mapsto x_{\epsilon}f:(xt,t)↦xϵ
其中ϵ\epsilonϵ是一个极小的正数(论文中取0.002),用于避免数值求解的不稳定性,xϵx_\epsilonxϵ几乎等价于干净的原始数据x0x_0x0。
而这个函数最核心的性质,就是自一致性:
f(xt,t)=f(xt′,t′)f(x_{t}, t)=f(x_{t'}, t')f(xt,t)=f(xt′,t′)
只要xtx_txt和xt′x_{t'}xt′属于PF ODE的同一条轨迹,无论t和t’取何值,一致性函数的输出都是同一个xϵx_\epsilonxϵ。
通俗解释:同一条轨迹上的所有点,不管是噪声多的状态还是噪声少的状态,最终都要映射到同一张干净图片上。就像同一颗种子,无论它长到幼苗、小树还是大树阶段,它的基因(最终要生成的图片)都是完全一致的,这也是“一致性模型”名字的由来。
我们训练的神经网络fθf_\thetafθ(参数为θ\thetaθ),就是用来近似这个一致性函数,这就是一致性模型的本体。
3.2 边界条件:训练稳定的关键约束
一致性模型必须满足一个硬约束,也就是边界条件:
f(xϵ,ϵ)=xϵf(x_{\epsilon}, \epsilon) = x_{\epsilon}f(xϵ,ϵ)=xϵ
通俗解释:当输入的已经是几乎无噪声的图片(t=ϵ),模型需要直接输出输入本身,无需做任何去噪处理。这个约束是训练稳定性的核心,还能避免模型学到“所有输入都输出0”的平凡无效解。
论文中给出了两种实现边界条件的参数化方式,最终实验验证效果最优的是带跳跃连接的参数化方案,公式如下:
fθ(x,t)=cskip(t)x+cout(t)Fθ(x,t)f_{\theta }(x,t)=c_{skip}(t)x+c_{out}(t)F_{\theta }(x,t)fθ(x,t)=cskip(t)x+cout(t)Fθ(x,t)
符号逐字解释:
- xxx:t时刻的输入带噪图像;
- ttt:当前输入对应的时间步;
- cskip(t)c_{skip}(t)cskip(t):跳跃连接的权重系数,必须满足cskip(ϵ)=1c_{skip}(\epsilon)=1cskip(ϵ)=1,保证t=ϵ时,输入能直接透传到输出;
- cout(t)c_{out}(t)cout(t):网络输出的权重系数,必须满足cout(ϵ)=0c_{out}(\epsilon)=0cout(ϵ)=0,保证t=ϵ时,网络分支的输出被完全清零;
- Fθ(x,t)F_\theta(x,t)Fθ(x,t):自由形式的神经网络(论文中使用NCSN++/EDM架构),无需满足边界条件约束,可以灵活设计网络结构。
通俗解释:这个结构和ResNet的跳跃连接逻辑一致,把输入直接加到输出上,通过两个权重函数控制输入和网络分支的占比。当t很小(几乎无噪声),输入直接透传;当t很大(几乎全是噪声),网络分支主导输出,完美满足边界条件,还能直接复用扩散模型中成熟的网络架构。
为了适配边界条件,论文对EDM的权重函数做了修改,最终使用的公式为:
cskip(t)=σdata2(t−ϵ)2+σdata2c_{skip }(t)=\frac{\sigma_{data }^{2}}{(t-\epsilon)^{2}+\sigma_{data }^{2}}cskip(t)=(t−ϵ)2+σdata2σdata2
cout(t)=σdata(t−ϵ)σdata2+t2c_{out }(t)=\frac{\sigma_{data }(t-\epsilon)}{\sqrt{\sigma_{data }^{2}+t^{2}}}cout(t)=σdata2+t2σdata(t−ϵ)
其中σdata=0.5\sigma_{data}=0.5σdata=0.5,是训练数据的标准差,这两个函数天然满足cskip(ϵ)=1c_{skip}(\epsilon)=1cskip(ϵ)=1、cout(ϵ)=0c_{out}(\epsilon)=0cout(ϵ)=0的边界条件。

图片2 一致性模型的映射逻辑示意图
出处:Consistency Models论文Figure 2
分析:这张图进一步可视化了一致性模型的核心能力——无论输入是轨迹上哪个时间点的带噪图片,模型都能直接映射到轨迹的起点(干净数据),无需迭代求解。
3.3 采样:单步极速生成,多步灵活调优
一致性模型的采样流程,完美实现了“速度与质量的自由权衡”,这也是它相比传统扩散模型的核心优势。
单步采样(核心能力)
单步采样的流程极其简洁,仅需一次网络前向传播:
- 从高斯分布中采样初始噪声:x^T∼N(0,T2I)\hat{x}_T \sim N(0, T^2 I)x^T∼N(0,T2I);
- 执行一次网络前向传播,直接得到生成结果:x^ϵ=fθ(x^T,T)\hat{x}_\epsilon = f_\theta(\hat{x}_T, T)x^ϵ=fθ(x^T,T)。
通俗解释:扩散模型需要走几十步的反向求解路径,一致性模型一步就跨过去了,生成速度直接提升几十上百倍,达到了和GAN一致的单步生成效率。
多步采样(算力换质量)
一致性模型同时保留了扩散模型“用算力换取样本质量”的能力,论文中给出了多步采样算法,通过交替执行“加噪-去噪”的步骤,提升生成样本的精细度。
通俗解释:单步采样是快速抓拍,能直接出片;多步采样是慢慢对焦、调整参数,花更多时间拍出更清晰的照片,用户可以根据自身需求,在生成速度和样本质量之间自由选择。
3.4 零样本数据编辑:意外的超强能力
一致性模型和扩散模型一样,无需针对特定任务做任何微调或额外训练,就能直接完成各类图像编辑任务,包括:图像修复、灰度图上色、超分辨率、图像去噪、隐空间插值、笔画引导生成等。
图片3 一致性模型的零样本图像编辑效果
出处:Consistency Models论文Figure 6
分析:这张图展示了三个核心零样本任务:(a)灰度图上色、(b)低分辨率图超分辨率重建、©用户笔画引导的图像生成。模型仅在LSUN卧室数据集上完成了生成训练,没有针对这些编辑任务做任何微调,就能直接输出高质量结果,完美继承了扩散模型的零样本编辑能力,这是GAN等传统单步生成模型完全无法实现的。
四、训练范式:两条路径都能实现SOTA效果
论文中提出了两种完全独立的训练方式,第一种是基于预训练扩散模型的一致性蒸馏(CD),第二种是完全不依赖扩散模型的一致性训练(CT),这也是一致性模型能成为独立生成范式的核心原因。
4.1 一致性蒸馏(CD):把扩散模型“榨干”成单步模型
这种训练方式的核心,是把一个预训练好的扩散模型(得分网络)的生成能力,通过蒸馏迁移到一致性模型中,最终得到一个单步就能生成高质量样本的模型。
核心逻辑
首先,我们把时间区间[ϵ,T][\epsilon, T][ϵ,T]离散成N个均匀分布的时间点:t1=ϵ<t2<⋯<tN=Tt_1=\epsilon < t_2 < \dots < t_N=Tt1=ϵ<t2<⋯<tN=T,论文中使用的离散化公式为:
ti=(ϵ1/ρ+i−1N−1(T1/ρ−ϵ1/ρ))ρt_{i}=(\epsilon^{1 / \rho}+\frac{i-1}{N-1}(T^{1 / \rho}-\epsilon^{1 / \rho}))^{\rho}ti=(ϵ1/ρ+N−1i−1(T1/ρ−ϵ1/ρ))ρ
其中ρ=7\rho=7ρ=7,和EDM扩散模型的设置保持一致。
对于同一条轨迹上相邻的两个时间点tnt_ntn和tn+1t_{n+1}tn+1,我们可以用预训练的扩散模型,通过ODE求解器,从xtn+1x_{t_{n+1}}xtn+1计算出它在同一条轨迹上对应的x^tnϕ\hat{x}_{t_n}^\phix^tnϕ,公式如下:
x^tnϕ:=xtn+1+(tn−tn+1)Φ(xtn+1,tn+1;ϕ)\hat{x}_{t_{n}}^{\phi}:=x_{t_{n+1}}+\left(t_{n}-t_{n+1}\right) \Phi\left(x_{t_{n+1}}, t_{n+1} ; \phi\right)x^tnϕ:=xtn+1+(tn−tn+1)Φ(xtn+1,tn+1;ϕ)
符号逐字解释:
- Φ(⋅;ϕ)\Phi(\cdot;\phi)Φ(⋅;ϕ):预训练扩散模型的ODE求解器单步更新函数,论文中对比了一阶Euler求解器和二阶Heun求解器,最终验证Heun求解器效果最优;
- ϕ\phiϕ:预训练扩散模型的参数;
- tn−tn+1t_n - t_{n+1}tn−tn+1:时间步的差值,反向求解时该值为负数。
通俗解释:这个公式就是用预训练的扩散模型,把tn+1t_{n+1}tn+1时刻的带噪图片,往前推一步,得到同一条轨迹上tnt_ntn时刻的带噪图片。根据自一致性性质,这两个点在同一条轨迹上,一致性模型对它们的输出应该完全相等,这就是蒸馏损失的核心思想。
一致性蒸馏损失函数
论文中定义的一致性蒸馏损失如下:
LCDN(θ,θ−;ϕ):=E[λ(tn)d(fθ(xtn+1,tn+1),fθ−(x^tnϕ,tn))]\mathcal{L}_{C D}^{N}\left(\theta, \theta^{-} ; \phi\right):= \mathbb{E}\left[\lambda\left(t_{n}\right) d\left(f_{\theta}\left(x_{t_{n+1}}, t_{n+1}\right), f_{\theta^{-}}\left(\hat{x}_{t_{n}}^{\phi}, t_{n}\right)\right)\right]LCDN(θ,θ−;ϕ):=E[λ(tn)d(fθ(xtn+1,tn+1),fθ−(x^tnϕ,tn))]
符号逐字解释:
- E[⋅]\mathbb{E}[\cdot]E[⋅]:数学期望,对所有随机变量取平均;
- 随机变量的采样规则:从训练数据集中采样干净图片x∼pdatax \sim p_{data}x∼pdata,从1到N-1均匀随机采样时间步索引n∼U[1,N−1]n \sim U[1, N-1]n∼U[1,N−1],给干净图片加噪得到xtn+1∼N(x,tn+12I)x_{t_{n+1}} \sim N(x, t_{n+1}^2 I)xtn+1∼N(x,tn+12I);
- λ(tn)\lambda(t_n)λ(tn):正的权重函数,论文实验验证恒等于1时效果最优;
- d(⋅,⋅)d(\cdot, \cdot)d(⋅,⋅):度量函数,用于衡量两个输出的差异,论文中测试了L2距离、L1距离、LPIPS感知距离,最终验证LPIPS效果最优;
- fθf_\thetafθ:在线网络,即我们正在训练的一致性模型,参数为θ\thetaθ;
- fθ−f_{\theta^-}fθ−:目标网络,是在线网络参数的指数移动平均(EMA),参数θ−\theta^-θ−不会通过梯度下降更新,仅通过EMA平滑更新。
通俗解释:这个损失函数的目标,是让模型对同一条轨迹上相邻两个点的输出尽可能保持一致。使用EMA目标网络而非直接用在线网络,是为了大幅稳定训练过程,避免模型训练崩溃,这和强化学习中的DQN、对比学习中的MoCo核心思路一致。
论文中还通过定理1给出了理论保障:当损失为0时,一致性模型的输出和真实一致性函数的误差,与ODE求解器的局部误差是同阶的,为蒸馏训练提供了严格的理论支撑。
4.2 一致性训练(CT):从零开始的独立生成范式
这是论文最具突破性的贡献:一致性模型完全可以不依赖任何预训练扩散模型,直接从数据集中从零开始训练,成为一个完全独立的生成模型家族,彻底摆脱了对扩散模型的依赖。
核心理论支撑
论文中的定理2证明:当使用Euler ODE求解器,且预训练得分模型完全拟合真实得分函数时,一致性蒸馏损失可以等价为一个完全不依赖扩散模型的损失函数,仅和训练数据本身相关。
这里的核心是一个得分函数的无偏估计器(论文Lemma 1):
∇logpt(xt)=−E[xt−xt2∣xt]\nabla log p_{t}\left(x_{t}\right)=-\mathbb{E}\left[ \frac {x_{t}-x}{t^{2}} | x_{t}\right]∇logpt(xt)=−E[t2xt−x∣xt]
符号解释:
- xxx:干净的原始数据,xtx_txt是给x加噪得到的t时刻带噪数据。
通俗解释:我们无需训练得分网络,直接用干净图片和带噪图片的差值,就能无偏估计出得分函数的值,这就是一致性模型能摆脱扩散模型的核心!
基于这个结论,论文推导出了完全不依赖扩散模型的一致性训练损失。
一致性训练损失函数
LCTN(θ,θ−):=E[λ(tn)d(fθ(x+tn+1z,tn+1),fθ−(x+tnz,tn))]\mathcal{L}_{C T}^{N}\left(\theta, \theta^{-}\right) := \mathbb{E}\left[\lambda (t_{n})d(f_{\theta }(x+t_{n+1}z,t_{n+1}),f_{\theta ^{-}}(x+t_{n}z,t_{n}))\right]LCTN(θ,θ−):=E[λ(tn)d(fθ(x+tn+1z,tn+1),fθ−(x+tnz,tn))]
符号逐字解释:
- z∼N(0,I)z \sim N(0, I)z∼N(0,I):标准高斯噪声;
- x+tn+1zx + t_{n+1} zx+tn+1z:给干净图片x加噪,得到t_{n+1}时刻的带噪图片;
- x+tnzx + t_n zx+tnz:使用同一个噪声z,给干净图片x加噪,得到同一条轨迹上t_n时刻的带噪图片;
- 其余符号与蒸馏损失完全一致。
通俗解释:这个设计堪称精妙!用同一个噪声z,给同一张干净图片加不同强度的噪声,得到的两个带噪图片,天然就处于PF ODE的同一条轨迹上!因此我们根本不需要预训练扩散模型来计算相邻轨迹点,直接用同一个噪声加不同强度的噪,就能得到符合自一致性要求的点对,然后让模型对它们的输出保持一致即可,完全摆脱了对扩散模型的依赖。
训练关键技巧:自适应调度函数
论文中发现,训练过程中使用自适应调度函数,逐步增加离散时间步的数量N和EMA衰减率μ\muμ,能极大提升训练的收敛速度和最终效果。
核心原因是:训练初期N较小,时间步间隔大,损失的方差小、偏差大,能让模型快速收敛;训练后期N变大,时间步间隔小,损失的方差大、偏差小,能让模型学到更精准的一致性函数。
论文中使用的调度函数如下:
N(k)=[kK((s1+1)2−s02)+s02−1]+1N(k)=\left[\sqrt{\frac{k}{K}\left(\left(s_{1}+1\right)^{2}-s_{0}^{2}\right)+s_{0}^{2}}-1\right]+1N(k)=[Kk((s1+1)2−s02)+s02−1]+1
μ(k)=exp(s0logμ0N(k))\mu(k)=exp \left(\frac{s_{0} log \mu_{0}}{N(k)}\right)μ(k)=exp(N(k)s0logμ0)
符号逐字解释:
- kkk:当前训练迭代次数;
- KKK:总训练迭代次数;
- s0s_0s0:初始的离散步数,s1s_1s1:训练结束时的目标离散步数;
- μ0\mu_0μ0:训练初始的EMA衰减率。
五、实验结果:碾压级的效果提升
论文在CIFAR-10、ImageNet 64×64、LSUN 256×256等主流生成模型基准数据集上完成了全面的消融实验和对比实验,核心评价指标为FID(Fréchet Inception Distance,数值越低代表生成质量越好)、IS(Inception Score,数值越高代表生成质量越好)、Precision/Recall。
5.1 超参消融实验
论文首先通过消融实验,验证了不同超参数对模型效果的影响,结果如下:
图片4 不同超参对一致性模型训练效果的影响
出处:Consistency Models论文Figure 3
结果分析:
- 度量函数对比(图a):LPIPS感知距离的效果远超L1和L2距离,因为LPIPS更贴合人眼对图像相似度的判断,这也是论文最终选择LPIPS作为核心度量函数的原因;
- ODE求解器对比(图b):二阶Heun求解器的效果远超一阶Euler求解器,和定理1的理论结论完全一致——更高阶的ODE求解器局部误差更小,最终一致性模型的精度更高;
- 离散步数N对比(图c):当N≥18之后,模型效果基本趋于稳定,无需设置过大的N,完美平衡了训练成本和最终效果;
- 自适应调度对比(图d):带自适应N和μ的训练方式,收敛速度和最终效果都远超固定参数的训练,验证了调度函数的有效性。
5.2 少步生成效果对比
论文将一致性蒸馏(CD)与当时SOTA的扩散模型蒸馏方法——渐进蒸馏(PD)做了全面对比,核心结果如下。
表格1 CIFAR-10数据集上的样本质量对比
出处:Consistency Models论文Table 1
| 方法 | 网络前向次数(NFE) | FID ↓ | IS ↑ |
|---|---|---|---|
| 渐进蒸馏(PD) | 1 | 8.34 | 8.69 |
| 一致性蒸馏(CD) | 1 | 3.55 | 9.48 |
| 渐进蒸馏(PD) | 2 | 5.58 | 9.05 |
| 一致性蒸馏(CD) | 2 | 2.93 | 9.75 |
结果分析:在CIFAR-10数据集上,单步生成的CD方法,FID直接从PD的8.34降至3.55,IS也显著更高,直接刷新了当时扩散模型蒸馏的SOTA记录;两步生成的CD,FID低至2.93,已经非常接近原始EDM扩散模型35步生成的效果(FID 2.04)。
表格2 ImageNet 64×64、LSUN数据集上的样本质量对比
出处:Consistency Models论文Table 2
| 数据集 | 方法 | 网络前向次数(NFE) | FID ↓ |
|---|---|---|---|
| ImageNet 64×64 | 渐进蒸馏(PD) | 1 | 15.39 |
| ImageNet 64×64 | 一致性蒸馏(CD) | 1 | 6.20 |
| ImageNet 64×64 | 渐进蒸馏(PD) | 2 | 8.95 |
| ImageNet 64×64 | 一致性蒸馏(CD) | 2 | 4.70 |
| LSUN Bedroom 256×256 | 渐进蒸馏(PD) | 1 | 16.92 |
| LSUN Bedroom 256×256 | 一致性蒸馏(CD) | 1 | 7.80 |
| LSUN Cat 256×256 | 渐进蒸馏(PD) | 1 | 29.6 |
| LSUN Cat 256×256 | 一致性蒸馏(CD) | 1 | 11.0 |
结果分析:在更高分辨率、更复杂的数据集上,CD的优势更加显著。ImageNet数据集上单步FID从15.39降至6.20,LSUN Bedroom和Cat数据集上,单步FID相比PD降低了一半以上,实现了碾压级的效果提升。
而独立训练的CT方法,在CIFAR-10数据集上单步FID达到8.70,远超VAE、流模型等其他非对抗单步生成模型,甚至和渐进蒸馏(PD)的单步效果相当,而CT完全没有用到预训练扩散模型,实现了零依赖下的SOTA效果。
六、核心代码实现
以下基于PyTorch实现一致性模型的核心模块,完全贴合论文的算法逻辑,包括带边界条件的模型参数化、单步/多步采样、一致性蒸馏/训练损失、EMA更新等核心功能。
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import vgg16
# -------------------------- 1. LPIPS度量函数(论文中效果最优的度量方式) --------------------------
class LPIPS(nn.Module):
def __init__(self):
super().__init__()
vgg = vgg16(pretrained=True).features
self.slice1 = nn.Sequential(*vgg[:4])
self.slice2 = nn.Sequential(*vgg[4:9])
self.slice3 = nn.Sequential(*vgg[9:16])
self.slice4 = nn.Sequential(*vgg[16:23])
for param in self.parameters():
param.requires_grad = False
# 特征线性层
self.lin0 = nn.Conv2d(64, 1, 1, bias=False)
self.lin1 = nn.Conv2d(128, 1, 1, bias=False)
self.lin2 = nn.Conv2d(256, 1, 1, bias=False)
self.lin3 = nn.Conv2d(512, 1, 1, bias=False)
def forward(self, x, y):
feats_x = self.get_features(x)
feats_y = self.get_features(y)
diff = 0
for fx, fy, lin in zip(feats_x, feats_y, [self.lin0, self.lin1, self.lin2, self.lin3]):
fx = fx / torch.norm(fx, dim=1, keepdim=True)
fy = fy / torch.norm(fy, dim=1, keepdim=True)
diff += lin((fx - fy) ** 2).mean()
return diff
def get_features(self, x):
h = self.slice1(x)
h1 = h
h = self.slice2(h)
h2 = h
h = self.slice3(h)
h3 = h
h = self.slice4(h)
h4 = h
return [h1, h2, h3, h4]
# -------------------------- 2. 一致性模型核心网络(满足边界条件的参数化) --------------------------
class UNetBlock(nn.Module):
def __init__(self, in_channels, out_channels, time_emb_dim):
super().__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
self.time_mlp = nn.Linear(time_emb_dim, out_channels)
self.res_conv = nn.Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()
def forward(self, x, t_emb):
h = F.silu(self.conv1(x))
# 注入时间嵌入
t_emb = F.silu(self.time_mlp(t_emb))
h = h + t_emb[..., None, None]
h = F.silu(self.conv2(h))
return h + self.res_conv(x)
class ConsistencyModel(nn.Module):
def __init__(self, in_channels=3, base_channels=64, time_emb_dim=256, eps=0.002, sigma_data=0.5):
super().__init__()
self.eps = eps
self.sigma_data = sigma_data
self.time_emb_dim = time_emb_dim
# 时间嵌入层
self.time_mlp = nn.Sequential(
nn.Linear(1, time_emb_dim),
nn.SiLU(),
nn.Linear(time_emb_dim, time_emb_dim)
)
# UNet编码器
self.down_blocks = nn.ModuleList([
UNetBlock(in_channels, base_channels, time_emb_dim),
UNetBlock(base_channels, base_channels*2, time_emb_dim),
UNetBlock(base_channels*2, base_channels*4, time_emb_dim),
])
self.downsample = nn.MaxPool2d(2)
# 中间层
self.mid_block = UNetBlock(base_channels*4, base_channels*4, time_emb_dim)
# UNet解码器
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.up_blocks = nn.ModuleList([
UNetBlock(base_channels*8, base_channels*2, time_emb_dim),
UNetBlock(base_channels*4, base_channels, time_emb_dim),
UNetBlock(base_channels*2, base_channels, time_emb_dim),
])
# 自由形式的F_theta输出头
self.final_conv = nn.Conv2d(base_channels, in_channels, 3, padding=1)
# 论文中满足边界条件的c_skip权重函数
def c_skip(self, t):
return self.sigma_data ** 2 / ((t - self.eps) ** 2 + self.sigma_data ** 2)
# 论文中满足边界条件的c_out权重函数
def c_out(self, t):
return self.sigma_data * (t - self.eps) / torch.sqrt(self.sigma_data ** 2 + t ** 2)
# 时间嵌入计算
def get_time_embedding(self, t):
return self.time_mlp(t)
# 一致性模型前向传播核心逻辑
def forward(self, x, t):
"""
x: 输入带噪图像 [batch_size, channels, height, width]
t: 时间步 [batch_size, 1]
return: 去噪后的图像
"""
batch_size = x.shape[0]
# 计算边界条件权重
c_skip = self.c_skip(t).view(batch_size, 1, 1, 1)
c_out = self.c_out(t).view(batch_size, 1, 1, 1)
# 时间嵌入
t_emb = self.get_time_embedding(t)
# 编码器前向
skips = []
h = x
for block in self.down_blocks:
h = block(h, t_emb)
skips.append(h)
h = self.downsample(h)
# 中间层
h = self.mid_block(h, t_emb)
# 解码器前向
for block in self.up_blocks:
h = self.upsample(h)
skip = skips.pop()
h = torch.cat([h, skip], dim=1)
h = block(h, t_emb)
# 自由网络F_theta输出
F_theta = self.final_conv(h)
# 最终一致性模型输出,满足边界条件
f_theta = c_skip * x + c_out * F_theta
return f_theta
# -------------------------- 3. 采样函数:单步+多步采样 --------------------------
@torch.no_grad()
def single_step_sample(model, batch_size, image_size=32, T=80, device='cuda'):
"""单步采样,论文核心能力"""
model.eval()
# 采样初始噪声
x_T = torch.randn(batch_size, 3, image_size, image_size, device=device) * T
t = torch.ones(batch_size, 1, device=device) * T
# 一次前向传播直接生成结果
x_0 = model(x_T, t)
# 归一化到[0,1]
x_0 = (x_0.clamp(-1, 1) + 1) / 2
return x_0
@torch.no_grad()
def multi_step_sample(model, batch_size, image_size=32, T=80, eps=0.002, time_steps=[40, 20, 10], device='cuda'):
"""多步采样,对应论文Algorithm 1,算力换质量"""
model.eval()
# 初始噪声
x = torch.randn(batch_size, 3, image_size, image_size, device=device) * T
t = torch.ones(batch_size, 1, device=device) * T
# 第一步去噪
x = model(x, t)
# 多步迭代
for tau in time_steps:
# 加噪
z = torch.randn_like(x)
x_tau = x + torch.sqrt(tau ** 2 - eps ** 2) * z
# 去噪
t_tau = torch.ones(batch_size, 1, device=device) * tau
x = model(x_tau, t_tau)
# 归一化到[0,1]
x = (x.clamp(-1, 1) + 1) / 2
return x
# -------------------------- 4. 训练损失函数 --------------------------
def consistency_distillation_loss(online_model, target_model, x, score_model, t_list, N=18, lpips_fn=None, device='cuda'):
"""一致性蒸馏损失,对应论文Algorithm 2"""
batch_size = x.shape[0]
# 随机采样时间步n
n = torch.randint(0, N-1, (batch_size,), device=device)
t_n = t_list[n].view(batch_size, 1)
t_n1 = t_list[n+1].view(batch_size, 1)
# 生成t_{n+1}时刻的带噪图片
z = torch.randn_like(x)
x_tn1 = x + t_n1 * z
# Heun二阶求解器计算x_tn^phi,论文中效果最优
# 第一步:Euler预测步
s_tn1 = score_model(x_tn1, t_n1)
x_tn_euler = x_tn1 + (t_n - t_n1) * (-t_n1 * s_tn1)
# 第二步:Heun修正步
s_tn_euler = score_model(x_tn_euler, t_n)
x_tn_phi = x_tn1 + (t_n - t_n1) * 0.5 * (-t_n1 * s_tn1 - t_n * s_tn_euler)
# 计算模型输出
f_online = online_model(x_tn1, t_n1)
f_target = target_model(x_tn_phi, t_n)
# 计算损失
if lpips_fn is not None:
loss = lpips_fn(f_online, f_target).mean()
else:
loss = F.mse_loss(f_online, f_target)
return loss
def consistency_training_loss(online_model, target_model, x, t_list, N=120, lpips_fn=None, device='cuda'):
"""一致性训练损失,对应论文Algorithm 3,无预训练扩散模型依赖"""
batch_size = x.shape[0]
# 随机采样时间步n
n = torch.randint(0, N-1, (batch_size,), device=device)
t_n = t_list[n].view(batch_size, 1)
t_n1 = t_list[n+1].view(batch_size, 1)
# 同一个噪声z生成同轨迹的两个带噪图片
z = torch.randn_like(x)
x_tn1 = x + t_n1 * z
x_tn = x + t_n * z
# 计算模型输出
f_online = online_model(x_tn1, t_n1)
f_target = target_model(x_tn, t_n)
# 计算损失
if lpips_fn is not None:
loss = lpips_fn(f_online, f_target).mean()
else:
loss = F.mse_loss(f_online, f_target)
return loss
# -------------------------- 5. EMA目标网络更新 --------------------------
def update_ema(target_model, online_model, mu=0.9999):
"""指数移动平均更新目标网络,对应论文EMA更新公式"""
with torch.no_grad():
for target_param, online_param in zip(target_model.parameters(), online_model.parameters()):
target_param.data = mu * target_param.data + (1 - mu) * online_param.data
七、总结与展望
这篇论文提出的一致性模型,彻底解决了扩散模型生成速度慢的核心痛点,同时完整保留了扩散模型的高质量生成、零样本编辑、训练稳定无模式崩溃等核心优势,还开辟了两种全新的训练范式:
- 作为扩散模型的蒸馏方法,一致性蒸馏刷新了当时少步生成的SOTA,在多个基准数据集上实现了碾压级的效果提升,将扩散模型的生成速度提升了几十上百倍;
- 作为独立的生成模型家族,一致性训练完全摆脱了对扩散模型、对抗训练的依赖,成为了一种全新的非对抗单步生成范式,效果远超VAE、流模型等传统单步生成方法。
更重要的是,一致性模型的核心思想,为后续的生成模型研究打开了全新的思路,后续的Rectified Flow、LCM等主流快速生成模型,都借鉴了一致性模型的核心思想,进一步推动了AIGC领域实时生成场景的落地。
论文中也指出了未来的研究方向:连续时间训练的方差优化、更高阶ODE求解器的适配、端到端的编辑任务训练、视频/音频等多模态生成的拓展等,这些方向也都成为了后续生成模型领域的研究热点。
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐

所有评论(0)