Latent Consistency Models:潜空间一致性模型,实现高分辨率图像的少步极速生成
论文信息
- 标题:LATENT CONSISTENCY MODELS: SYNTHESIZING HIGH-RESOLUTION IMAGES WITH FEW-STEP INFERENCE
- 会议:arXiv:2310.04378
- 单位:清华大学交叉信息研究院
- 代码:https://github.com/luosiallen/latent-consistency-model
- 论文:https://arxiv.org/pdf/2310.04378.pdf
开篇:扩散模型的“速度困局”与破局方案
以Stable Diffusion为代表的潜扩散模型(LDM),早已成为AI图像生成的绝对主流,768×768甚至更高分辨率的写实、创意图像生成都不在话下。但用过的开发者都懂它的致命短板:慢。
传统扩散模型的生成过程,就像爬楼梯,必须从顶层的纯噪声开始,一步一步往下走,20~50步迭代才能得到一张清晰的图片;哪怕是DPM-Solver++这类优化后的ODE求解器,也至少要10步才能保证画质不崩。想做实时生成?比如用户输入文案立刻出图,传统扩散模型根本做不到。
行业里也不是没试过解决这个问题:
- 训练-free的加速方法:靠优化ODE求解器提速,但步数低于10步就会出现画面崩坏、结构错乱;
- 蒸馏方法:把预训练扩散模型蒸馏成少步模型,但经典的Guided-Distill需要两阶段蒸馏,光2步模型就要训45个A100天,成本高到离谱,还会有两阶段的误差累积;
- 一致性模型(CM):2023年提出的新范式,能实现单步生成,但只支持像素空间的小图生成(最高256×256),不支持文本条件生成,也没解决扩散模型里核心的无分类器引导(CFG)问题,根本没法用在Stable Diffusion这类工业级模型上。
而这篇论文提出的Latent Consistency Models(潜空间一致性模型,简称LCM),直接把所有问题一锅端了:
- 把一致性模型搬进了Stable Diffusion的潜空间,完美适配高分辨率文本-图像生成;
- 提出单阶段引导蒸馏,把CFG无缝融入蒸馏过程,告别两阶段蒸馏的高成本和误差累积;
- 发明SKIPPING-STEP跳步技术,让蒸馏收敛速度提升数十倍,768×768分辨率的2~4步模型,只需要32个A100小时就能训完;
- 设计潜空间一致性微调(LCF),自定义数据集微调也能保留少步生成能力,不用重新蒸馏。
最终效果有多夸张?LCM能在2~4步内生成和原版SD 50步效果相当的768×768高清图像,甚至1步就能生成结构完整、细节丰富的画面,在LAION数据集上的FID、CLIP Score全面吊打同期所有少步生成方法。
图1 LCM不同步数的生成效果(出处:原论文Figure 1)
图中展示了CFG尺度8.0下,LCM的1/2/4步推理生成的768×768图像。可以看到,4步生成的图像细节拉满、光影自然、文本对齐度极高;2步生成依然保持了优秀的画质和结构完整性;即便是1步生成,也能输出主体清晰、构图合理的图像,这是传统扩散模型在10步以内都无法实现的效果。
前置知识:先搞懂核心基础概念
在拆解LCM的核心方法前,我们先把底层概念掰碎了讲,每个公式、每个专业术语都配上大白话解释,新手也能完全看懂。
3.1 扩散模型与概率流ODE(PF-ODE)
扩散模型的核心,是先给真实图片一步步加高斯噪声,直到变成纯噪声(前向过程),再让模型学逆过程——从纯噪声一步步去噪,还原出真实图片。
从数学上,前向加噪过程可以用一个随机微分方程(SDE)描述:
dxt=f(t)xt dt+g(t)dwtd x_{t}=f(t) x_{t} ~d t+g(t) d w_{t}dxt=f(t)xt dt+g(t)dwt
x0∼pdata(x0)x_{0} \sim p_{data}(x_{0})x0∼pdata(x0)
| 符号 | 数学含义 | 通俗解释 |
|---|---|---|
| xtx_txt | t时刻的加噪数据(图像/潜变量) | 加了t步噪声的图片,t=0是真实图片,t=T是纯噪声 |
| f(t)f(t)f(t) | 漂移系数,由噪声调度决定 | 控制数据均值随时间的变化率,相当于加噪时的“均值偏移速度” |
| g(t)g(t)g(t) | 扩散系数,由噪声调度决定 | 控制噪声的添加幅度,相当于加噪时的“噪声注入速度” |
| dwtdw_tdwt | 标准布朗运动(维纳过程) | 随机高斯噪声的数学表达,是扩散过程随机性的来源 |
| pdata(x0)p_{data}(x_0)pdata(x0) | 真实数据的分布 | 我们训练用的真实图片数据集的分布 |
扩散模型的逆过程,既可以用SDE求解,也可以用确定性的常微分方程(ODE)求解,这个ODE就是概率流ODE(PF-ODE),它和前向SDE有完全一样的边缘分布,也是LCM和一致性模型的核心基础。
PF-ODE的数学表达式为:
dxtdt=f(t)xt−12g2(t)∇xlogqt(xt)\frac{d x_{t}}{ d t}=f(t) x_{t}-\frac{1}{2} g^{2}(t) \nabla_{x} log q_{t}\left(x_{t}\right)dtdxt=f(t)xt−21g2(t)∇xlogqt(xt)
xT∼qT(xT)x_{T} \sim q_{T}(x_{T})xT∼qT(xT)
| 符号 | 数学含义 | 通俗解释 |
|---|---|---|
| dxtdt\frac{d x_t}{dt}dtdxt | x_t对时间t的导数 | 数据点在PF-ODE轨迹上的移动速度和方向 |
| ∇xlogqt(xt)\nabla_x log q_t(x_t)∇xlogqt(xt) | t时刻数据分布的对数概率梯度,也叫得分函数 | 扩散模型要学的核心目标,告诉模型“该往哪个方向去噪” |
| qT(xT)q_T(x_T)qT(xT) | t=T时刻的分布,通常是标准高斯分布 | 生成的起点,也就是输入的纯噪声 |
扩散模型里,我们训练的噪声预测网络ϵθ(xt,t)\epsilon_\theta(x_t, t)ϵθ(xt,t),其实就是在拟合缩放后的得分函数:−∇logqt(xt)=ϵθ(xt,t)σt-\nabla log q_t(x_t) = \frac{\epsilon_\theta(x_t, t)}{\sigma_t}−∇logqt(xt)=σtϵθ(xt,t)。把它代入PF-ODE,就得到了我们实际用来采样的经验PF-ODE:
dxtdt=f(t)xt+g2(t)2σtϵθ(xt,t)\frac{d x_{t}}{ d t}=f(t) x_{t}+\frac{g^{2}(t)}{2 \sigma_{t}} \epsilon_{\theta}\left(x_{t}, t\right)dtdxt=f(t)xt+2σtg2(t)ϵθ(xt,t)
通俗来说,PF-ODE描述了一条从纯噪声(t=T)到真实图片(t=0)的平滑轨迹,扩散模型的采样过程,就是用数值求解器一步步沿着这条轨迹走,从起点走到终点。而一致性模型和LCM的核心,就是学会“从轨迹上任意一点,直接跳到终点”,不用一步步走。
3.2 无分类器引导(CFG)
CFG是文本-图像扩散模型里的核心技术,能大幅提升生成图片和输入文本的对齐度,几乎所有工业级Stable Diffusion模型都离不开它。
CFG的核心逻辑,是把**条件噪声预测(带文本提示)和无条件噪声预测(空文本)**做线性加权,公式如下:
ϵ~θ(zt,ω,c,t):=(1+ω)ϵθ(zt,c,t)−ωϵθ(zt,ø,t)\tilde{\epsilon}_{\theta}\left( z_{t},\omega ,c,t\right) :=(1+\omega )\epsilon _{\theta }\left( z_{t},c,t\right) -\omega \epsilon _{\theta }\left( z_{t},ø ,t\right)ϵ~θ(zt,ω,c,t):=(1+ω)ϵθ(zt,c,t)−ωϵθ(zt,ø,t)
| 符号 | 数学含义 | 通俗解释 |
|---|---|---|
| ϵ~θ\tilde{\epsilon}_\thetaϵ~θ | CFG加权后的最终噪声预测 | 模型最终用来去噪的目标值 |
| ω\omegaω | CFG引导尺度,通常取2~14 | 控制生成的“文本贴合度”,值越大,模型越听文本的话,但可能牺牲多样性 |
| ccc | 条件信息,也就是文本提示词的embedding | 用户输入的生成文案 |
| øøø | 空条件,也就是空文本的embedding | 无任何提示词的条件 |
| ϵθ(zt,c,t)\epsilon_\theta(z_t,c,t)ϵθ(zt,c,t) | 条件噪声预测 | 带文本提示时,模型预测的噪声 |
| ϵθ(zt,ø,t)\epsilon_\theta(z_t,ø,t)ϵθ(zt,ø,t) | 无条件噪声预测 | 不带文本提示时,模型预测的噪声 |
CFG效果虽好,但给少步蒸馏带来了巨大麻烦:之前的蒸馏方法需要先训一个适配CFG的模型,再做少步蒸馏,两阶段流程成本极高,还会累积误差。这也是LCM要解决的核心痛点之一。
3.3 一致性模型(CM)
一致性模型是LCM的理论基础,它的核心思想是自一致性性质:同一条PF-ODE轨迹上的任意一个点,都能映射到同一个终点(也就是t=0的真实图片)。
我们定义一致性函数 fθ(xt,t)↦x0f_\theta(x_t, t) \mapsto x_0fθ(xt,t)↦x0,它的作用是:给轨迹上任意t时刻的点xtx_txt,直接预测出它对应的终点x0x_0x0。自一致性性质就是:同一条轨迹上的所有点,经过一致性函数映射后,得到的结果必须完全一致。
为了保证t=0时,一致性函数输出就是输入本身,一致性函数的参数化形式为:
fθ(x,t)=cskip(t)x+cout(t)Fθ(x,t)f_{\theta}(x, t)=c_{skip }(t) x+c_{out }(t) F_{\theta}(x, t)fθ(x,t)=cskip(t)x+cout(t)Fθ(x,t)
| 符号 | 数学含义 | 通俗解释 |
|---|---|---|
| cskip(t)c_{skip}(t)cskip(t) | 跳跃连接系数,满足cskip(ϵ)=1c_{skip}(\epsilon)=1cskip(ϵ)=1 | 保留输入本身的权重,t趋近于0时,这个系数趋近于1,保证输入等于输出 |
| cout(t)c_{out}(t)cout(t) | 输出系数,满足cout(ϵ)=0c_{out}(\epsilon)=0cout(ϵ)=0 | 控制网络输出的权重,t趋近于0时,这个系数趋近于0,消除网络的影响 |
| Fθ(x,t)F_\theta(x,t)Fθ(x,t) | 我们训练的主干神经网络(UNet) | 核心学习模块,用来预测去噪后的图片 |
一致性模型的训练,靠的是一致性蒸馏:先有一个训好的教师扩散模型,用它的PF-ODE生成轨迹上的相邻点,让一致性函数对这两个点的预测结果保持一致,损失函数为:
L(θ,θ−;Φ)=Ex,t[d(fθ(xtn+1,tn+1),fθ−(x^tnϕ,tn))]\mathcal{L}\left(\theta, \theta^{-} ; \Phi\right)=\mathbb{E}_{x, t}\left[d\left(f_{\theta}\left(x_{t_{n+1}}, t_{n+1}\right), f_{\theta^{-}}\left(\hat{x}_{t_{n}}^{\phi}, t_{n}\right)\right)\right]L(θ,θ−;Φ)=Ex,t[d(fθ(xtn+1,tn+1),fθ−(x^tnϕ,tn))]
| 符号 | 数学含义 | 通俗解释 |
|---|---|---|
| θ\thetaθ | 学生模型的可学习参数 | 我们正在训练的一致性模型参数 |
| θ−\theta^-θ− | 目标模型的参数,由θ\thetaθ的EMA更新得到 | 学生模型的“平滑版本”,用来稳定训练,避免震荡 |
| Φ\PhiΦ | 教师模型的ODE数值求解器 | 用来从xtn+1x_{t_{n+1}}xtn+1估算出同轨迹上的xtnx_{t_n}xtn |
| x^tnϕ\hat{x}_{t_n}^\phix^tnϕ | 用ODE求解器从xtn+1x_{t_{n+1}}xtn+1一步估算出的xtnx_{t_n}xtn | 同一条轨迹上,前一个时间步的估算值 |
| d(⋅,⋅)d(\cdot,\cdot)d(⋅,⋅) | 距离度量函数,通常用L2均方误差 | 衡量两个预测结果的一致性,值越小,一致性越好 |
但原版一致性模型有三个致命缺陷:只支持像素空间、不支持CFG条件引导、不适配Stable Diffusion的潜空间和1000步长调度,这就是LCM要解决的全部问题。
核心方法:LCM的四大创新设计
LCM的核心目标,是把一致性模型完美适配到Stable Diffusion这类潜扩散模型上,实现高分辨率文本-图像的少步生成。下面我们拆解它的四大核心创新,每个公式都做完整的符号解释和通俗解读。
4.1 潜空间一致性蒸馏(LCD)
原版一致性模型只能在像素空间训小图,而Stable Diffusion的核心优势,就是用VAE把高分辨率图片压缩到低维潜空间,大幅降低计算量。LCM的第一步,就是把一致性蒸馏从像素空间搬到潜空间,也就是潜空间一致性蒸馏(LCD)。
首先,Stable Diffusion里,我们先训好一个VAE,编码器E(x)E(x)E(x)把图片xxx压缩成潜向量z=E(x)z=E(x)z=E(x),解码器D(z)D(z)D(z)把潜向量还原成图片x^=D(z)\hat{x}=D(z)x^=D(z)。扩散模型的前向、反向过程,全部在潜空间zzz上进行,对应的PF-ODE也在潜空间定义:
dztdt=f(t)zt+g2(t)2σtϵθ(zt,c,t)\frac{d z_{t}}{ d t}=f(t) z_{t}+\frac{g^{2}(t)}{2 \sigma_{t}} \epsilon_{\theta}\left(z_{t}, c, t\right)dtdzt=f(t)zt+2σtg2(t)ϵθ(zt,c,t)
zT∼N(0,σ‾2I)z_{T} \sim \mathcal{N}\left(0, \overline{\sigma}^{2} I\right)zT∼N(0,σ2I)
| 符号 | 数学含义 | 通俗解释 |
|---|---|---|
| ztz_tzt | t时刻的潜向量 | 潜空间里的加噪数据,对应像素空间的xtx_txt |
| ccc | 文本条件的embedding | 输入提示词的编码,用来做条件生成 |
| ϵθ(zt,c,t)\epsilon_\theta(z_t,c,t)ϵθ(zt,c,t) | 潜空间的条件噪声预测网络,也就是Stable Diffusion的UNet | 教师模型,我们要蒸馏的预训练SD模型 |
接下来,我们定义潜空间一致性函数,它接收潜向量ztz_tzt、文本条件ccc、时间步ttt,直接预测轨迹的终点z0z_0z0:
fθ(z,c,t)=cskip(t)z+cout(t)(z−σtϵ^θ(z,c,t)αt)f_{\theta}(z, c, t)=c_{skip }(t) z+c_{out }(t)\left(\frac{z-\sigma_{t} \hat{\epsilon}_{\theta}(z, c, t)}{\alpha_{t}}\right)fθ(z,c,t)=cskip(t)z+cout(t)(αtz−σtϵ^θ(z,c,t))
这个参数化方式叫ϵ-Prediction,和原版Stable Diffusion的参数化完全对齐,能直接用预训练SD的权重初始化,大幅降低训练成本。
| 符号 | 数学含义 | 通俗解释 |
|---|---|---|
| ϵ^θ(z,c,t)\hat{\epsilon}_\theta(z,c,t)ϵ^θ(z,c,t) | 我们训练的潜空间噪声预测UNet,用预训练SD的权重初始化 | LCM的主干网络,和SD的UNet结构完全兼容 |
| αt\alpha_tαt、σt\sigma_tσt | 扩散模型的噪声调度系数,满足αt2+σt2=1\alpha_t^2 + \sigma_t^2 = 1αt2+σt2=1 | 预训练SD的固定噪声调度,不用额外修改 |
| z−σtϵ^θ(z,c,t)αt\frac{z-\sigma_{t} \hat{\epsilon}_{\theta}(z, c, t)}{\alpha_{t}}αtz−σtϵ^θ(z,c,t) | 由预测噪声还原出的z0z_0z0 | SD里经典的“去噪公式”,从加噪潜向量和预测噪声,还原出干净潜向量 |
有了一致性函数,我们就能定义潜空间一致性蒸馏损失,和原版一致性蒸馏的逻辑一致,但完全适配潜空间和条件生成:
LCD(θ,θ−;Ψ)=Ez,c,n[d(fθ(ztn+1,c,tn+1),fθ−(z^tnΨ,c,tn))]\mathcal{L}_{\mathcal{C D}}\left(\theta, \theta^{-} ; \Psi\right)=\mathbb{E}_{z, c, n}\left[d\left(f_{\theta}\left(z_{t_{n+1}}, c, t_{n+1}\right), f_{\theta^{-}}\left(\hat{z}_{t_{n}}^{\Psi}, c, t_{n}\right)\right)\right]LCD(θ,θ−;Ψ)=Ez,c,n[d(fθ(ztn+1,c,tn+1),fθ−(z^tnΨ,c,tn))]
这里的Ψ\PsiΨ是潜空间PF-ODE的数值求解器,我们可以用DDIM、DPM-Solver、DPM-Solver++,作用是从ztn+1z_{t_{n+1}}ztn+1一步估算出同轨迹上的ztnz_{t_n}ztn,公式为:
z^tnΨ−ztn+1=∫tn+1tn(f(t)zt+g2(t)2σtϵθ(zt,c,t))dt≈Ψ(ztn+1,tn+1,tn,c)\hat {z}_{t_{n}}^{\Psi }-z_{t_{n+1}}=\int _{t_{n+1}}^{t_{n}}\left( f(t)z_{t}+\frac {g^{2}(t)}{2\sigma _{t}}\epsilon _{\theta }\left( z_{t},c,t\right) \right) d{t}\approx \Psi (z_{t_{n+1}},t_{n+1},t_{n},c)z^tnΨ−ztn+1=∫tn+1tn(f(t)zt+2σtg2(t)ϵθ(zt,c,t))dt≈Ψ(ztn+1,tn+1,tn,c)
通俗来说,潜空间一致性蒸馏的逻辑就是:用预训练SD当老师,生成潜空间里的PF-ODE轨迹,让LCM学会“轨迹上任意一点都能直接跳到终点”,同时保证同一条轨迹的预测结果完全一致。
4.2 单阶段引导蒸馏:解决CFG的核心难题
CFG是文本生成的灵魂,但原版一致性蒸馏根本不支持CFG。之前的Guided-Distill方法,需要先训一个“单步CFG模型”,再做一致性蒸馏,两阶段流程不仅成本高,还会累积误差。
LCM提出了单阶段引导蒸馏,通过增强PF-ODE,把CFG直接融入蒸馏过程,一步到位,彻底告别两阶段蒸馏。
首先,我们把CFG加权后的噪声预测,直接代入潜空间PF-ODE,得到增强PF-ODE:
dztdt=f(t)zt+g2(t)2σtϵ~θ(zt,ω,c,t)\frac{d z_{t}}{ d t}=f(t) z_{t}+\frac{g^{2}(t)}{2 \sigma_{t}} \tilde{\epsilon}_{\theta}\left(z_{t}, \omega, c, t\right)dtdzt=f(t)zt+2σtg2(t)ϵ~θ(zt,ω,c,t)
zT∼N(0,σ‾2I)z_{T} \sim \mathcal{N}\left(0, \overline{\sigma}^{2} I\right)zT∼N(0,σ2I)
这里的ϵ~θ\tilde{\epsilon}_\thetaϵ~θ就是CFG加权后的噪声预测,这个增强PF-ODE,就是带CFG引导的生成轨迹,也是我们希望LCM学会的目标轨迹。
接下来,我们定义增强一致性函数,把CFG尺度ω\omegaω也作为模型的输入:
fθ(z,ω,c,t)=cskip(t)z+cout(t)(z−σtϵ^θ(z,ω,c,t)αt)f_{\theta}(z, \omega, c, t)=c_{skip }(t) z+c_{out }(t)\left(\frac{z-\sigma_{t} \hat{\epsilon}_{\theta}(z, \omega, c, t)}{\alpha_{t}}\right)fθ(z,ω,c,t)=cskip(t)z+cout(t)(αtz−σtϵ^θ(z,ω,c,t))
和之前的一致性函数相比,唯一的变化就是噪声预测网络ϵ^θ\hat{\epsilon}_\thetaϵ^θ新增了一个输入ω\omegaω(CFG尺度)。我们对ω\omegaω做傅里叶编码,融入到UNet的时间嵌入里,用零初始化保证训练稳定性,完全不用修改SD的UNet主干结构。
对应的,蒸馏损失也升级为带CFG的引导蒸馏损失:
LCD(θ,θ−;Ψ)=Ez,c,ω,n[d(fθ(ztn+1,ω,c,tn+1),fθ−(z^tnΨ,ω,ω,c,tn)]\mathcal{L}_{\mathcal{C D}}\left(\theta, \theta^{-} ; \Psi\right)=\mathbb{E}_{z, c, \omega, n}\left[d\left(f_{\theta}\left(z_{t_{n+1}}, \omega, c, t_{n+1}\right), f_{\theta^{-}}\left(\hat{z}_{t_{n}}^{\Psi, \omega}, \omega, c, t_{n}\right)\right]\right.LCD(θ,θ−;Ψ)=Ez,c,ω,n[d(fθ(ztn+1,ω,c,tn+1),fθ−(z^tnΨ,ω,ω,c,tn)]
其中,带CFG的轨迹点估算公式为:
z^tnΨ,ω←ztn+1+(1+ω)Ψ(ztn+1,tn+1,tn,c)−ωΨ(ztn+1,tn+1,tn,ø)\hat{z}_{t_{n}}^{\Psi, \omega} \leftarrow z_{t_{n+1}}+(1+\omega) \Psi\left(z_{t_{n+1}}, t_{n+1}, t_{n}, c\right)-\omega \Psi\left(z_{t_{n+1}}, t_{n+1}, t_{n}, ø\right)z^tnΨ,ω←ztn+1+(1+ω)Ψ(ztn+1,tn+1,tn,c)−ωΨ(ztn+1,tn+1,tn,ø)
训练时,我们从[ωmin,ωmax][\omega_{min}, \omega_{max}][ωmin,ωmax](论文里用[2,14])里均匀采样CFG尺度,让模型同时适配不同的引导强度,推理时用户可以自由调整CFG尺度,不用重新训练。
这个单阶段引导蒸馏的优势是碾压级的:
- 训练成本极低:只需要单阶段蒸馏,32个A100小时就能完成768×768模型的训练,而Guided-Distill需要45个A100天;
- 无误差累积:一步到位学习带CFG的目标轨迹,没有两阶段蒸馏的误差累积;
- 完全兼容预训练SD:直接用任意预训练SD模型做老师,不用修改主干结构,社区里的所有SD模型都能蒸馏成LCM。
4.3 SKIPPING-STEP跳步技术:大幅加速蒸馏收敛
Stable Diffusion用的是1000步的离散噪声调度,原版一致性蒸馏只能在相邻的时间步(比如tnt_ntn和tn+1t_{n+1}tn+1)之间做约束,这就带来了一个致命问题:相邻时间步的潜向量ztnz_{t_n}ztn和ztn+1z_{t_{n+1}}ztn+1几乎一模一样,一致性损失的值极小,模型几乎学不到东西,收敛速度慢到离谱。
LCM提出了SKIPPING-STEP跳步技术,完美解决了这个问题。核心逻辑很简单:不再约束相邻时间步,而是约束间隔k步的两个时间步tn+kt_{n+k}tn+k和tnt_ntn。
对应的,蒸馏损失修改为:
LCD(θ,θ−;Ψ)=Ez,c,ω,n[d(fθ(ztn+k,ω,c,tn+k) ,fθ−(z^tnΨ,ω,ω,c,tn))]\mathcal {L}_{\mathcal {C}\mathcal {D}}\left( \theta ,\theta ^{-};\Psi \right) =\mathbb {E}_{z,c,\omega ,n}\left[ d\left(f_{\theta }(z_{t_{n+k}},\omega ,c,t_{n+k})\, ,f_{\theta ^{-}}(\hat {z}_{t_{n}}^{\Psi ,\omega },\omega ,c,t_{n})\right) \right]LCD(θ,θ−;Ψ)=Ez,c,ω,n[d(fθ(ztn+k,ω,c,tn+k),fθ−(z^tnΨ,ω,ω,c,tn))]
轨迹点的估算公式也对应修改为:
z^tnΨ,ω←ztn+k+(1+ω)Ψ(ztn+k,tn+k,tn,c)−ωΨ(ztn+k,tn+k,tn,ø)\hat{z}_{t_{n}}^{\Psi, \omega} \leftarrow z_{t_{n+k}}+(1+\omega) \Psi\left(z_{t_{n+k}}, t_{n+k}, t_{n}, c\right)-\omega \Psi\left(z_{t_{n+k}}, t_{n+k}, t_{n}, ø\right)z^tnΨ,ω←ztn+k+(1+ω)Ψ(ztn+k,tn+k,tn,c)−ωΨ(ztn+k,tn+k,tn,ø)
| 符号 | 数学含义 | 通俗解释 |
|---|---|---|
| kkk | 跳步间隔,论文里默认取20 | 两个时间步之间的间隔,k=1就是原版的相邻步约束 |
跳步技术的效果有多夸张?原版k=1时,模型训10000步都收敛不了,而k=20时,模型2000步就能快速收敛,训练速度直接提升了数十倍。
论文里也做了消融实验,k太小会导致收敛慢,k太大会导致ODE求解器的估算误差变大,k=20是兼顾收敛速度和精度的最优值。同时,DPM-Solver和DPM-Solver++这类高阶求解器,能承受更大的k值(比如k=50),而DDIM作为一阶求解器,k=20是最优选择。
4.4 潜空间一致性微调(LCF):自定义数据集的少步适配
预训练的SD模型,经常需要在自定义数据集上微调,来生成特定风格的图片(比如二次元、Pokemon、辛普森一家)。但传统的微调方法,会破坏模型的少步生成能力,微调后又得重新蒸馏,成本极高。
LCM提出了潜空间一致性微调(LCF),专门解决这个问题:不用教师扩散模型,直接在预训练LCM的基础上,对自定义数据集做微调,同时完美保留少步生成能力。
LCF的核心逻辑,和一致性训练(CT)一致:对同一张干净图片z0z_0z0,用同一个噪声ϵ\epsilonϵ生成两个不同时间步的加噪潜向量ztn+kz_{t_{n+k}}ztn+k和ztnz_{t_n}ztn,让一致性函数对这两个加噪向量的预测结果保持一致。
加噪潜向量的生成公式为:
ztn+k=α(tn+k)z+σ(tn+k)ϵ,ztn=α(tn)z+σ(tn)ϵz_{t_{n+k}}=\alpha (t_{n+k})z+\sigma (t_{n+k})\epsilon , \quad z_{t_{n}}=\alpha (t_{n})z+\sigma (t_{n})\epsilonztn+k=α(tn+k)z+σ(tn+k)ϵ,ztn=α(tn)z+σ(tn)ϵ
| 符号 | 数学含义 | 通俗解释 |
|---|---|---|
| zzz | 自定义数据集图片的干净潜向量 | 微调数据集里的图片,经过VAE编码得到的潜向量 |
| ϵ\epsilonϵ | 随机高斯噪声,和两个时间步共用 | 保证两个加噪潜向量在同一条PF-ODE轨迹上 |
LCF的损失函数为:
L(θ,θ−)=Ez,c,ω,n[d(fθ(ztn+k,tn+k,c,ω),fθ−(ztn,tn,c,ω))]\mathcal{L}\left(\theta, \theta^{-}\right)=\mathbb{E}_{z, c, \omega, n}\left[d\left(f_{\theta}\left(z_{t_{n+k}}, t_{n+k}, c, \omega\right), f_{\theta^{-}}\left(z_{t_{n}}, t_{n}, c, \omega\right)\right)\right]L(θ,θ−)=Ez,c,ω,n[d(fθ(ztn+k,tn+k,c,ω),fθ−(ztn,tn,c,ω))]
整个微调过程,完全不需要预训练的教师扩散模型,只需要自定义数据集,就能让LCM适配新的风格,同时保留2~4步的少步生成能力,大幅降低了自定义模型的开发成本。
实验结果与深度分析
论文做了全面的实验,从定量指标、消融实验、定性效果、自定义微调四个维度,全面验证了LCM的优越性。
5.1 实验设置
- 数据集:512×512分辨率用LAION-Aesthetics-6+(12M图文对),768×768分辨率用LAION-Aesthetics-6.5+(650K图文对);
- 教师模型:512×512用Stable Diffusion V2.1-Base(ϵ-Prediction),768×768用Stable Diffusion V2.1(v-Prediction);
- 训练配置:batch size 72(512×512)/16(768×768),100K迭代,8张A100 GPU,学习率8e-6,EMA系数0.999943;
- 基线方法:DDIM、DPM-Solver、DPM-Solver++(训练-free加速方法),Guided-Distill(两阶段蒸馏方法);
- 评估指标:FID(弗雷歇 inception 距离,衡量生成质量和多样性,越小越好)、CLIP Score(衡量图文对齐度,越大越好)。
5.2 核心定量结果
表格1 512×512分辨率下各方法的性能对比(出处:原论文Table 1)
所有结果均在CFG尺度8.0下测试,核心指标为FID(↓)和CLIP Score(↑)
| 模型 | FID ↓ | CLIP Score ↑ | ||||||
|---|---|---|---|---|---|---|---|---|
| 1步 | 2步 | 4步 | 8步 | 1步 | 2步 | 4步 | 8步 | |
| DDIM | 183.29 | 81.05 | 22.38 | 13.83 | 6.03 | 14.13 | 25.89 | 29.29 |
| DPM | 185.78 | 72.81 | 18.53 | 12.24 | 6.35 | 15.10 | 26.64 | 29.54 |
| DPM++ | 185.78 | 72.81 | 18.43 | 12.20 | 6.35 | 15.10 | 26.64 | 29.55 |
| Guided-Distill | 108.21 | 33.25 | 15.12 | 13.89 | 12.08 | 22.71 | 27.25 | 28.17 |
| LCM (Ours) | 35.36 | 13.31 | 11.10 | 11.84 | 24.14 | 27.83 | 28.69 | 28.84 |
结果分析:
- 少步性能碾压所有基线:LCM在1-4步的核心区间,性能和其他方法拉开了数量级差距。4步推理时,LCM的FID低至11.10,比DPM++低了66%,比Guided-Distill低了26.6%;CLIP Score高达28.69,超过了其他方法8步的水平;
- 1步生成依然可用:其他方法1步生成的FID都在100以上,完全是噪声图,而LCM 1步的FID只有35.36,CLIP Score达到24.14,能生成结构完整、图文对齐的图像;
- 4步就达到收敛:LCM在4步时FID就降到了最低点,8步性能几乎没有变化,而其他方法要8步才能接近LCM 4步的水平,验证了LCM的少步生成能力。
表格2 768×768分辨率下各方法的性能对比(出处:原论文Table 2)
所有结果均在CFG尺度8.0下测试,核心指标为FID(↓)和CLIP Score(↑)
| 模型 | FID ↓ | CLIP Score ↑ | ||||||
|---|---|---|---|---|---|---|---|---|
| 1步 | 2步 | 4步 | 8步 | 1步 | 2步 | 4步 | 8步 | |
| DDIM | 186.83 | 77.26 | 24.28 | 15.66 | 6.93 | 16.32 | 26.48 | 29.49 |
| DPM | 188.92 | 67.14 | 20.11 | 14.08 | 7.40 | 17.11 | 27.25 | 29.80 |
| DPM++ | 188.91 | 67.14 | 20.08 | 14.11 | 7.41 | 17.11 | 27.26 | 29.84 |
| Guided-Distill | 120.28 | 30.70 | 16.70 | 14.12 | 12.88 | 24.88 | 28.45 | 29.16 |
| LCM (Ours) | 34.22 | 16.32 | 13.53 | 14.97 | 25.32 | 27.92 | 28.60 | 28.49 |
结果分析:
768×768高分辨率下,LCM的优势依然保持:4步推理FID低至13.53,远超其他方法;1步生成的FID仅34.22,CLIP Score25.32,依然保持了极高的可用性。这证明LCM完全适配高分辨率生成场景,没有因为分辨率提升而出现性能衰减。
图2 2/4步推理的生成效果对比(出处:原论文Figure 2)
图中对比了LCM、Guided-Distill、DPM-Solver++的2/4步生成结果。可以清晰看到,LCM 2步生成的图像,细节和真实感已经超过了其他方法4步的效果;LCM 4步生成的图像,光影、纹理、结构都达到了原版SD 50步的水平,而其他方法4步依然存在模糊、结构错乱、细节缺失的问题。
5.3 消融实验
5.3.1 ODE求解器与跳步间隔k的影响

图3 不同ODE求解器和跳步间隔k的收敛曲线(出处:原论文Figure 3)
左图是DDIM求解器,中图是DPM-Solver,右图是DPM-Solver++,横轴是训练迭代次数,纵轴是4步推理的FID值。
结果分析:
- 跳步技术是收敛的核心:k=1时,所有求解器的收敛速度都极慢,12000次迭代FID还在25以上;而k=20时,2000次迭代FID就降到了15以下,收敛速度提升了数倍;
- 高阶求解器适配更大的k:DDIM作为一阶求解器,k=50时会出现精度下降;而DPM-Solver和DPM-Solver++作为二阶求解器,k=50时依然能保持优秀的收敛速度和精度;
- 三者最终性能接近:当k=20时,三种求解器的最终FID几乎一致,都能达到11左右,论文最终选择DDIM作为默认求解器,兼顾速度和精度。
5.3.2 CFG尺度的影响

图4 不同CFG尺度下的FID和CLIP Score变化(出处:原论文Figure 4)
左图是FID随CFG尺度的变化,右图是CLIP Score随CFG尺度的变化。
结果分析:
- CFG尺度平衡质量和对齐度:随着CFG尺度增大,CLIP Score持续上升,说明图文对齐度越来越好;但FID先降后升,说明超过一定阈值后,多样性会下降;
- 少步性能差距极小:2/4/8步的曲线几乎重合,说明LCM在不同步数下,对CFG尺度的适配性都极好,不会因为步数减少而出现性能暴跌;
- 1步性能有差距:1步的FID和CLIP Score和多步有一定差距,还有优化空间,这也是后续工作的改进方向。

图5 不同CFG尺度的生成效果对比(出处:原论文Figure 5)
从左到右CFG尺度依次增大,可以看到,随着CFG尺度提升,生成的图像和文本的贴合度越来越高,细节越来越精准,验证了LCM单阶段引导蒸馏的有效性。
5.4 自定义数据集微调结果

图6 LCF自定义微调的生成效果(出处:原论文Figure 6)
左图是Pokemon数据集,右图是辛普森一家数据集,展示了原始LCM、微调1K/10K/30K步的4步生成效果。
结果分析:
经过LCF微调后,LCM快速适配了自定义数据集的风格,30K步微调后,就能用4步推理生成风格统一、细节完整的Pokemon和辛普森一家风格图像,同时完全保留了少步生成能力,不用重新蒸馏,验证了LCF方法的有效性。
核心代码实现
下面是基于PyTorch和diffusers库的LCM核心实现,包含潜空间一致性蒸馏损失、带CFG编码的UNet、LCM采样算法,完全对齐论文里的公式和算法逻辑。
环境依赖
pip install torch diffusers transformers accelerate
完整核心代码
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers import UNet2DConditionModel, AutoencoderKL, DDIMScheduler
from transformers import CLIPTextModel, CLIPTokenizer
from accelerate import Accelerator
from tqdm import tqdm
# ===================== 1. CFG尺度的傅里叶编码模块 =====================
class TimestepEmbedding(nn.Module):
def __init__(self, in_channels: int, time_embed_dim: int):
super().__init__()
self.linear_1 = nn.Linear(in_channels, time_embed_dim)
self.act = nn.SiLU()
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim)
def forward(self, sample: torch.Tensor):
sample = self.linear_1(sample)
sample = self.act(sample)
sample = self.linear_2(sample)
return sample
class FourierEmbedding(nn.Module):
def __init__(self, embedding_dim: int, scale: float = 1.0):
super().__init__()
self.embedding_dim = embedding_dim
self.scale = scale
def forward(self, x: torch.Tensor):
x = x * self.scale
half_dim = self.embedding_dim // 2
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=x.device) * -emb)
emb = x.unsqueeze(-1) * emb.unsqueeze(0)
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
# ===================== 2. LCM UNet:扩展原版SD UNet,支持CFG尺度编码 =====================
class LCMUNet(nn.Module):
def __init__(
self,
base_unet: UNet2DConditionModel,
time_embed_dim: int = 1280,
cfg_embed_dim: int = 256,
):
super().__init__()
# 原版SD的UNet主干
self.unet = base_unet
# CFG尺度的编码模块
self.cfg_encoder = nn.Sequential(
FourierEmbedding(cfg_embed_dim),
TimestepEmbedding(cfg_embed_dim, time_embed_dim)
)
# 零初始化的投影层,保证训练稳定性
self.cfg_proj = nn.Linear(time_embed_dim, time_embed_dim)
nn.init.zeros_(self.cfg_proj.weight)
nn.init.zeros_(self.cfg_proj.bias)
def forward(
self,
sample: torch.Tensor,
timestep: torch.Tensor,
encoder_hidden_states: torch.Tensor,
cfg_scale: torch.Tensor,
return_dict: bool = False,
):
# 1. 计算原版UNet的时间嵌入
t_emb = self.unet.time_proj(timestep)
t_emb = self.unet.time_embedding(t_emb)
# 2. 计算CFG尺度的嵌入,加到时间嵌入里
cfg_emb = self.cfg_encoder(cfg_scale)
cfg_emb = self.cfg_proj(cfg_emb)
t_emb = t_emb + cfg_emb
# 3. 经过UNet的下采样、中间层、上采样
# 预处理
if self.unet.config.in_channels != sample.shape[1]:
sample = self.unet.conv_in(sample)
# 下采样
down_block_res_samples = (sample,)
for downsample_block in self.unet.down_blocks:
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
sample, res_samples = downsample_block(
hidden_states=sample,
temb=t_emb,
encoder_hidden_states=encoder_hidden_states,
)
else:
sample, res_samples = downsample_block(hidden_states=sample, temb=t_emb)
down_block_res_samples += res_samples
# 中间层
if self.unet.mid_block is not None:
sample = self.unet.mid_block(
sample, temb=t_emb, encoder_hidden_states=encoder_hidden_states
)
# 上采样
for upsample_block in self.unet.up_blocks:
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
sample = upsample_block(
hidden_states=sample,
temb=t_emb,
res_hidden_states_tuple=res_samples,
encoder_hidden_states=encoder_hidden_states,
)
else:
sample = upsample_block(
hidden_states=sample, temb=t_emb, res_hidden_states_tuple=res_samples
)
# 输出层
sample = self.unet.conv_norm_out(sample)
sample = self.unet.conv_act(sample)
sample = self.unet.conv_out(sample)
if return_dict:
return {"sample": sample}
return sample
# ===================== 3. 潜空间一致性蒸馏(LCD)损失函数 =====================
def get_skip_out_coeffs(timesteps: torch.Tensor, scheduler: DDIMScheduler, sigma_data: float = 0.5):
"""
计算一致性函数的c_skip和c_out系数
"""
sqrt_alpha_prod = scheduler.alphas_cumprod[timesteps] ** 0.5
sqrt_one_minus_alpha_prod = (1 - scheduler.alphas_cumprod[timesteps]) ** 0.5
c_skip = sigma_data ** 2 / ((sqrt_one_minus_alpha_prod ** 2) + (sqrt_alpha_prod ** 2) * sigma_data ** 2)
c_out = sqrt_one_minus_alpha_prod * sigma_data / ((sqrt_one_minus_alpha_prod ** 2) + (sqrt_alpha_prod ** 2) * sigma_data ** 2) ** 0.5
return c_skip.view(-1, 1, 1, 1), c_out.view(-1, 1, 1, 1)
def lcm_consistency_loss(
student_model: LCMUNet,
target_model: LCMUNet,
teacher_unet: UNet2DConditionModel,
scheduler: DDIMScheduler,
latents: torch.Tensor,
text_embeds: torch.Tensor,
uncond_embeds: torch.Tensor,
num_train_timesteps: int = 1000,
skip_step: int = 20,
cfg_min: float = 2.0,
cfg_max: float = 14.0,
):
"""
计算LCM的一致性蒸馏损失,对齐论文公式
:param student_model: 学生LCM模型,正在训练
:param target_model: 目标EMA模型,稳定训练
:param teacher_unet: 预训练SD的教师UNet
:param scheduler: DDIM调度器
:param latents: 干净的潜向量z0,shape [batch_size, 4, 64, 64]
:param text_embeds: 文本条件embedding,shape [batch_size, 77, 768]
:param uncond_embeds: 无条件embedding,shape [batch_size, 77, 768]
:param skip_step: 跳步间隔k,默认20
:param cfg_min: CFG尺度最小值
:param cfg_max: CFG尺度最大值
:return: 标量损失值
"""
batch_size = latents.shape[0]
device = latents.device
# 1. 随机采样时间步:t_n+k和t_n
max_step = num_train_timesteps - skip_step
n = torch.randint(0, max_step, (batch_size,), device=device)
t_n = n
t_nk = n + skip_step
# 2. 随机采样CFG尺度
cfg_scale = torch.rand((batch_size,), device=device) * (cfg_max - cfg_min) + cfg_min
# 3. 采样随机噪声,生成加噪潜向量
noise = torch.randn_like(latents)
z_nk = scheduler.add_noise(latents, noise, t_nk)
# 4. 用教师UNet计算DDIM一步求解,得到z_n的估计值
with torch.no_grad():
# 条件和无条件噪声预测
noise_cond = teacher_unet(z_nk, t_nk, encoder_hidden_states=text_embeds).sample
noise_uncond = teacher_unet(z_nk, t_nk, encoder_hidden_states=uncond_embeds).sample
# CFG加权
noise_pred = (1 + cfg_scale.view(-1, 1, 1, 1)) * noise_cond - cfg_scale.view(-1, 1, 1, 1) * noise_uncond
# DDIM一步求解,从t_nk到t_n
alpha_t_n = scheduler.alphas_cumprod[t_n].view(-1, 1, 1, 1)
alpha_t_nk = scheduler.alphas_cumprod[t_nk].view(-1, 1, 1, 1)
sigma_t_n = (1 - alpha_t_n) ** 0.5
sigma_t_nk = (1 - alpha_t_nk) ** 0.5
pred_x0 = (z_nk - sigma_t_nk * noise_pred) / alpha_t_nk ** 0.5
z_n_hat = alpha_t_n ** 0.5 * pred_x0 + sigma_t_n * noise_pred
# 5. 计算学生模型对z_nk的预测
c_skip_nk, c_out_nk = get_skip_out_coeffs(t_nk, scheduler)
noise_pred_student_nk = student_model(z_nk, t_nk, text_embeds, cfg_scale)
pred_x0_student_nk = (z_nk - scheduler.sqrt_one_minus_alphas_cumprod[t_nk].view(-1, 1, 1, 1) * noise_pred_student_nk) / scheduler.alphas_cumprod[t_nk].sqrt().view(-1, 1, 1, 1)
f_nk = c_skip_nk * z_nk + c_out_nk * pred_x0_student_nk
# 6. 计算目标EMA模型对z_n_hat的预测
with torch.no_grad():
c_skip_n, c_out_n = get_skip_out_coeffs(t_n, scheduler)
noise_pred_target_n = target_model(z_n_hat, t_n, text_embeds, cfg_scale)
pred_x0_target_n = (z_n_hat - scheduler.sqrt_one_minus_alphas_cumprod[t_n].view(-1, 1, 1, 1) * noise_pred_target_n) / scheduler.alphas_cumprod[t_n].sqrt().view(-1, 1, 1, 1)
f_n = c_skip_n * z_n_hat + c_out_n * pred_x0_target_n
# 7. 计算MSE损失
loss = F.mse_loss(f_nk, f_n)
return loss
# ===================== 4. LCM采样算法 =====================
@torch.no_grad()
def lcm_sample(
model: LCMUNet,
scheduler: DDIMScheduler,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
vae: AutoencoderKL,
prompt: str,
num_inference_steps: int = 4,
cfg_scale: float = 8.0,
height: int = 512,
width: int = 512,
device: torch.device = torch.device("cuda"),
):
"""
LCM少步采样,对齐论文的多步采样算法
"""
# 1. 文本编码
text_input = tokenizer(
prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt"
)
text_embeds = text_encoder(text_input.input_ids.to(device))[0]
uncond_input = tokenizer([""], padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt")
uncond_embeds = text_encoder(uncond_input.input_ids.to(device))[0]
# 2. 初始化潜向量
latents = torch.randn((1, 4, height // 8, width // 8), device=device)
latents = latents * scheduler.init_noise_sigma
# 3. 设置时间步
scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = scheduler.timesteps.flip(0) # 从大到小
# 4. 迭代采样
for i, t in enumerate(timesteps):
# 预测噪声
noise_pred = model(latents, t, text_embeds, torch.tensor([cfg_scale], device=device))
# 预测x0
pred_x0 = (latents - scheduler.sqrt_one_minus_alphas_cumprod[t] * noise_pred) / scheduler.alphas_cumprod[t].sqrt()
# 如果不是最后一步,加噪到下一个时间步
if i < len(timesteps) - 1:
next_t = timesteps[i+1]
noise = torch.randn_like(latents)
latents = scheduler.alphas_cumprod[next_t].sqrt() * pred_x0 + scheduler.sqrt_one_minus_alphas_cumprod[next_t] * noise
else:
latents = pred_x0
# 5. VAE解码
latents = 1 / 0.18215 * latents
image = vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()[0]
return image
# ===================== 5. 训练主函数示例 =====================
def train_lcm():
# 加速器初始化
accelerator = Accelerator(mixed_precision="fp16")
device = accelerator.device
# 加载预训练SD模型
model_id = "runwayml/stable-diffusion-v1-5"
vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae").to(device)
text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder").to(device)
tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer")
teacher_unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet").to(device)
scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler")
# 冻结VAE、文本编码器、教师UNet
vae.requires_grad_(False)
text_encoder.requires_grad_(False)
teacher_unet.requires_grad_(False)
# 初始化学生和目标LCM模型
student_unet = LCMUNet(UNet2DConditionModel.from_pretrained(model_id, subfolder="unet")).to(device)
target_unet = LCMUNet(UNet2DConditionModel.from_pretrained(model_id, subfolder="unet")).to(device)
target_unet.load_state_dict(student_unet.state_dict())
target_unet.requires_grad_(False)
# 优化器
optimizer = torch.optim.AdamW(student_unet.parameters(), lr=8e-6, weight_decay=1e-5)
ema_decay = 0.999943
# 这里需要替换成你自己的数据集加载逻辑
# dataloader = ...
# 加速器准备
student_unet, optimizer, dataloader = accelerator.prepare(student_unet, optimizer, dataloader)
# 训练循环
num_epochs = 10
global_step = 0
for epoch in range(num_epochs):
progress_bar = tqdm(dataloader, disable=not accelerator.is_local_main_process)
for batch in progress_bar:
with accelerator.accumulate(student_unet):
# 图片编码为潜向量
pixel_values = batch["pixel_values"].to(device)
latents = vae.encode(pixel_values).latent_dist.sample() * 0.18215
# 文本编码
text_input = tokenizer(
batch["text"], padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt"
)
text_embeds = text_encoder(text_input.input_ids.to(device))[0]
uncond_input = tokenizer(
[""] * len(batch["text"]), padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt"
)
uncond_embeds = text_encoder(uncond_input.input_ids.to(device))[0]
# 计算LCM损失
loss = lcm_consistency_loss(
student_unet, target_unet, teacher_unet, scheduler,
latents, text_embeds, uncond_embeds
)
# 反向传播
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
# 更新EMA目标模型
for student_param, target_param in zip(student_unet.parameters(), target_unet.parameters()):
target_param.data.mul_(ema_decay).add_(student_param.data, alpha=1 - ema_decay)
global_step += 1
progress_bar.set_postfix({"loss": loss.item(), "step": global_step})
# 每个epoch保存模型
accelerator.save(student_unet.state_dict(), f"lcm_epoch_{epoch}.pth")
if __name__ == "__main__":
train_lcm()
总结与展望
LCM的出现,彻底解决了扩散模型少步生成的行业难题,让AI图像生成从“几秒出图”迈入了“实时出图”的时代。我们来回顾它的核心贡献:
- 首次将一致性模型适配到潜空间,完美兼容Stable Diffusion的生态,让高分辨率文本-图像的少步生成成为可能;
- 提出单阶段引导蒸馏,把CFG无缝融入一致性蒸馏,告别了两阶段蒸馏的高成本和误差累积,训练成本从“数十天”降到了“32小时”;
- 发明SKIPPING-STEP跳步技术,让蒸馏收敛速度提升数十倍,解决了SD 1000步长调度下的训练难题;
- 设计LCF微调方法,让自定义数据集的微调也能保留少步生成能力,大幅降低了定制化模型的开发门槛。
LCM的落地价值已经被行业充分验证:现在几乎所有的AI绘画平台都集成了LCM,ControlNet、IP-Adapter等插件也都完美适配LCM,真正实现了“输入文案,实时出图”。
未来,LCM还有广阔的探索空间:
- 扩展到更多生成任务,比如图像编辑、超分辨率、视频生成、3D生成;
- 优化1步生成的效果,进一步提升生成速度;
- 和最新的扩散模型架构结合,比如SDXL、Flux,实现更高质量的少步生成;
- 探索无教师的从头训练方法,不用依赖预训练SD模型,直接训LCM。
LCM不仅是一个优秀的少步生成算法,更是为生成建模领域提供了一个全新的思路:不用执着于一步步迭代求解ODE,直接学会从轨迹上任意一点跳到终点,才是实现极速生成的终极方案。
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐


所有评论(0)