论文链接:Elucidating the Design Space of Diffusion-Based Generative Models(22‘6)

扩散模型有很坚实的理论基础,但是模型、采样策略、训练策略、噪声参数化方法等等之间强耦合,且五花八门。作者的贡献在于

  1. 提出了一套对去噪分数匹配模型的统一框架EDM;

  2. 找到采样时最优性能的时间步离散方法,应用高阶Runge-Kutta方法,评估不同采样器;

  3. 为提升训练效果,对模型的输入、输出、损失函数进行预处理,调整训练期间的噪声水平分布,应用non-leaking数据增强。

首先在常见框架下表述一下diffusion model:

假设数据分布为p_{data}(x),其标准差为\sigma_{data}。通过往数据x中添加独立同分布的标准差为\sigma的高斯噪声,获得分布族p(x; \sigma),对于巨大的\sigma_{max} \gg\sigma_{data}p(x, \sigma_{max})几乎与纯高斯噪声无异。初始随机采样x_0 \sim \mathcal{N}(0, \sigma_{max}^2I),顺序去噪成x_i,这个过程中样本的噪声水平满足\sigma_0=\sigma_{max}>\sigma_1>\sigma_2>\cdots>\sigma_N=0,终点x_N便落在了数据分布中。

ODE公式

基于漂移系数f(t)和扩散系数g(t) 的原始叙事

Song et al.将扩散模型的前向随机微分方程定义为dx=f(x,t)dt+g(t)dw_tf(\cdot,t)g(\cdot)分别表示漂移和扩散系数,在方差保持VP和方差爆炸VE中设计存在差异,一般是f(x,t)=f(t)x, f(\cdot): \mathbb{R} \rightarrow \mathbb{R},前向SDE可写成dx=f(t)xdt+g(t)dw_t

布朗运动在事件反向时会改变统计结构,反向SDE并不是将时间反向那么简单,为dx=[f(t)x-g(t)^2 \bigtriangledown_x \log p(x; \sigma)]dt+g(t)d\bar{W},其中分数函数\bigtriangledown_x \log p(x; \sigma)是一个指示当前噪声水平下数据概率密度更高的方向的矢量(概率密度对数的梯度)。

前向扰动核的一般形式为p_{ot}\left(x_t|x_0\right)=\mathcal{N}\left(x_t; s(t)x_0,s(t)^2 \sigma(t)^2I\right),其中信号缩放系数s(t)=\exp \left( \int_0^t f(\xi)d\xi\right),噪声强度\sigma(t)=\sqrt{\int_0^t \frac{g(\xi)^2}{s(\xi)^2}d\xi},对应边缘分布p_t(x)=\int_{\mathbb{R}^d}p_{0t}(x|x_0)p_{data}(x_0)dx_0,将能求解出相同p_t(x)的常微分方程称作概率流ODE:dx=[f(t)x-\frac{1}{2}g(t)^2 \bigtriangledown_x \log p_t(x)]dt,在前向和反向过程中就只有时间反向的差异了。在每轮迭代中,可以通过一个随机求解器去噪加噪,也可以用一个ODE求解器去噪,整个流程中唯一的随机来源于采样x_0

基于信号缩放s(t)和噪声水平\sigma(t)的EMD叙事

这样的表述过程有个问题,PF ODE公式建立在本身没有多少实际意义的fg上,而对于训练模型、引导采样、理解ODE在实践中的具体意义至关重要的边缘分布却只能由这些公式间接推导。EDM换了一种表述,既然PF ODE是为了匹配特定集合的边缘分布,为什么不直接将边缘分布作为目标,基于\sigma(t)和s(t)定义ODE?拆解边缘分布\begin{matrix} p_t(x) &= \int_{\mathbb{R}^d}p_{0t}(x|x_0)p_{data}(x_0)dx_0\\ &= \int_{\mathbb{R}^d} p_{data}(x_0)\left[ \mathcal{N}\left(x; s(t)x_0,s(t)^2 \sigma(t)^2I\right)\right]dx_0 \\ &= \int_{\mathbb{R}^d} p_{data}(x_0)\left[ s(t)^{-d}\mathcal{N}\left(x/s(t); x_0,\sigma(t)^2I\right)\right]dx_0 \\ &=s(t)^{-d} \int_{\mathbb{R}^d} p_{data}(x_0) \mathcal{N}\left(x/s(t); x_0,\sigma(t)^2I\right)dx_0 \\ &= s(t)^{-d} \left[ p_{data} * \mathcal{N}\left(0, \sigma(t)^2I\right)\right](x/s(t)) \end{matrix}

p_a * p_b表示两个概率密度函数之间的卷积操作,定义分布p(x;\sigma)=p_{data} * \mathcal{N}\left(0, \sigma(t)^2I\right),有边缘分布p_t(x) = s(t)^{-1}p\left(x/s(t); \sigma(t)\right),同样用p(x;\sigma)表示PF ODE\begin{matrix} dx &=[f(t)x-\frac{1}{2}g(t)^2 \bigtriangledown_x \log p_t(x)]dt \\ &=[f(t)x-\frac{1}{2}g(t)^2 \bigtriangledown_x \log [s(t)^{-1}p(x/s(t); \sigma(t))]]dt \\ &=[f(t)x-\frac{1}{2}g(t)^2 [\bigtriangledown_x \log s(t)^{-1} + \bigtriangledown_x \log p(x/s(t); \sigma(t))]]dt \\ &= [f(t)x-\frac{1}{2}g(t)^2\bigtriangledown_x \log p(x/s(t); \sigma(t))]dt \end{matrix}

\sigma(t)和s(t)对f(t)和g(t)进行重写:f(t)=\dot{s}(t)/s(t)g(t)=s(t)\sqrt{2\dot{\sigma}(t)\sigma(t)},点表示导数。回代可得dx = \left[ \frac{\dot{s}(t)}{s(t)}x - s(t)^2 \dot{\sigma}(t) \sigma(t) \bigtriangledown_{x} \log p\left( \frac{x}{s(t)}; \sigma(t) \right) \right] dt。PF ODE的不同实现都是对同一个正则ODE的重新参数化,噪声强度\sigma(t)和信号缩放系数s(t)分别对t和x进行转换。

分数匹配

下图a是在不同噪声水平的分布p(x;\sigma)中采样,即往干净数据中添加不同强度高斯噪声后得到的图片。高噪声可能导致过饱和,所以

假设训练集包含有限个样本\{y_1, \dots,y_Y\},即p_{data}(x)=\frac{1}{Y}\sum_{i=1}^Y \delta(x-y_i),于是\begin{matrix} p(x;\sigma)&=p_{data} * \mathcal{N}\left(0, \sigma(t)^2I\right) \\ &= \int_{\mathbb{R}^d} p_{data}(x_0) \mathcal{N}\left(x; x_0,\sigma^2I\right)dx_0 \\ &= \int_{\mathbb{R}^d} \left[ \frac{1}{Y}\sum_{i=1}^Y \delta(x_0-y_i)\right] \mathcal{N}\left(x; x_0,\sigma^2I\right)dx_0 \\ &= \frac{1}{Y} \sum_{i=1}^Y \int_{\mathbb{R}^d} \mathcal{N}\left(x; x_0,\sigma^2I\right) \delta(x_0-y_i) dx_0 \\ &= \frac{1}{Y} \sum_{i=1}^Y \mathcal{N}\left(x; y_i,\sigma^2I\right) \end{matrix}

理想的去噪器D(x;\sigma)对原始样本y的最佳预测,能最小化在任何\sigma下的L_2去噪期望误差\mathbb{E}_{y \sim p_{data}} \mathbb{E}_{n \sim \mathcal{N}(0, \sigma^2I)} \| D(y+n; \sigma) - y\|_2^2,把期望展开,改写成含噪声样本x的积分\begin{matrix} \mathcal{L}(D; \sigma)&=\mathbb{E}_{y \sim p_{data}} \mathbb{E}_{n \sim \mathcal{N}(0, \sigma^2I)} \| D(y+n; \sigma) - y\|_2^2 \\ &=\mathbb{E}_{y \sim p_{data}} \mathbb{E}_{x \sim \mathcal{N}(y, \sigma^2I)} \| D(x; \sigma) - y\|_2^2 \\ &=\mathbb{E}_{y \sim p_{data}} \int_{\mathbb{R}^d} \mathcal{N}(x;y, \sigma^2I)\| D(x; \sigma) - y\|_2^2 dx \\ &= \frac{1}{Y} \sum_{i=1}^Y \int_{\mathbb{R}^d} \mathcal{N}(x;y_i, \sigma^2I)\| D(x; \sigma) - y_i\|_2^2 dx \\ &= \int_{\mathbb{R}^d} \frac{1}{Y} \sum_{i=1}^Y \mathcal{N}(x;y_i, \sigma^2I) \| D(x; \sigma) - y_i\|_2^2 dx \end{matrix}

说明可以通过独立最小化每个x的\mathcal{L}(D;x,\sigma)=\frac{1}{Y} \sum_{i=1}^Y \mathcal{N}(x;y_i, \sigma^2I) \| D(x; \sigma) - y_i\|_2^2来最小化\mathcal{L}(D; \sigma),求解D(x;\sigma)=\arg \min_{D(x;\sigma)}\mathcal{L}(D;x,\sigma)是一个凸优化问题,根据导数为0得到D(x; \sigma)=\frac{\sum_i \mathcal{N}(x;y_i, \sigma^2I)y_i}{\sum_i \mathcal{N}(x; y_i, \sigma^2I)},对于小数据集,例如CIFAR-10,这个值是可计算的。下列图a对图片添加不同水平的高斯噪声,得到x=x_0+\sigma \epsilon(高噪声下算出的像素值范围广,多数会被裁剪到0或255,导致过饱和的颜色,为了可视化需归一化每张图像的像素范围),图b是计算出的最优去噪器输出,可以看到噪声越来越大时,最优去噪器从“恢复原图“会逐渐退化为“输出数据集平均图像“

考虑分数函数\begin{matrix} \bigtriangledown_x \log p(x; \sigma) =\frac{\bigtriangledown_x p(x; \sigma)}{p(x; \sigma)} = \frac{\bigtriangledown_x \frac{1}{Y} \sum_i \mathcal{N}\left(x; y_i,\sigma^2I\right)}{\frac{1}{Y} \sum_i\mathcal{N}\left(x; y_i,\sigma^2I\right)} =\frac{\sum_i\bigtriangledown_x \mathcal{N}\left(x; y_i,\sigma^2I\right)}{ \sum_i \mathcal{N}\left(x; y_i,\sigma^2I\right)} \end{matrix},其中高斯的导数\bigtriangledown_x \mathcal{N}\left(x; y_i,\sigma^2I\right)= \mathcal{N}\left(x; y_i,\sigma^2I\right) \left[ \frac{y_i - x}{\sigma^2}\right],于是\bigtriangledown_x \log p(x; \sigma) =\frac{\sum_i\mathcal{N}\left(x; y_i,\sigma^2I\right) \left[ \frac{y_i - x}{\sigma^2}\right]}{ \sum_i \mathcal{N}\left(x; y_i,\sigma^2I\right)} = \left( \frac{\sum_i\mathcal{N}\left(x; y_i,\sigma^2I\right)y_i}{\sum_i \mathcal{N}\left(x; y_i,\sigma^2I\right)} - x\right)/\sigma^2 = (D(x; \sigma) - x)/\sigma^2

独立于\sigma,是一个不错的预测目标。实践中去噪器是一个根据“优化目标“训练出的神经网络D_{\theta}(\hat{x}; \sigma),其中\hat{x}=x/s(t),代入PF ODE

\begin{matrix} dx &= \left[ \frac{\dot{s}(t)}{s(t)}x - s(t) \dot{\sigma}(t) \sigma(t) \bigtriangledown_{\frac{x}{s(t)}} \log p\left( \frac{x}{s(t)}; \sigma(t) \right) \right] dt \\ &=\left[ \frac{\dot{s}(t)}{s(t)}x - \frac{\dot{\sigma}(t)s(t)}{\sigma(t)} \left(D_{\theta}(\hat{x}; \sigma(t))-\hat{x} \right)\right] dt \\ &= \left[\left( \frac{\dot{\sigma}(t)}{\sigma(t)}+\frac{\dot{s}(t)}{s(t)}\right) \cdot x - \frac{\dot{\sigma}(t)s(t)}{\sigma(t)}D_{\theta}\left(\frac{x}{s(t)}; \sigma(t)\right) \right] dt \end{matrix}

ODE求解器

为了更多地优化细节,推理时对噪声水平采用幂律调度,取\sigma_{i<N}=(Ai+B)^{\rho},代入边界条件\sigma_0=\sigma_{max}\sigma_{N-1}=\sigma_{min},得到\sigma_{i<N}=(\sigma_{max}^{\frac{1}{\rho}}+\frac{i}{N-1}(\sigma_{min}^{\frac{1}{\rho}}-\sigma_{max}^{\frac{1}{\rho}}))^{\rho},以及\sigma_N = 0,其中N表示ODE求解器的迭代步数,\rho用来调节步长分布,为1时就是普通的线性插值,步长均匀,\rho越大,越多的\sigma_i靠近\sigma_{min},实验选择\sigma_{min}=0.002\sigma_{max}=80\rho = 7。对应地参数化时间步t_i = \sigma^{-1}(\sigma_i)

之前的方案一般采用一阶Euler数值求解ODE,而EDM改为二阶Heun方法,能在局部误差和NFE(模型调用次数)之间取得更好的权衡,针对初值问题\frac{dy}{dx} = f(x, y)y(x_0) = y_0,每轮迭代:1. 计算当前点的斜率k_1 = f(x_n, y_n);2. 用欧拉法预测下一点的临时值y_p = y_n + h \cdot k_1;3. 用临时值计算下一点的斜率k_2 = f(x_n+h, y_p);4. 两斜率取平均,算最终值y_{n+1} = y_n+ h \cdot (k_1+k_2)/2。假设步长h,一阶欧拉求解器引入的局部误差为\mathcal{O}(h^2),Heun二阶方法多引入一次模型调用,将误差变成\mathcal{O}(h^3)

EMD选择s(t)=1,\sigma(t)=t

ODE求解轨迹由\sigma(t)和s(t)定义,为了减少截断误差(与曲率成比例),EDM选择\sigma(t)=t、s(t)=1,此时\sigma和t等价,ODE方程可简化成dx/dt=(x-D(x;t))/t。这样的ODE更容易数值求解?下图中构造了一个一维的toy数据集p_{data}(x) = \frac{1}{2}\delta(x-1)+\frac{1}{2}\delta(x+1),数据只有两个点x=\pm1,给数据加高斯噪声x=x_0+\sigma \epsilon,随着\sigma增大,两个峰会逐渐变成一个高斯分布,表示为在同一t下橙色的深浅显示,橙色线表示生成过程中的流场,黑色箭头表示局部梯度。VP在大噪声区域轨迹几乎水平,在小噪声区域轨迹突然完全并指向x=\pm1,因为分数\bigtriangledown_x \log p(x;\sigma)在小\sigma才明显;VE的轨迹一直都很弯;而DDIM或者EDM在大\sigma区域近似直线且指向数据均值(0),中\sigma区域稍微弯曲,变小后又直了,指向数据(±1)。

引入随机性

相较于确定性采样,引入随机性会更好吗?结合Huen二阶确定性求解器、显式的Langevin式“扰动“,每步去噪含两步:1. 按因子\gamma_i \ge 0往样本中加噪,噪声强度升至\hat{t}_i = t_i + \gamma_i t_i;2. 从x_i \sim \mathcal{N}(0, t_i^2I)变到\hat{x}_i \sim \mathcal{N}(0, \hat{t}_i^2I),只需添加一个\sqrt{\hat{t}_i^2-t_i^2}\epsilon \sim \mathcal{N}\left( 0, (\hat{t}_i^2-t_i^2)I \right),然后就是按Huen ODE求解器求解

加入随机性可以缓解前面引入的误差,但也可能导致细节丢失或者颜色过饱和,可能因为实践中去噪器存在误差,需要启发式方案进行修复,比如只在噪声水平区间t_i \in [S_{tmin}, S_{tmax}]加入随机项,定义\gamma_i=S_{churn}/N,由S_{churn}控制整体随机性,同时通过裁剪保证不要高于原本的噪声水平。针对细节丢失问题,主要因为模型倾向于去掉稍稍偏多的噪声,输出靠近数据均值,这也是L2损失常导致的“均值回归“的现象,可通过设置略高于1的S_{noise}缓解。消融实验:

启发式设置\{S_{churn}, S_{tmin}, S_{tmax}, S_{noise}\}的值最好根据模型定制,启用网格搜索逐个查找最优配置,能看出来,启用随机采样改进模型调用次数时可能会影响到模型架构和训练策略的选择。所以作者声明“We stress that this is not a general-purpose SDE solver, but a sampling procedure tailored for the specific problem.“。

预条件化

通过监督训练网络时,最好把输入输出固定在相同方差,各式模型的性能差异可以归结于对输入输出不同的scaling,而原始样本x=y+n, n \sim \mathcal{N}(0, \sigma^2I)的方差会随着\sigma剧烈变化,之前的工作对输入除\sigma,训练模型预测方差一致的噪声n,重建信号D_{\theta}(x; \sigma)=x-\sigma F_{\theta}(\cdot),问题是模型的输出误差会被放大\sigma倍,高噪声时直接预测信号D(x;\sigma)似乎更容易。EDM将去噪器统一成D_{\theta}(x; \sigma) = c_{skip}(\sigma)x+c_{out}(\sigma)F_{\theta}(c_{in}(\sigma)x; c_{noise}(\sigma)),其中c_{skip}(\sigma)x直接保留输入,c_{out}(\sigma)F_{\theta}表示网络预测残差,涵盖了模型预测n、y,或者介于两者之间的东西的所有情况。根据噪声水平调整总体损失\mathbb{E}_{\sigma, y,n} [\lambda(\sigma)\| D(y+n; \sigma) - y\|_2^2],其中\sigma \sim p_{train}y \sim p_{data}n \sim \mathcal{N}(0, \sigma^2I),替换回原始模型输出的表示:

这样的格式便于探索对网络的有效训练:希望模型输入的方差保持,设置Var_{y,n}[c_{in}(\sigma)(y+n)]=c_{in}^2(\sigma_{data}^2+\sigma^2)=1,得到c_{in} = 1/\sqrt{\sigma^2+\sigma_{data}^2};同样,对于模型的目标输出Var_{y,n}\left[\frac{1}{c_{out}(\sigma)}(y-c_{skip}(\sigma)(y+n)\right]=\frac{1}{c_{out}(\sigma)^2}\left[ (1-c_{skip}(\sigma))^2 \sigma_{data}^2-c_{skip}(\sigma)^2 \sigma^2\right]=1,于是c_{out}(\sigma)^2= (1-c_{skip}(\sigma))^2 \sigma_{data}^2-c_{skip}(\sigma)^2 \sigma^2,为了减少对模型输出误差的放大作用,需要尽可能选择较小的c_{out}(\sigma)(恒非负),成优化问题c_{skip}(\sigma)=\arg \min_{c_{skip}(\sigma)}c_{out}(\sigma)^2,求导数零解,即d\left[(1-c_{skip}(\sigma))^2 \sigma_{data}^2-c_{skip}(\sigma)^2 \sigma^2\right]/dc_{skip}(\sigma)=0,得到c_{skip}(\sigma)=\sigma_{data}^2/(\sigma^2+\sigma_{data}^2),代入可得c_{out} = \sigma \sigma_{data}/\sqrt{\sigma^2+\sigma_{data}^2};为了让不同噪声强度下梯度尺寸接近,有\lambda(\sigma)=\frac{1}{c_{out}(\sigma)^2}=(\sigma^2+\sigma_{data}^2)/(\sigma\sigma_{data})^2。遵循之前的工作,设置输出层权重为0,初始时F_{\theta}(\cdot)=0,可以看到在任意\sigma损失的期望都是1,固定\sigma,有

“输入网络的噪声条件变量“设置为c_{noise}=\ln(\sigma)/4,也是避免数值范围太大,经验选择log-scale。这些设置下训练更稳定,可以放心地关注损失的设计。实验中\sigma_{data}=0.5

训练时噪声分布

怎么选择p_{train}(\sigma),设计训练时的噪声水平?

看下面子图a,画了不同任务下loss随噪声强度的变化曲线,绿色表示训练开始时的loss,蓝色和橙色为训练完成时的,如果在某个\sigma下,训练前后loss下降很多,就说明模型学到很多,loss几乎不变则说明这个噪声水平很难学或者没价值,阴影区域表示随机采样1万个训练样本的标准差。EMD训练时采样\sigma的概率分布为红色虚线,采用log-normal采样分布,让训练样本集中在中间区域,即\ln(\sigma) \sim \mathcal{N}(p_{mean}, P_{std}^2),设置P_{mean}=-1.2P_{std}=1.2

为了避免扩散模型在小数据集上发生过拟合,可以在训练时引入GAN系列论文中的数据增强技术,其中对训练集中的原始图片做了各式几何转换,同时为了避免增强泄露给模型,将增强系数作为条件输入F_{\theta},推理时设置为0不做增强,结果是数据增强能提升FID。随着模型改进,随机采样的相关性似乎在减弱,参考上图中的b(无条件CIFAR-10任务)、c(类条件ImageNet-64任务)子图。

统一框架下的不同参数化结果

先前工作的原始实现之间的差异可以总结为模型输入输出、图片数据的动态范围、对x的缩放、对\sigma的插值,在EMD架构下可以归纳为下列表格,F_{\theta}表示模型输出,图片数据用[-1,1]中的连续值表示,x和\sigma始终满足dx = \left[ \frac{\dot{s}(t)}{s(t)}x - s(t)^2 \dot{\sigma}(t) \sigma(t) \bigtriangledown_{x} \log p\left( \frac{x}{s(t)}; \sigma(t) \right) \right] dt

N表示在近似求解微分方程时的离散时间步数,但原始的模型权重可能不是在任意时间网格都训练过,假设训练时的时间步为\{u_i\}_{i\in [0,M]}

在Song et al.定义的VP版本中

在原始叙事中,设置f(t)= -\frac{1}{2}\beta(t)g(t)=\sqrt{\beta(t)},其中\beta(t)控制着噪声的增长速度,原始设置中被定义为一个线性调度\beta(t)=\beta_{min}+t(\beta_{max}-\beta_{min}),还采用线性时间调度t_{i<N} =\sigma^{-1}(\sigma_i)= 1+\frac{i}{N-1}(\epsilon_s-1)。令\beta_d=\beta_{max}-\beta_{min},定义积分\alpha(t)=\int_0^t \beta(\xi)d\xi=\frac{1}{2}\beta_d t^2+\beta_{min}t,可以推导出表格中的调度函数s(t)=\exp\left(\int_0^t [f(\xi)]d\xi\right)=\exp(-\frac{1}{2} \alpha(t))\sigma(t)=\sqrt{\int_0^t \frac{g(\xi)^2}{s(\xi)^2}d\xi} = \sqrt{\int_0^t \beta(\xi)e^{\alpha(t)}d\xi}=\sqrt{e^{\alpha(t)}-1},有s(t)=1/\sqrt{\sigma(t)^2+1}

通过噪声网络近似分数函数\bigtriangledown_x \log p_t(x) \approx -\frac{1}{\bar{\sigma}(t)}F_{\theta}(x; (M-1)t),其中M=1000,\bar{\sigma}(t)对应扰动核p_{ot}\left(x_t|x_0\right)的标准差s(t) \sigma(t),替换成用原始样本\hat{x}=x/s(t)表示的形式\bigtriangledown_{\hat{x}} \log p_t(\hat{x}; \sigma(t)) \approx -\frac{1}{\sigma(t)}F_{\theta}(s(t)\hat{x}; (M-1)t),代入s(t),用去噪器表示左边,有\frac{D_{\theta}(\hat{x};\sigma(t))-\hat{x}}{\sigma(t)^2}= -\frac{1}{\sigma(t)} F_{\theta}\left(\frac{1}{\sqrt{\sigma(t)^2+1}}\hat{x};(M-1)t\right),转换一下\sigma(t)\rightarrow\sigmat \rightarrow \sigma^{-1}(\sigma),得到

损失函数为\mathbb{E}_{t,y,n}\left[\frac{1}{\sigma(t)^2}\|D_{\theta}(y+n;\sigma(t))-y\|_2^2\right],即

实验时采用CIFAR-10对应的“DDPM++ cont. (VP)”权重,含62M可训练参数,噪声强度范围\sigma \in [\sigma(\epsilon_t), \sigma(1)] \approx[0.001, 152],远宽于EDM偏好的范围[0.002, 80],直接将该模型应用到算法1、2中。

原始实现中存在些疏忽。在欧拉求解器中dx/dt乘的是-1/N而非(\epsilon_s-1)/(N-1),最后一步从t_{N-1}=\epsilon_st_N=\epsilon_s - 1/N,当N<1000t_N<0,这意味着比如当N=128时,生成的图片中会包含大量噪声,作者修复了这些错误,每步步长都从时间序列统一推导出来,且明确终点t_N=0

在Song et al.定义的VE版本中

在原始叙事中,设置f(t)=0,g(t)=\sigma_{min} \sigma_d^t \sqrt{2 \log \sigma_d},其中\sigma_d = \sigma_{max}/\sigma_{min},根据对应的SDE推导出与扰动核相匹配的s(t)=\exp\left(\int_0^t [f(\xi)]d\xi\right)=1\begin{matrix} &\sigma(t)=\sqrt{\int_0^t \frac{g(\xi)^2}{s(\xi)^2}d\xi} = \sqrt{\int_0^t [\sigma_{min} \sigma_d^{\xi} \sqrt{2 \log \sigma_d}]^2 d\xi}\\ &=\sqrt{\int_0^t \sigma_{min}^2 [2\log \sigma_d] \sigma_d^{2\xi} d\xi}=\sigma_{min} \sqrt{\int_0^t [\log (\sigma_d^2)] [(\sigma_d^2)^{\xi}]d\xi}=\sigma_{min} \sqrt{\sigma_d^{2t}-1} \end{matrix}

但实践时希望的是指数增长噪声\sigma(t)=\sigma_{min}(\sigma_{max}/\sigma_{min})^t,噪声水平呈log-uniform分布\ln(\sigma) \sim \mathcal{U}\left(\ln(\sigma_{min}), \ln(\sigma_{max}) \right),采用离散的噪声序列\bar{\sigma}_{i<N}=\sigma_{min} \left( \frac{\sigma_{max}}{\sigma_{min}}\right)^{1-i/(N-1)}\bar{\sigma}_N=0,通过x_{i+1}=x_i+\frac{1}{2}(\bar{\sigma}_i^2-\bar{\sigma}_{i+1}^2)\bigtriangledown_x \log \bar{p}_i(x)近似求解PF ODE。稍微解释一下这个迭代式,在s(t)=1的情况下,dx/dt = -1/2\cdot g(t)^2 \cdot \bigtriangledown_x \log p_t(x),再代入g(t)^2 = s(t)^2 (2 \dot{\sigma}\sigma) = d\sigma^2/dt,得dx = -1/2 \cdot \bigtriangledown_x \log p_t(x) d\sigma^2,欧拉离散化可得。

这和在EMD框架下,设置t_i = \bar{\sigma}_i^2、s(t)=1、\sigma(t)=\sqrt{t},替换\bar{p}_i(x)=p(x;\sigma(t_i))时得到的ODE的Euler迭代公式一致(参考算法1):

原始针对CIFAR-10的设置中\sigma_{min}=0.01\sigma_{max}=50,图片范围x \in [0,1],现在要调整成[-1,1],这两就也得乘2。

原始通过下式近似分数函数\bigtriangledown_x \log p_t(x) \approx \bar{F}_{\theta}(x; \sigma(t)),分数网络\bar{F}_{\theta}已包含预处理和后处理步骤,由负噪声网络F_{\theta}得到,即\bar{F}_{\theta}(x;\sigma)=\frac{1}{\sigma}F_{\theta}\left(2x-1; \log (\sigma) \right),其中2x-1取值范围在[-1,1],\log(\sigma)对跨越数量级的噪声水平取对数,都是更稳定的网络输入。

为了统一到EDM框架,希望用\{c_{skip}, c_{out}, c_{in}, c_{noise}\}表示预/后处理过程,而不是直接集成到网络本身,考虑到图片表示范围的差异,需对\bigtriangledown_x \log p_t(x) \approx \frac{1}{\sigma}F_{\theta}\left(2x-1; \log (\sigma) \right)进行替换:p_t(x) \rightarrow p_t(2x-1)x \rightarrow \frac{1}{2}x+\frac{1}{2}\sigma \rightarrow \sigma/2,得到

\begin{matrix} \bigtriangledown_{[\frac{1}{2}x+\frac{1}{2}]}\log p_t (2[\frac{1}{2}x+\frac{1}{2}]-1) &\approx \frac{1}{[\frac{1}{2}\sigma]}F_{\theta}\left(2[\frac{1}{2}x+\frac{1}{2}]-1; \log [\frac{1}{2}\sigma] \right) \\ \bigtriangledown_x \log p_t(x) &\approx \frac{1}{\sigma} F_{\theta}\left(x;\log(\frac{1}{2}\sigma)\right) \end{matrix}

再用去噪器表示左边,得到\left(D_{\theta}(x; \sigma)-x \right)/\sigma^2 = \frac{1}{\sigma} F_{\theta}\left(x;\log(\frac{1}{2}\sigma)\right),化解为

损失函数和VP一样为\mathbb{E}_{t,y,n}\left[\frac{1}{\sigma(t)^2}\|D_{\theta}(y+n;\sigma(t))-y\|_2^2\right],代入指数增长噪声,有

实验时采用CIFAR-10对应的“NCSN++ cont. (VE)”权重,含63M可训练参数,噪声强度范围\sigma \in [\sigma(\epsilon_t), \sigma(1)] \approx[0.02, 100],而EDM偏好的范围是[0.002, 80],在重新实现时直接调整\sigma_{min}=0.002的话,模型会遇到没见过的噪声强度,但EDM在重新设计了训练分布,将log-uniform的p_{train}(\sigma)改为log-normal,并调整了损失权重,还做了数据增强后,模型可以支持更小的\sigma了。再改用统一的条件处理框架,这样就能直接将该模型应用到算法1、2中了。另外将样本的精度从单精度调整成双精度,减少了高步数采样误差。

改进版DDPM和DDIM

Song et al.观察到确定性的DDIM采样器可以表示成dx(t)=\epsilon_{\theta}^{(t)} \left(\frac{x(t)}{\sqrt{\sigma(t)^2+1}} \right) d \sigma(t)的欧拉迭代过程,其中噪声网络的输入输出都是缩放过程的,也就是说对于x(t)=y(t)+n(t)而言,有\left(\frac{x(t)}{\sqrt{\sigma(t)^2+1}} \right) \approx \frac{n(t)}{\sigma(t)},代入预测原始数据的去噪器D_{\theta}(x(t); \sigma(t)) \approx y,可以得到\epsilon_{\theta}^{(t)} \left(\frac{x(t)}{\sqrt{\sigma(t)^2+1}} \right) =\frac{x(t)-D_{\theta}(x(t); \sigma(t))}{\sigma(t)}。对于最理想的\epsilon(\cdot)D(\cdot),有\epsilon(\cdot) = -\sigma(t)[(D(\cdot)/\sigma(t)^2]=-\sigma(t) \bigtriangledown_{x(t)} \log p\left(x(t); \sigma(t) \right),同时设置\sigma(t)=t,上列ODE便可简化成dx(t) = -t \bigtriangledown_{x(t)} \log p\left(x(t); \sigma(t) \right),和在EDM框架下设置s(t)=1、\sigma(t)=t得到的ODE一致。

原始DDPM的前向过程是一个逐步往数据中添加高斯噪声的马尔可夫链,根据离散的方差调度\{\beta_1, \dots,\beta_T\}(如一个线性调度),有q(\bar{x}_t|\bar{x}_{t-1})=\mathcal{N}(\bar{x}_t;\sqrt{1-\beta_t}\bar{x}_{t-1}, \beta_tI),得到从\bar{x}_0\bar{x}_t的转移概率为q(\bar{x}_t|\bar{x}_0)=\mathcal{N}(\bar{x}_t;\sqrt{\bar{\alpha}_t}\bar{x}_0, (1-\bar{\alpha}_t)I),其中\bar{\alpha}_t = \prod_{s=1}^t (1-\beta_s);也可以先定义\{\bar{\alpha}_t\}(比如一个余弦调度),再推导\beta_t = 1-\frac{\bar{\alpha}_t}{\bar{\alpha}_{t-1}}

在改进版iDDPM中便定义\bar{\alpha}_t = \frac{f(t)}{f(0)},其中f(t) = \cos^2 \left( \frac{t/T+s}{1+s} \cdot \frac{\pi}{2}\right),s=0.008,不过Nichol et al.在实现时省去了“除以f(0)“这点。同时为了避免靠近t=T 时出现奇异点(发生除零操作),原始实现中会对\beta_t做裁剪,改用\beta_s' = \min(\beta_s, 0.999),对应地\bar{\alpha}'_t = \prod_{s=1}^t (1-\beta_s')=\prod_{s=1}^t (1-\min(\beta_s,0.999))=\prod_{s=1}^t (1-\min(1-\frac{\bar{\alpha}_s}{\bar{\alpha}_{s-1}},0.999))=\prod_{s=1}^t \max (\frac{\bar{\alpha}_s}{\bar{\alpha}_{s-1}},0.001)

现在,在EDM框架下重新介绍一遍推理。定义iDDPM的采样步为\{u_i\}i \in \{0, \dots,M\}按噪声水平\sigma(u_j)降序排列,得对之前的式子替换T \rightarrow Mt \rightarrow M-j,设置常量C_1=0.001C_2=0.008,有q(\bar{x}_j|\bar{x}_M)=\mathcal{N}(\bar{x}_j;\sqrt{\bar{\alpha}'_j}\bar{x}_0, (1-\bar{\alpha}'_j)I)\bar{\alpha}_j = \cos^2 \left( \frac{(M-j)/M+C_2}{1+C_2} \cdot \frac{\pi}{2}\right)=\sin \left( \frac{\pi}{2} \frac{j}{M(1+C_2)}\right)\bar{\alpha}'_j = \prod_{s=M-1}^j \max \left(\frac{\bar{\alpha}_j}{\bar{\alpha}_{j-1}},C_1\right)=\bar{\alpha}'_{j+1} \max \left(\frac{\bar{\alpha}_j}{\bar{\alpha}_{j-1}},C_1\right)

为了能够匹配“扰动核“,即p_{0t}(x(u_j)|x(0)) = q(\bar{x}_j|\bar{x}_M),展开后为\mathcal{N}(x(u_j);s(t)x(0),s(u_j)^2\sigma(u_j)^2I)=\mathcal{N} \left(\bar{x}_j;\sqrt{\bar{\alpha}'_j}\bar{x}_M, (1-\bar{\alpha}'_j)I\right)。先替换s(t)=1\sigma(t)=t\bar{x}_M=x(0),得到\mathcal{N}(x(u_j);x(0),u_j^2I)=\mathcal{N} \left(\bar{x}_j;\sqrt{\bar{\alpha}'_j}x(0), (1-\bar{\alpha}'_j)I\right),通过定义\bar{x}_j = \sqrt{\bar{\alpha}'_j} x(u_j)可以匹配两个分布的均值,即

\mathcal{N}(x(u_j);x(0),u_j^2I)=\mathcal{N} \left(\sqrt{\bar{\alpha}'_j} x(u_j);\sqrt{\bar{\alpha}'_j}x(0), (1-\bar{\alpha}'_j)I\right)=\mathcal{N}\left(x(u_j); x(0), \frac{1-\bar{\alpha}'_j}{\bar{\alpha}'_j}I\right)

然后是匹配方差u_j^2 = \frac{1-\bar{\alpha}'_j}{\bar{\alpha}'_j},解得\bar{\alpha}'_j = \frac{1}{u_j^2+1},对等式左边进行替换,有

上式给出了\{u_j\}的递推公式,边界满足u_M=0,这就是采样时“时间步“的设计。

之前说过,网络预测的是缩放后的噪声,在设置\sigma(t)=t时有

找到最靠近当前噪声水平\sigmau_j,下标j作为输入网络的噪声条件,即

这和VP的预条件公式是一样的。

训练时,采用和VP同样的方案定义主损失L_{simple}(iDDPM中还有第二个损失项L_{vlb})和噪声分布,从\{u_j\}中均匀抽取\sigma,即\sigma = u_jj \sim u(0, M-1),设置\lambda(\sigma) = 1/\sigma^2

实验时采用ImageNet-64对应的“ADM (dropout)”权重,296M可训练参数,支持M=1000个离散噪声水平\sigma \in \{u_j\} \approx \{20291, 642, 321, 214, 160, 128, 106, 92, 80, 71, \dots, 0.0064\},所以在接入(假设训练时\sigma连续的)EDM统一采样框架时就存在一些问题:

  1. EDM的采样算法通常使用N个时间步进行采样,实际上N \ll M,需要重新选择时间步。可以采用线性映射,同时前8个\sigma于EDM的\sigma_{max} \approx 80而言太大了,选择直接忽略,从第9个开始映射,取j=\left \lfloor j_0 + (M-1-j_0)/(N-1) \cdot i \right \rfloorj_0=8t_i=u_j

  2. EDM采样离散时间步时,需要将\sigma_i映射到最接近的离散噪声,即\sigma_i \leftarrow u_{\arg \min_j}|u_j-\sigma_i|,同时设定\sigma_{min}=0.0064 \approx u_{N-1}

  3. 随机采样(算法2)产生的中间时间步\hat{t}_i也有不在\{u_j\}的可能,也映射为\hat{t}_i \leftarrow u_{\arg \min_j}|u_j-(t_i + \gamma_i t_i)|

于是便可以直接将原网络应用到算法1、2,请忽略网络预测的方差。

消融实验

下图中,蓝线表示采用原始配套的采样器,橙色表示作者处理了原始实现中一些疏忽后的结果。绿色是采用Heun求解器和EDM时间步,在所有样例中,能比欧拉方法在达到相同FID的情况下调用更少次模型。红色线表示使用EDM设定的\sigma(t)和s(t)。还有黑色的虚线,在ODE的调度设计上用了很复杂的ODE求解器,得不偿失。

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐