论文信息

  • 标题:FLOW MATCHING FOR GENERATIVE MODELING
  • 会议:arXiv:2210.02747
  • 单位:Meta AI (FAIR)、Weizmann Institute of Science
  • 代码:github.com/atong01/conditional-flow-matching(社区主流开源实现)
  • 论文:https://arxiv.org/pdf/2210.02747.pdf

开篇:生成模型的困局与破局

这两年AI生成领域,扩散模型(Diffusion Model)可谓是独领风骚,从AI绘画到文生视频,几乎所有爆款应用都离不开它的身影。但用过扩散模型的开发者都懂它的痛:训练周期长到离谱,采样要几十上百步才能出一张清晰的图,想加速还得各种魔改,稍不注意就会出现画面崩坏。

扩散模型的本质,是用一个随机微分方程(SDE),给图片一步步加噪,再让模型学逆过程去噪。但这种随机过程的设定,天然就把生成路径限制在了一个很窄的范围里,就像你开车从北京到上海,只能走一条弯弯曲曲的国道,不仅绕路,还得频繁踩刹车油门,效率极低。

那有没有更高效的方案?其实早就有了——连续归一化流(Continuous Normalizing Flows, CNF)。它用确定性的常微分方程(ODE)来建模从噪声到真实图片的完整过程,理论上能建模任意生成路径,还能精确计算样本的似然,完美避开了扩散模型的天生缺陷。

但CNF一直有个致命问题:训不动。传统的最大似然训练要反复做ODE数值仿真,计算量爆炸,别说ImageNet这种大数据集,就连CIFAR10都能训到天荒地老。之前的无仿真训练方法,要么高维场景下算不动,要么梯度有偏,训出来的模型效果稀烂。

而这篇论文提出的Flow Matching(流匹配,简称FM),直接把CNF的训练难题一锅端了。它设计了一套无仿真、无偏、可无限扩展的训练框架,不仅能把扩散模型的路径作为特例纳入其中,还能用上更高效的最优传输(Optimal Transport, OT)路径,实现了更快的训练、更少的采样步数、更好的生成效果,甚至只用35%的训练迭代量,就能超过扩散模型训几百万步的效果。
在这里插入图片描述

图1 无条件ImageNet-128生成样本(出处:原论文Figure 1)
这些样本均由FM+OT路径训练的CNF生成,只用了50万次训练迭代,就达到了同期SOTA扩散模型400万+迭代的效果,样本细节丰富、真实感拉满。


背景与前置知识:先搞懂核心概念

在讲FM的核心方法之前,我们先把基础概念掰碎了讲,每个专业术语都配上大白话解释,保证新手也能看懂。

连续归一化流(CNF)到底是什么?

CNF的核心思想,就是用一个连续的、平滑的微分方程,把简单的高斯噪声分布,一步步“流”成我们想要的真实图片分布。

你可以把它想象成一场精准的车队运输:

  • 起点(t=0):一个巨大的停车场,所有车都停在高斯噪声分布的车位上,对应我们的输入噪声;
  • 终点(t=1):真实图片的分布区域,我们要让所有车精准开到对应的位置,形成和真实数据一模一样的分布;
  • 行驶路线:由ODE方程定义的流映射ϕt(x)\phi_t(x)ϕt(x),每辆车的行驶轨迹完全由这个映射决定;
  • 导航指令:时间依赖的向量场vt(x)v_t(x)vt(x),告诉模型在t时刻、位置x处,应该往哪个方向、以多大速度行驶,这就是我们要用神经网络拟合的目标。
核心公式1:流映射的ODE定义

ddtϕt(x)=vt(ϕt(x))\frac{d}{d t} \phi_{t}(x)=v_{t}\left(\phi_{t}(x)\right)dtdϕt(x)=vt(ϕt(x))
ϕ0(x)=x\phi _{0}(x)=xϕ0(x)=x

符号 数学含义 通俗解释
ϕt(x)\phi_t(x)ϕt(x) 流映射函数,将t=0时刻的输入x映射到t时刻的状态 车队在t时刻的位置,从起点x出发,沿着路线开到的位置
vt(⋅)v_t(\cdot)vt() 时间依赖的向量场,由神经网络参数化,是CNF的核心学习目标 导航给的行驶指令,告诉t时刻在位置ϕt(x)\phi_t(x)ϕt(x)的车,该往哪个方向开、开多快
ttt 时间变量,取值范围固定为[0,1][0,1][0,1] 从起点到终点的行驶时间,t=0是起点,t=1是终点
ϕ0(x)=x\phi_0(x)=xϕ0(x)=x ODE的初始条件 起点位置就是初始的噪声x,车还没开的时候,位置就是初始车位
核心公式2:分布的Push-forward(前推)

pt=[ϕt]∗p0p_{t}=[\phi _{t}]_{*} p_{0}pt=[ϕt]p0

符号 数学含义 通俗解释
ptp_tpt t时刻的概率密度分布 t时刻车队的整体分布,所有车在t时刻的位置形成的分布
p0p_0p0 t=0时刻的先验分布,通常为标准高斯分布N(0,I)\mathcal{N}(0,I)N(0,I) 起点停车场的车辆分布,也就是我们输入的噪声分布
[ϕt]∗[\phi_t]_*[ϕt] Push-forward(前推)算子,通过流映射ϕt\phi_tϕt将分布p0p_0p0转换为ptp_tpt 按照规划的路线,把所有车从起点开到t时刻的位置,形成新分布的过程

简单说,只要我们学会了正确的向量场vtv_tvt,就能通过解这个ODE,把任意高斯噪声,转换成一张真实的图片。但问题来了:我们怎么让模型学到正确的vtv_tvt

传统CNF的训练方法,要反复解ODE来计算似然,计算量直接拉满,这就是FM要解决的核心痛点。


核心方法:Flow Matching的完整逻辑

FM的核心创新,分为两步走:

  1. 定义Flow Matching(FM)目标:让模型学的向量场,和我们想要的目标向量场完全匹配,损失就是两者的均方误差;
  2. 提出Conditional Flow Matching(CFM)目标:解决了原始FM目标无法直接计算的问题,证明了CFM和FM的梯度完全等价,让训练变得简单、高效、无偏。

3.1 原始Flow Matching(FM)目标

我们的最终目标,是让模型学的向量场vt(x)v_t(x)vt(x),完美匹配能生成目标分布的真实向量场ut(x)u_t(x)ut(x)。所以FM的损失函数,就是两个向量场的均方误差,在时间和分布上取期望。

核心公式3:Flow Matching损失函数

LFM(θ)=Et,pt(x)∥vt(x)−ut(x)∥2\mathcal{L}_{FM}(\theta)=\mathbb{E}_{t, p_{t}(x)}\left\| v_{t}(x)-u_{t}(x)\right\| ^{2}LFM(θ)=Et,pt(x)vt(x)ut(x)2

符号 数学含义 通俗解释
LFM(θ)\mathcal{L}_{FM}(\theta)LFM(θ) FM的损失函数,θ\thetaθ是向量场网络的可学习参数 我们要最小化的训练损失,值越小,模型学的向量场越接近真实值
Et,pt(x)\mathbb{E}_{t, p_t(x)}Et,pt(x) 对时间t和t时刻的分布pt(x)p_t(x)pt(x)取数学期望 我们要在整个行驶时间[0,1],和所有可能的车辆位置上,都让导航指令尽可能准确
ttt 从均匀分布U[0,1]U[0,1]U[0,1]中随机采样的时间变量 随机选一个行驶时间点,来计算当前的损失
vt(x)v_t(x)vt(x) 模型学习的参数化向量场,由神经网络实现 我们的模型输出的导航指令
ut(x)u_t(x)ut(x) 能生成目标概率路径ptp_tpt的真实目标向量场 完美的导航指令,能让所有车精准从起点开到终点
∣⋅∣2|\cdot|^22 L2范数的平方,也就是均方误差 衡量模型输出的导航指令和完美指令的差距

这个损失函数看起来简单又完美,但有个致命问题:我们根本不知道真实的ut(x)u_t(x)ut(x)pt(x)p_t(x)pt(x)长什么样。我们只有真实图片的样本,没有完整的分布表达式,更没法直接计算对应的向量场,所以这个损失根本没法直接算。

这时候,论文的第一个核心定理就来了:我们可以把复杂的边际分布,拆成无数个单样本的条件分布的叠加,把不可计算的边际向量场,拆成可计算的条件向量场的加权和。

3.2 条件概率路径与边际向量场的构建

我们手里只有真实图片的样本x1x_1x1(服从未知的数据分布q(x1)q(x_1)q(x1)),那我们就给每个真实样本x1x_1x1,定义一个条件概率路径pt(x∣x1)p_t(x|x_1)pt(xx1),满足两个边界条件:

  1. t=0时,p0(x∣x1)=p(x)=N(0,I)p_0(x|x_1) = p(x) = \mathcal{N}(0,I)p0(xx1)=p(x)=N(0,I),也就是所有样本的条件路径,起点都是同一个标准高斯分布;
  2. t=1时,p1(x∣x1)p_1(x|x_1)p1(xx1)是一个集中在x1x_1x1附近的高斯分布,也就是终点就是真实样本本身。

有了每个样本的条件路径,我们把所有样本的条件路径积分起来,就得到了整体的边际概率路径:
pt(x)=∫pt(x∣x1)q(x1)dx1p_{t}(x)=\int p_{t}\left(x | x_{1}\right) q\left(x_{1}\right) d x_{1}pt(x)=pt(xx1)q(x1)dx1

通俗说,就是给每一张真实图片,都规划一条从噪声到它的专属路线,把所有这些路线合起来,就是从噪声分布到整个真实图片分布的完整路线。

对应的,整体的边际向量场ut(x)u_t(x)ut(x),也可以由每个样本的条件向量场ut(x∣x1)u_t(x|x_1)ut(xx1)加权求和得到:
ut(x)=∫ut(x∣x1)pt(x∣x1)q(x1)pt(x)dx1u_{t}(x)=\int u_{t}\left(x | x_{1}\right) \frac{p_{t}\left(x | x_{1}\right) q\left(x_{1}\right)}{p_{t}(x)} d x_{1}ut(x)=ut(xx1)pt(x)pt(xx1)q(x1)dx1

论文的定理1严格证明了:这个加权求和得到的边际向量场,恰好就是能生成边际概率路径pt(x)p_t(x)pt(x)的真实向量场。

这就把一个不可解的全局问题,拆成了无数个可解的单样本问题。但还是有个问题:边际向量场的积分还是没法直接算,损失还是没法求。

这时候,论文的第二个核心创新——Conditional Flow Matching(CFM) 就登场了。

3.3 Conditional Flow Matching(CFM):可计算的等价训练目标

既然边际损失算不了,那我们直接在每个样本的条件路径上算损失行不行?论文的定理2给出了肯定的答案:CFM损失和原始FM损失,对模型参数θ\thetaθ的梯度完全相等

也就是说,优化CFM损失,就完全等价于优化原始的FM损失,而且CFM损失是完全可计算的!

核心公式4:Conditional Flow Matching损失函数

LCFM(θ)=Et,q(x1),pt(x∣x1)∥vt(x)−ut(x∣x1)∥2\mathcal{L}_{CFM}(\theta)=\mathbb{E}_{t, q\left(x_{1}\right), p_{t}\left(x | x_{1}\right)}\left\| v_{t}(x)-u_{t}\left(x | x_{1}\right)\right\| ^{2}LCFM(θ)=Et,q(x1),pt(xx1)vt(x)ut(xx1)2

符号 数学含义 通俗解释
LCFM(θ)\mathcal{L}_{CFM}(\theta)LCFM(θ) 条件流匹配损失函数 我们实际训练时用的损失,和原始FM损失梯度完全等价,可直接计算
$\mathbb{E}_{t, q(x_1), p_t(x x_1)}$ 对时间t、真实数据分布q(x1)q(x_1)q(x1)、条件路径分布$p_t(x
q(x1)q(x_1)q(x1) 真实图片的数据分布,我们有它的样本 真实图片的数据集,比如ImageNet、CIFAR10
$u_t(x x_1)$ 真实样本x1x_1x1对应的条件向量场,有闭式解,可直接计算

定理2的核心意义,就是把一个原本不可计算的全局优化问题,变成了一个可批量计算的逐样本优化问题。我们只需要给每个真实样本,定义好它的条件概率路径pt(x∣x1)p_t(x|x_1)pt(xx1)和对应的条件向量场ut(x∣x1)u_t(x|x_1)ut(xx1),就能用随机梯度下降训练模型了!

3.4 通用高斯条件路径:一套公式覆盖所有场景

论文里进一步给出了通用高斯条件路径的完整形式,把扩散模型的路径、OT路径都纳入了同一个框架里,还给出了条件向量场的闭式解。

我们定义条件概率路径是一个随时间变化的高斯分布:
pt(x∣x1)=N(x∣μt(x1),σt(x1)2I)p_{t}\left(x | x_{1}\right)=\mathcal{N}\left(x | \mu_{t}\left(x_{1}\right), \sigma_{t}\left(x_{1}\right)^{2} I\right)pt(xx1)=N(xμt(x1),σt(x1)2I)

符号 数学含义 通俗解释
μt(x1)\mu_t(x_1)μt(x1) 条件高斯分布的均值,是关于时间t和真实样本x1x_1x1的函数 单样本路径在t时刻的中心位置,也就是车队在t时刻应该开到的核心位置
σt(x1)\sigma_t(x_1)σt(x1) 条件高斯分布的标准差,是关于时间t的函数 单样本路径在t时刻的分布宽度,t=0时是1(标准高斯),t=1时趋近于0(集中在真实样本上)
III 单位矩阵,代表高斯分布是各向同性的 每个维度的方差都一样,简化计算

同时,我们定义条件流映射是一个简单的仿射变换,把标准高斯噪声x0x_0x0映射到条件路径的t时刻状态:
ψt(x)=σt(x1)x+μt(x1)\psi_{t}(x)=\sigma_{t}(x_{1}) x+\mu_{t}\left(x_{1}\right)ψt(x)=σt(x1)x+μt(x1)

这个仿射变换,就是单样本的专属行驶路线:t=0时,ψ0(x)=x\psi_0(x)=xψ0(x)=x,就是起点噪声;t=1时,ψ1(x)=σ1x1+μ1(x1)≈x1\psi_1(x)=\sigma_1 x_1 + \mu_1(x_1) \approx x_1ψ1(x)=σ1x1+μ1(x1)x1,就是终点真实样本。

有了这个通用形式,论文的定理3直接给出了条件向量场的闭式解,不用再做任何复杂的推导:
ut(x∣x1)=σt′(x1)σt(x1)(x−μt(x1))+μt′(x1)u_{t}\left(x | x_{1}\right)=\frac{\sigma_{t}'\left(x_{1}\right)}{\sigma_{t}\left(x_{1}\right)}\left(x-\mu_{t}\left(x_{1}\right)\right)+\mu_{t}'\left(x_{1}\right)ut(xx1)=σt(x1)σt(x1)(xμt(x1))+μt(x1)

符号 数学含义 通俗解释
σt′\sigma_t'σt σt\sigma_tσt对时间t的导数 标准差随时间的变化率,也就是分布宽度的收缩速度
μt′\mu_t'μt μt\mu_tμt对时间t的导数 均值随时间的变化率,也就是路线中心的移动速度
x−μt(x1)x-\mu_t(x_1)xμt(x1) 当前位置和路径中心的偏移量 车辆当前位置和规划路线中心的距离

这个公式太重要了!只要我们定义好均值函数μt(x1)\mu_t(x_1)μt(x1)和标准差函数σt(x1)\sigma_t(x_1)σt(x1),就能直接算出对应的条件向量场,代入CFM损失就能训练模型,不用再做任何复杂的数学推导。

3.5 两大核心实例:扩散路径 vs 最优传输(OT)路径

基于上面的通用公式,论文里给出了两个最关键的实例,一个是和现有扩散模型对齐的扩散路径,另一个是论文主推的、效率拉满的OT路径。

实例1:扩散条件路径

扩散模型的加噪过程,本质上就是一个高斯条件路径,我们可以直接把它套进通用公式里。比如最常用的方差保持(VP)扩散路径,它的均值和标准差函数为:
μt(x1)=α1−tx1,σt(x1)=1−α1−t2\mu_{t}\left(x_{1}\right)=\alpha_{1-t} x_{1}, \quad \sigma_{t}\left(x_{1}\right)=\sqrt{1-\alpha_{1-t}^{2}}μt(x1)=α1tx1,σt(x1)=1α1t2
其中αt=e−12T(t)\alpha_t = e^{-\frac{1}{2} T(t)}αt=e21T(t)T(t)=∫0tβ(s)dsT(t)=\int_{0}^{t} \beta(s) d sT(t)=0tβ(s)dsβ(s)\beta(s)β(s)是扩散模型的噪声调度函数。

把它代入定理3的公式,就能直接算出扩散路径的条件向量场,和扩散模型里的概率流ODE完全一致。这意味着,FM框架完全可以用来训练扩散模型,而且比传统的得分匹配方法更稳定、更鲁棒

但扩散路径有个天生的缺陷:它的均值和标准差函数是非线性的,对应的向量场方向会随时间不断变化,生成轨迹是弯曲的,甚至会出现“过冲”——也就是开过了终点,还要往回开,不仅增加了拟合难度,还需要更多的采样步数才能收敛。

实例2:最优传输(OT)条件路径

这是论文的核心亮点,也是FM框架能超越扩散模型的关键。既然我们能自由定义路径,那为什么不选一条最简单、最高效的直线路径?

OT路径的均值和标准差函数,直接用最简单的线性变化:
μt(x1)=tx1,σt(x1)=1−(1−σmin)t\mu_{t}(x_1)=t x_{1}, \quad \sigma_{t}(x_1)=1-\left(1-\sigma_{min }\right) tμt(x1)=tx1,σt(x1)=1(1σmin)t

符号 数学含义 通俗解释
σmin\sigma_{min}σmin t=1时刻的最小标准差,通常设为0.001,保证t=1时分布集中在真实样本上 终点位置的分布宽度,几乎就是一个点,精准落在真实样本上

把它代入定理3的公式,就能得到OT路径的条件向量场:
ut(x∣x1)=x1−(1−σmin)x1−(1−σmin)tu_{t}\left(x | x_{1}\right)=\frac{x_{1}-\left(1-\sigma_{min }\right) x}{1-\left(1-\sigma_{min }\right) t}ut(xx1)=1(1σmin)tx1(1σmin)x

对应的条件流映射,也就是生成轨迹,是完美的直线:
ψt(x)=(1−(1−σmin)t)x+tx1\psi_{t}(x)=\left(1-\left(1-\sigma_{min }\right) t\right) x+t x_{1}ψt(x)=(1(1σmin)t)x+tx1
在这里插入图片描述

图2 扩散路径与OT路径的条件向量场对比(出处:原论文Figure 2)
左图是扩散路径的条件得分函数(也就是扩散模型的拟合目标),向量方向随时间不断变化;右图是OT路径的条件向量场,方向全程恒定不变。显然,OT路径的拟合难度要低得多,神经网络更容易学准。

在这里插入图片描述

图3 扩散路径与OT路径的生成轨迹对比(出处:原论文Figure 3)
左图是扩散路径的轨迹,弯弯曲曲,还有大量过冲和折返;右图是OT路径的轨迹,全程是笔直的直线,从起点直接到终点,没有任何多余的移动。

OT路径的优势是碾压级的:

  1. 轨迹是直线,向量场方向恒定:神经网络的拟合难度大幅降低,训练收敛更快,更容易学准;
  2. 无过冲、无折返:采样时只需要更少的步数就能得到高质量样本,甚至10步以内就能出图;
  3. 闭式解极其简单:计算量比扩散路径小得多,训练和推理都更快;
  4. 理论最优:这个线性路径,恰好是两个高斯分布之间的Wasserstein-2最优传输位移插值,是理论上最短、最省力的路径。

对应的,OT路径的CFM损失也简化到了极致,连时间相关的系数都消掉了:
LCFM(θ)=Et,q(x1),p(x0)∥vt(ψt(x0))−(x1−(1−σmin)x0)∥2\mathcal{L}_{CFM}(\theta)=\mathbb{E}_{t, q\left(x_{1}\right), p\left(x_{0}\right)}\left\| v_{t}\left(\psi_{t}\left(x_{0}\right)\right)-\left(x_{1}-\left(1-\sigma_{min }\right) x_{0}\right)\right\| ^{2}LCFM(θ)=Et,q(x1),p(x0)vt(ψt(x0))(x1(1σmin)x0)2

这个损失函数,就是我们实际训练时最常用的形式,简单、高效、效果拉满。


实验结果与深度分析

论文做了极其全面的实验,从2D玩具数据到CIFAR10、ImageNet 32×32/64×64/128×128的无条件生成,再到64×64→256×256的条件超分任务,全面验证了FM框架的优越性,尤其是OT路径的碾压级表现。

4.1 核心性能对比:无条件生成任务

论文用完全相同的UNet架构、超参数和训练轮数,对比了FM框架和主流扩散模型方法(DDPM、Score Matching、ScoreFlow)的性能,核心指标包括:

  • NLL(负对数似然,单位BPD):衡量模型对数据分布的拟合能力,越小越好;
  • FID(弗雷歇 inception 距离):衡量生成样本的质量和多样性,越小越好;
  • NFE(函数评估次数):衡量采样时的计算量,也就是采样步数,越小越快。

表格1 不同方法在CIFAR10、ImageNet 32×32/64×64上的核心性能对比(出处:原论文Table 1)

模型 CIFAR-10 ImageNet 32 × 32 ImageNet 64 × 64
NLL ↓ FID ↓ NFE ↓ NLL ↓ FID ↓ NFE ↓ NLL ↓ FID ↓ NFE ↓
扩散基线
DDPM 3.12 7.48 274 3.54 6.99 262 3.32 17.36 264
Score Matching 3.16 19.94 242 3.56 5.68 178 3.40 19.74 441
ScoreFlow 3.09 20.78 428 3.55 14.14 195 3.36 24.95 601
本文方法
FM w / Diffusion 3.10 8.06 183 3.54 6.37 193 3.33 16.88 187
FM w / OT 2.99 6.35 142 3.53 5.02 122 3.31 14.45 138

核心结果分析

  1. 全面超越扩散基线:FM-OT在所有数据集、所有指标上,都全面超越了主流扩散方法。CIFAR10上FID降到6.35,比DDPM低1.13;ImageNet 64×64上FID降到14.45,比DDPM低近3个点;
  2. 采样效率碾压:FM-OT的NFE仅122-142,比DDPM的260+少了近一半,比ScoreFlow少了70%以上,采样速度直接翻倍;
  3. 扩散路径也有提升:即使用和扩散模型一样的路径,FM训练的模型也比传统得分匹配方法更稳定,FID更低,NFE更少,证明了FM框架本身的优越性;
  4. 似然拟合更优:FM-OT的NLL在所有数据集上都是最低的,证明它对数据分布的拟合更精准,这是扩散模型很难做到的。

在这里插入图片描述

图4 2D棋盘格数据的生成轨迹对比(出处:原论文Figure 4)
左图可以看到,FM-OT在t=1/3时就已经形成了棋盘格的雏形,而扩散路径要到t=2/3才开始出现结构;右图可以看到,FM-OT在NFE=10时就已经生成了清晰的棋盘格,而扩散方法需要NFE=20以上才能达到类似效果,直观体现了OT路径的高效性。

4.2 训练收敛速度对比

论文还对比了不同方法的训练收敛速度,结果堪称降维打击。
在这里插入图片描述

图5 ImageNet 64×64训练过程中的FID变化曲线(出处:原论文Figure 5)
横轴是训练轮次,纵轴是FID值。FM-OT的FID下降速度远超所有扩散基线,只用了不到100个epoch,就达到了扩散方法200个epoch都达不到的FID值。

更夸张的是,ImageNet 128×128的实验中,主流扩散模型ADM需要训436万次迭代,而FM-OT只用了50万次迭代(仅11%的计算量),就达到了20.9的FID,超过了同期几乎所有无条件生成模型。

4.3 低步数采样效率对比

工业场景里,最看重的就是低步数下的采样质量,毕竟用户不想等几十秒才出一张图。论文专门测试了不同方法在低NFE下的采样误差和样本质量。
在这里插入图片描述

图7 低NFE下的采样误差与FID对比(出处:原论文Figure 7)
左图是ODE求解的数值误差,FM-OT在NFE=60时,误差就已经低于扩散方法NFE=100的水平;右图是FID随NFE的变化,FM-OT在NFE=20时,FID就已经降到了扩散方法NFE=100的水平,哪怕NFE=10,也能生成可用的样本。

这意味着,FM-OT模型可以用10-20步采样,达到扩散模型100步的效果,推理速度直接提升5-10倍,这对落地应用来说是颠覆性的提升。

4.4 条件生成任务:图像超分

论文还验证了FM框架在条件生成任务上的能力,做了64×64→256×256的ImageNet超分任务,和同期SOTA超分扩散模型SR3做了对比。

表格2 图像超分任务性能对比(出处:原论文Table 2)

模型 FID ↓ IS ↑ PSNR ↑ SSIM ↑
Reference(原图) 1.9 240.8
Regression(双三次插值) 15.2 121.1 27.9 0.801
SR3(扩散SOTA) 5.2 180.1 26.4 0.762
FM w / OT 3.4 200.8 24.7 0.747

结果分析:FM-OT的FID比SR3低了1.8,IS高了20.7,生成的超分图片更接近原图,真实性和多样性都远超扩散模型,证明了FM框架在条件生成任务上同样有极强的泛化能力。


核心代码实现

下面是基于PyTorch和torchdiffeq的FM-OT完整实现,完全对齐论文里的公式,包含核心损失计算、向量场网络、训练循环和采样函数,开箱即用。

环境依赖

pip install torch torchvision torchdiffeq tqdm

完整代码实现

import torch
import torch.nn as nn
import torch.optim as optim
from torchdiffeq import odeint
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np

# ===================== 1. 工具函数:OT路径的核心计算 =====================
def ot_conditional_flow(x0: torch.Tensor, x1: torch.Tensor, t: torch.Tensor, sigma_min: float = 0.001):
    """
    计算OT路径的条件流映射ψ_t(x0),对齐论文公式(22)
    :param x0: 初始噪声,shape [batch_size, dim]
    :param x1: 真实样本,shape [batch_size, dim]
    :param t: 时间变量,shape [batch_size, 1]
    :param sigma_min: t=1时刻的最小标准差
    :return: ψ_t(x0):t时刻的流状态,shape [batch_size, dim]
    """
    sigma_t = 1 - (1 - sigma_min) * t
    mu_t = t * x1
    return sigma_t * x0 + mu_t

def ot_conditional_vector_field(x0: torch.Tensor, x1: torch.Tensor, sigma_min: float = 0.001):
    """
    计算OT路径的目标条件向量场,对齐论文公式(21)的简化形式
    :param x0: 初始噪声,shape [batch_size, dim]
    :param x1: 真实样本,shape [batch_size, dim]
    :param sigma_min: t=1时刻的最小标准差
    :return: 目标向量场u_t,shape [batch_size, dim]
    """
    return x1 - (1 - sigma_min) * x0

def cfm_ot_loss(model: nn.Module, x1: torch.Tensor, sigma_min: float = 0.001):
    """
    计算OT路径的CFM损失,对齐论文公式(23)
    :param model: 向量场网络v_t(x),输入t和x,输出向量场
    :param x1: 真实样本,shape [batch_size, *data_shape]
    :param sigma_min: t=1时刻的最小标准差
    :return: 标量损失值
    """
    batch_size = x1.shape[0]
    data_shape = x1.shape[1:]
    dim = np.prod(data_shape)
    
    # 1. 随机采样时间t ~ U[0,1]
    t = torch.rand(batch_size, device=x1.device).unsqueeze(1)  # [batch_size, 1]
    
    # 2. 采样初始噪声x0 ~ N(0,I)
    x0 = torch.randn_like(x1)  # [batch_size, *data_shape]
    
    # 3. 计算OT路径的中间状态ψ_t(x0)
    psi_t = ot_conditional_flow(x0, x1, t.view(-1, *([1]*len(data_shape))), sigma_min)
    
    # 4. 计算模型预测的向量场
    v_pred = model(t.squeeze(), psi_t)
    
    # 5. 计算目标向量场
    v_target = ot_conditional_vector_field(x0, x1, sigma_min)
    
    # 6. 计算均方误差损失
    loss = nn.MSELoss()(v_pred, v_target)
    return loss

# ===================== 2. 向量场网络:UNet(图像生成用) =====================
class SinusoidalPositionEmbedding(nn.Module):
    """时间t的正弦位置编码,和扩散模型的一致"""
    def __init__(self, dim: int):
        super().__init__()
        self.dim = dim

    def forward(self, t: torch.Tensor):
        half_dim = self.dim // 2
        emb = np.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb)
        emb = t[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb

class UNetBlock(nn.Module):
    """UNet基础块"""
    def __init__(self, in_channels: int, out_channels: int, time_emb_dim: int):
        super().__init__()
        self.time_mlp = nn.Linear(time_emb_dim, out_channels)
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
        self.norm1 = nn.GroupNorm(8, out_channels)
        self.norm2 = nn.GroupNorm(8, out_channels)
        self.act = nn.SiLU()

    def forward(self, x: torch.Tensor, t_emb: torch.Tensor):
        h = self.act(self.norm1(self.conv1(x)))
        # 加入时间编码
        time_emb = self.act(self.time_mlp(t_emb))
        h = h + time_emb[(...,) + (None,) * 2]
        h = self.act(self.norm2(self.conv2(h)))
        return h

class VectorFieldUNet(nn.Module):
    """用于图像生成的向量场UNet,输入时间t和图像x,输出向量场"""
    def __init__(self, in_channels: int = 3, base_channels: int = 64, time_emb_dim: int = 256):
        super().__init__()
        self.time_embedding = nn.Sequential(
            SinusoidalPositionEmbedding(base_channels),
            nn.Linear(base_channels, time_emb_dim),
            nn.SiLU(),
            nn.Linear(time_emb_dim, time_emb_dim),
        )
        
        # 下采样
        self.down_blocks = nn.ModuleList([
            UNetBlock(in_channels, base_channels, time_emb_dim),
            UNetBlock(base_channels, base_channels*2, time_emb_dim),
            UNetBlock(base_channels*2, base_channels*4, time_emb_dim),
        ])
        self.downsample = nn.MaxPool2d(2)
        
        # 瓶颈层
        self.bottleneck = UNetBlock(base_channels*4, base_channels*4, time_emb_dim)
        
        # 上采样
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.up_blocks = nn.ModuleList([
            UNetBlock(base_channels*8, base_channels*2, time_emb_dim),
            UNetBlock(base_channels*4, base_channels, time_emb_dim),
            UNetBlock(base_channels*2, base_channels, time_emb_dim),
        ])
        
        # 输出头
        self.out_conv = nn.Conv2d(base_channels, in_channels, 1)

    def forward(self, t: torch.Tensor, x: torch.Tensor):
        """
        前向传播
        :param t: 时间张量,shape [batch_size]
        :param x: 输入图像,shape [batch_size, in_channels, H, W]
        :return: 向量场输出,shape和x一致
        """
        # 时间编码
        t_emb = self.time_embedding(t)
        
        # 下采样
        hiddens = []
        h = x
        for block in self.down_blocks:
            h = block(h, t_emb)
            hiddens.append(h)
            h = self.downsample(h)
        
        # 瓶颈层
        h = self.bottleneck(h, t_emb)
        
        # 上采样
        for block in self.up_blocks:
            h = self.upsample(h)
            h = torch.cat([h, hiddens.pop()], dim=1)
            h = block(h, t_emb)
        
        # 输出
        return self.out_conv(h)

# ===================== 3. 采样函数:ODE求解生成样本 =====================
@torch.no_grad()
def sample(model: nn.Module, batch_size: int, data_shape: tuple, device: torch.device, atol: float = 1e-5, rtol: float = 1e-5):
    """
    从训练好的模型中采样,解ODE从t=0到t=1
    :param model: 训练好的向量场模型
    :param batch_size: 生成样本数量
    :param data_shape: 数据形状,比如(3, 32, 32)
    :param device: 计算设备
    :param atol: ODE求解的绝对误差
    :param rtol: ODE求解的相对误差
    :return: 生成的样本,shape [batch_size, *data_shape]
    """
    # 初始噪声
    x0 = torch.randn(batch_size, *data_shape, device=device)
    
    # 定义ODE函数:dphi/dt = v_t(phi)
    def ode_func(t: torch.Tensor, x: torch.Tensor):
        t_tensor = torch.full((x.shape[0],), t, device=device)
        return model(t_tensor, x)
    
    # 解ODE,从t=0到t=1
    t_span = torch.tensor([0.0, 1.0], device=device)
    trajectory = odeint(ode_func, x0, t_span, atol=atol, rtol=rtol, method='dopri5')
    
    # 返回t=1时刻的样本
    return trajectory[-1]

# ===================== 4. 训练循环示例(以CIFAR10为例) =====================
def train_cifar10():
    # 超参数
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    batch_size = 128
    epochs = 100
    lr = 1e-4
    sigma_min = 0.001
    data_shape = (3, 32, 32)
    
    # 数据加载
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),  # 归一化到[-1,1]
    ])
    dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    
    # 模型、优化器初始化
    model = VectorFieldUNet(in_channels=3, base_channels=64).to(device)
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)
    
    # 训练循环
    for epoch in range(epochs):
        model.train()
        total_loss = 0.0
        pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}")
        
        for batch, _ in pbar:
            batch = batch.to(device)
            optimizer.zero_grad()
            
            # 计算CFM损失
            loss = cfm_ot_loss(model, batch, sigma_min)
            
            # 反向传播
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            pbar.set_postfix({"loss": loss.item()})
        
        print(f"Epoch {epoch+1} 平均损失: {total_loss / len(dataloader):.4f}")
        
        # 每10个epoch采样一次
        if (epoch + 1) % 10 == 0:
            model.eval()
            samples = sample(model, 8, data_shape, device)
            # 这里可以添加保存样本的代码
            print(f"已生成{len(samples)}个样本")
    
    # 保存模型
    torch.save(model.state_dict(), "fm_ot_cifar10.pth")

if __name__ == "__main__":
    train_cifar10()

代码说明

  1. 核心损失函数cfm_ot_loss完全对齐论文里的OT路径CFM损失公式,计算简单高效,无任何额外开销;
  2. 向量场网络:用UNet实现,和扩散模型的网络结构兼容,方便迁移现有扩散模型的优化经验;
  3. 采样函数:用torchdiffeq的dopri5求解器解ODE,支持调整误差容忍度,平衡采样速度和质量;
  4. 训练循环:以CIFAR10为例,开箱即用,可直接迁移到ImageNet等其他数据集。

总结与展望

这篇论文提出的Flow Matching框架,彻底解决了连续归一化流几十年来难以训练、无法扩展到高维图像数据的难题,给生成建模领域开辟了一条全新的道路。

核心贡献回顾

  1. 提出了Flow Matching训练框架:无仿真、无偏、可无限扩展,完美解决了传统CNF的训练痛点;
  2. 证明了CFM和FM的梯度等价性:把不可计算的全局优化问题,转化为可批量计算的逐样本优化问题,让训练变得简单可行;
  3. 给出了通用高斯路径的闭式解:把扩散模型纳入统一框架,同时提出了更高效的OT路径,实现了训练和采样的双重提速;
  4. 大量实验验证:在多个数据集和任务上,全面超越了同期的扩散模型SOTA,用更少的计算量、更少的采样步数,实现了更好的生成效果。

未来展望

FM框架打开了生成建模的全新空间,后续有大量可以探索的方向:

  1. 更丰富的概率路径:除了高斯路径,还可以探索非各向同性高斯、非高斯的概率路径,进一步提升生成效率和效果;
  2. 端到端的OT路径优化:结合可微的OT求解器,学习数据驱动的最优传输路径,而不是固定的线性路径;
  3. 和大语言模型的结合:把FM框架用到文本生成、多模态生成领域,替代现有的扩散模型和自回归模型;
  4. 更快的采样算法:针对OT路径的直线特性,设计专门的ODE求解器,实现1-2步的超快速采样。

总的来说,Flow Matching不仅是对扩散模型的一次全面升级,更是生成建模领域的一次范式革新。它让连续归一化流真正实现了工业化落地的可能,后续的Stable Diffusion 3、Midjourney等模型,也都纷纷借鉴了FM和OT路径的核心思想,它正在成为新一代生成模型的核心基础。

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐