MIT 6.S184 | 流匹配与扩散模型导论 | 2026 | Course Notes | 翻译 | Chapter 7: Discrete Diffusion Models
目录
前言
MIT6.S184: Generative AI with Stochastic Differential Equations 最新 2026 年课程笔记 An Introduction to Flow Matching and Diffusion Models 翻译,本篇文章翻译第七章节 Discrete Diffusion Models: Building Language Models with Diffusion 相关内容🤗。
Course Notes:https://diffusion.csail.mit.edu/2026/docs/lecture_notes.pdf
Course Website:https://diffusion.csail.mit.edu/2026/index.html
7. Discrete Diffusion Models: Building Language Models with Diffusion
在前面的章节中,我们研究了 flow 与 diffusion models,它们作为定义在欧几里得空间 R d \mathbb{R}^d Rd 上的生成模型,能够生成表示为向量 z ∈ R d z \in \mathbb{R}^d z∈Rd 的数据点。然而,并不是所有数据都天然适合建模为欧几里得空间 R d \mathbb{R}^d Rd 中的点。许多数据类型,例如 text 或 DNA,更自然地应被视为离散状态空间 S S S 中的元素。最重要的是,language 由离散 token 序列构成,而这正是我们希望建模的对象。
那么,我们该如何将 flow 与 diffusion models 应用于这种数据类型呢?事实证明,我们在前面章节中学习的原理,同样能够扩展到这些数据类型。由此得到的模型,在机器学习文献中被称为 discrete diffusion models(离散扩散模型)[5, 16]。
不过,需要牢记的一点是:这里并不存在数学意义上的 diffusion process(因为在离散状态空间中不存在 SDEs)。因此,我们不再使用 ODEs / SDEs,而是使用:连续时间马尔可夫链(continuous-time Markov chain, CTMCs)。
在接下来的内容中,我们将介绍 CTMC models(见 Section 7.1)以及如何学习这些模型(见 Section 7.2),并展示如何利用 flow 与 diffusion models 的原理构建 large language models(LLMs)。
7.1 Continuous-Time Markov chain (CTMC) models
在本节中,我们将介绍 continuous-time Markov chains(CTMCs)。你可以将 CTMCs 看作是 SDEs 的离散版本,我们可以利用它来构建生成离散空间的神经网络模型。此外,我们还将介绍 CTMC models,即利用 CTMCs 来生成离散序列(例如文本)的神经网络模型。
Figure 17:一个 CTMC 轨迹的示意图,其状态空间为 S = { S 1 , S 2 , S 3 } S=\{S_1,S_2,S_3\} S={S1,S2,S3} ,并且序列长度 d = 1 d=1 d=1 。图片改编自文献 [5]。
现在,我们首先来刻画状态空间 S S S 。设 V = { v 1 , ⋯ , v V } \mathcal{V} = \{v_1, \cdots, v_V\} V={v1,⋯,vV} 为我们的 vocabulary(词汇表),则状态空间定义为 S = V d S = \mathcal{V}^d S=Vd ,其中 d ∈ N d \in \mathbb{N} d∈N 表示序列长度(sequence length), V ∈ N V \in \mathbb{N} V∈N 表示词汇表大小(vocabulary size)。
对于 language, { v 1 , ⋯ , v V } \{v_1, \cdots, v_V\} {v1,⋯,vV} 可以表示 alphabet 或一组离散 tokens。而 S S S 则表示所有长度为 d d d 的 sequences(或 sentences)的集合。对于 DNA, { v 1 , ⋯ , v V } \{v_1, \cdots, v_V\} {v1,⋯,vV} 则可以表示全部 4 种 DNA bases,而 S S S 表示所有长度为 d d d 的 DNA sequences。
接下来,令 X t X_t Xt 为定义在 S S S 上的随机过程,即一个随机轨迹: X : [ 0 , 1 ] → S , t ↦ X t ∈ S X : [0,1] \to S,\, t \mapsto X_t \in S X:[0,1]→S,t↦Xt∈S 。我们要求 X t X_t Xt 是一个马尔可夫过程,即一个 “无记忆(no memory)” 的过程。具体而言,这意味着下面的条件成立:
p ( X t + h ∣ X t , X t 1 , ⋯ , X t k ) ⏟ prob. of future given present and past = p ( X t + h ∣ X t ) ⏟ prob. of future given present ( for all 0 < h , 0 ≤ t 1 < t 2 < ⋯ < t k < t ) \underbrace{ p(X_{t+h}\mid X_t,X_{t_1},\cdots,X_{t_k}) }_{\text{prob.\ of future given present and past}} \qquad = \qquad \underbrace{ p(X_{t+h}\mid X_t) }_{\text{prob.\ of future given present}} \qquad (\text{for all } 0<h,\ 0\le t_1<t_2<\cdots<t_k<t) prob. of future given present and past p(Xt+h∣Xt,Xt1,⋯,Xtk)=prob. of future given present p(Xt+h∣Xt)(for all 0<h, 0≤t1<t2<⋯<tk<t)
换句话说,未来事件的概率只依赖于当前状态,过去对于未来已经不再重要。注意,虽然 ODEs/SDEs 并不是定义在离散状态空间上,但它们同样也是马尔可夫过程。
这里,由于 X t X_t Xt 定义在离散空间上,因此它被称为 Markov chain(马尔可夫链),更具体地说,是 Continuous-time Markov chain(CTMC)。其中 quantity p t + h ∣ t ( X t + h ∣ X t ) p_{t+h\mid t}(X_{t+h}\mid X_t) pt+h∣t(Xt+h∣Xt) 被称为 transition probabilities(转移概率)。它们与马尔可夫链的初始分布 X 0 ∼ p 0 X_0 \sim p_0 X0∼p0 一起完全决定了整个 CTMC。因此,当我们说 CTMC 时,你也可以简单理解为转移概率 p t + h ∣ t ( X t + h ∣ X t ) p_{t+h\mid t}(X_{t+h}\mid X_t) pt+h∣t(Xt+h∣Xt) 。
接下来,我们来推导离散场景下 vector field(向量场)的对应形式。由于现在处于离散场景中,我们只能在 states(状态)之间进行 jump(或 switch),而无法像 ODE 中那样沿着某个连续方向移动。因此,我们定义一个 rate matrix(速率矩阵) Q t ( y ∣ x ) Q_t(y\mid x) Qt(y∣x) ,它用于刻画从状态 x ∈ S x \in S x∈S 跳转到状态 y ∈ S y \in S y∈S 的速率。
形式化地,rate matrix Q t Q_t Qt 是如下有界函数(关于时间连续):
Q : S × S × [ 0 , 1 ] → R , ( x , y , t ) ↦ Q t ( y ∣ x ) (84) Q : S \times S \times [0,1] \to \mathbb{R}, \quad (x,y,t) \mapsto Q_t(y\mid x) \tag{84} Q:S×S×[0,1]→R,(x,y,t)↦Qt(y∣x)(84)
其中, Q t ( y ∣ x ) Q_t(y\mid x) Qt(y∣x) 描述从 x x x 切换到 y y y 的速率,并满足:
( 1 ) Outgoing rates are positives: Q t ( y ∣ x ) ≥ 0 whenever x ≠ y ( 2 ) Rate staying equals negative outgoing rate: Q t ( x ∣ x ) = − ∑ y ≠ x Q t ( y ∣ x ) for all x \begin{align} (1)\ \text{Outgoing rates are positives:}\quad Q_t(y\mid x) &\ge 0 \qquad \text{whenever } x\ne y \tag{85} \\[14pt] (2)\ \text{Rate staying equals negative outgoing rate:}\quad Q_t(x\mid x) &= -\sum_{y\ne x}Q_t(y\mid x) \qquad \text{for all } x \tag{86} \end{align} (1) Outgoing rates are positives:Qt(y∣x)(2) Rate staying equals negative outgoing rate:Qt(x∣x)≥0whenever x=y=−y=x∑Qt(y∣x)for all x(85)(86)
这两个条件是直观的:第一个条件表示从 x x x 切换到不同状态 y ≠ x y \ne x y=x 的速率只能是非负的(不发生切换对应于 0,因此速率小于 0 是没有意义的)。第二个条件表示停留在 x x x 的速率 Q t ( x ∣ x ) Q_t(x\mid x) Qt(x∣x) 必须与离开 x x x 的总速率相抵消。它本质上是一个一致性条件,表示你要么停留在 x x x ,要么离开它,不存在第三种情况。
注意,这些条件特别意味着 Q t ( x ∣ x ) ≤ 0 Q_t(x\mid x) \le 0 Qt(x∣x)≤0 。因此, Q t ( y ∣ x ) Q_t(y\mid x) Qt(y∣x) 可以看作一个矩阵,其中对角线元素全部非正,非对角线元素全部非负。
现在,我们可以定义微分方程在离散场景中的对应形式,即要求一个 CTMC “follow” 某个 rate matrix 的条件。其核心思想是 X X X 的分布或过程应当遵循 rate matrix Q t Q_t Qt 。换句话说,我们要求转移概率满足:
d d h p t + h ∣ t ( X t + h = y ∣ X t = x ) ∣ h = 0 = Q t ( y ∣ x ) for all x , y ∈ S , 0 ≤ t (87) \left. \frac{d}{dh} p_{t+h\mid t}(X_{t+h}=y \mid X_t=x) \right|_{h=0} = Q_t(y\mid x) \quad \text{for all } x,y\in S,\; 0\le t \tag{87} dhdpt+h∣t(Xt+h=y∣Xt=x) h=0=Qt(y∣x)for all x,y∈S,0≤t(87)
左边表示从 x x x 切换到 y y y 的概率的无穷小的变化率。我们要求这些概率按照 rate matrix 所指定的方式变化。
现在,我们简单检查一下这样的条件是否合理。也就是说,如果我们像 公式 (87) 中那样定义 Q t ( y ∣ x ) Q_t(y\mid x) Qt(y∣x) ,它是否真的是一个合法的 rate matrix?
当 h = 0 h=0 h=0 时,由于时间尚未流逝,从 x x x 切换到 y ≠ x y\ne x y=x 的概率为 0,即 p t ∣ t ( y ∣ x ) = 0 for all y ≠ x p_{t\mid t}(y\mid x)=0\, \text{for all } y\ne x pt∣t(y∣x)=0for all y=x 。因此,我们知道其导数必须是非负的,从而 Q t ( y ∣ x ) ≥ 0 whenever y ≠ x Q_t(y\mid x)\ge0 \, \text{whenever } y\ne x Qt(y∣x)≥0whenever y=x ,这验证了 公式 (85) 中的第一个条件。
进一步地,我们有:
∑ y ≠ x Q t ( y ∣ x ) = ∑ y ≠ x d d h p ( X t + h = y ∣ X t = x ) ∣ h = 0 = d d h ∑ y ≠ x p ( X t + h = y ∣ X t = x ) ∣ h = 0 = d d h ( 1 − p ( X t + h = x ∣ X t = x ) ) = − Q t ( x ∣ x ) \begin{align*} \sum_{y\ne x} Q_t(y\mid x) = \sum_{y\ne x} \left. \frac{d}{dh} p(X_{t+h}=y\mid X_t=x) \right|_{h=0} = \left. \frac{d}{dh} \sum_{y\ne x} p(X_{t+h}=y\mid X_t=x) \right|_{h=0} &= \frac{d}{dh} \bigl( 1-p(X_{t+h}=x\mid X_t=x) \bigr) \\[14pt] &= - Q_t(x\mid x) \end{align*} y=x∑Qt(y∣x)=y=x∑dhdp(Xt+h=y∣Xt=x) h=0=dhdy=x∑p(Xt+h=y∣Xt=x) h=0=dhd(1−p(Xt+h=x∣Xt=x))=−Qt(x∣x)
这里我们使用了概率之和等于 1。这说明 公式 (86) 成立。因此,我们验证了每一个 CTMC 至少都存在一个满足 公式 (87) 的 rate matrix。但如果反过来呢?也就是说如果我们给定一个 Q t Q_t Qt ,是否一定存在对应的 CTMC?如果存在,它是否唯一?答案是:确实如此。
Theorem 33 (CTMC existence and uniqueness)
对于任意的 rate matrix Q t Q_t Qt(关于时间 t t t 有界且连续),都存在唯一的马尔可夫链 X t X_t Xt(即唯一的一组转移概率 p t + h ∣ t ( y ∣ x ) p_{t+h\mid t}(y\mid x) pt+h∣t(y∣x))使得 公式 (87) 成立。
对于感兴趣的读者,我们在 Section C 中提供了一个自洽的证明。这个定理最重要的结论是,对于机器学习的目的,我们可以直接构造一个 rate matrix Q t Q_t Qt(例如通过神经网络),并假设存在唯一一个与 Q t Q_t Qt 对应的马尔可夫链。
Example 34 (Two-state CTMC with equal jump rates)
设 S = { a , b } S=\{a,b\} S={a,b} ,并考虑一个时间齐次 CTMC ( X t ) t ≥ 0 (X_t)_{t\ge0} (Xt)t≥0 ,它以常数速率 λ > 0 \lambda>0 λ>0 在两个状态之间切换:
Q = a b a − λ λ b λ − λ . Q= \begin{array}{c|cc} & a & b\\ \hline a & -\lambda & \lambda\\ b & \lambda & -\lambda \end{array}. Q=aba−λλbλ−λ.
则,在时间增量 h ≥ 0 h\ge0 h≥0 上的转移概率同样与时间 t t t 无关,并由下式给出:
( p ( X t + h = a ∣ X t = a ) p ( X t + h = a ∣ X t = b ) p ( X t + h = b ∣ X t = a ) p ( X t + h = b ∣ X t = b ) ) = 1 2 ( 1 + e − 2 λ h 1 − e − 2 λ h 1 − e − 2 λ h 1 + e − 2 λ h ) . \begin{pmatrix} p(X_{t+h}=a\mid X_t=a) & p(X_{t+h}=a\mid X_t=b) \\ p(X_{t+h}=b\mid X_t=a) & p(X_{t+h}=b\mid X_t=b) \end{pmatrix} = \frac12 \begin{pmatrix} 1+e^{-2\lambda h} & 1-e^{-2\lambda h} \\ 1-e^{-2\lambda h} & 1+e^{-2\lambda h} \end{pmatrix}. (p(Xt+h=a∣Xt=a)p(Xt+h=b∣Xt=a)p(Xt+h=a∣Xt=b)p(Xt+h=b∣Xt=b))=21(1+e−2λh1−e−2λh1−e−2λh1+e−2λh).
可以手动验证 公式 (87) 成立,即这些转移概率的确对应于该 rate matrix。实际上,这些速率是非常直观的:该 chain 会以瞬时速率 λ \lambda λ 不断在两个状态之间翻转。指数项 e − 2 λ h e^{-2\lambda h} e−2λh 描述了初始状态记忆的衰减过程。随着时间趋于无穷即 h → ∞ h\to\infty h→∞ ,有:
P ( h ) → ( 1 2 1 2 1 2 1 2 ) , P(h) \to \begin{pmatrix} \frac12 & \frac12\\[4pt] \frac12 & \frac12 \end{pmatrix}, P(h)→(21212121),
因此,该 chain 最终会忘记它最初从哪里开始,并且以概率 1 2 \frac12 21 处于 a a a 或 b b b 。并且,切换速率 λ > 0 \lambda>0 λ>0 越大,这种收敛发生得越快。
Simulation of CTMC.
接下来,我们考虑如何模拟一个 CTMC 的轨迹。设 h > 0 h>0 h>0 为步长, p i n i t p_{\mathrm{init}} pinit 为定义在 S S S 上的初始分布。例如 p i n i t = U n i f S p_{\mathrm{init}}=\mathrm{Unif}_S pinit=UnifS 表示 S S S 上的均匀分布。
随后,我们可以通过如下方式迭代地进行模拟:首先采样 X 0 ∼ p i n i t X_0\sim p_{\mathrm{init}} X0∼pinit ,然后令:
X t + h ∼ p t + h ∣ t ( ⋅ ∣ X t ) . X_{t+h}\sim p_{t+h\mid t}(\cdot\mid X_t). Xt+h∼pt+h∣t(⋅∣Xt).
现在,如果我们知道 p t + h ∣ t ( ⋅ ∣ X t ) p_{t+h\mid t}(\cdot\mid X_t) pt+h∣t(⋅∣Xt) ,那么上述过程当然是可行的。然而,除了最简单的 CTMC 之外,我们通常并不知道封闭式的转移核,而只能访问 rate matrix Q t Q_t Qt 。不过,根据 公式 (87):
p t + h ∣ t ( X t + h = y ∣ X t = x ) = p t ∣ t ( X t = y ∣ X t = x ) + h Q t ( y ∣ x ) + R t ( h ) = 1 y = x + h Q t ( y ∣ x ) + R t ( h ) p_{t+h\mid t}(X_{t+h}=y\mid X_t=x) = p_{t\mid t}(X_t=y\mid X_t=x) + hQ_t(y\mid x) + R_t(h) = 1_{y=x} + hQ_t(y\mid x) + R_t(h) pt+h∣t(Xt+h=y∣Xt=x)=pt∣t(Xt=y∣Xt=x)+hQt(y∣x)+Rt(h)=1y=x+hQt(y∣x)+Rt(h)
其中 R t ( h ) R_t(h) Rt(h) 是一个误差项,当 h h h 足够小时可以忽略。因此,对于较小的 h h h ,我们可以令:
p t + h ∣ t ( X t + h = y ∣ X t = x ) ≈ 1 y = x + h Q t ( y ∣ x ) = : p ~ t + h ∣ t ( y ∣ x ) p_{t+h\mid t}(X_{t+h}=y\mid X_t=x) \approx 1_{y=x}+hQ_t(y\mid x) =: \tilde p_{t+h\mid t}(y\mid x) pt+h∣t(Xt+h=y∣Xt=x)≈1y=x+hQt(y∣x)=:p~t+h∣t(y∣x)
可以验证:由于我们对 rate matrix 施加的条件,当 h h h 足够小时, p ~ t + h ∣ t ( y ∣ x ) \tilde p_{t+h\mid t}(y\mid x) p~t+h∣t(y∣x) 确实是一个合法的概率分布。因此,我们可以近似地通过如下方式采样下一个点:
X t + h ∼ p ~ t + h ∣ t ( ⋅ ∣ x ) = ( 1 y = x + h Q t ( y ∣ x ) ) y ∈ S (88) X_{t+h} \sim \tilde p_{t+h\mid t}(\cdot\mid x) = \bigl( 1_{y=x}+hQ_t(y\mid x) \bigr)_{y\in S} \tag{88} Xt+h∼p~t+h∣t(⋅∣x)=(1y=x+hQt(y∣x))y∈S(88)
由于上述只是一个离散分布,因此我们可以使用标准方法轻松进行采样。这提供了一种简单的 CTMC 模拟方法。
CTMC model.
接下来,我们定义如何用神经网络来参数化一个 CTMC。一个 CTMC model(或 discrete diffusion model)由以下部分给出: S S S 上的初始分布 p i n i t p_{\mathrm{init}} pinit ,一个带参数 θ \theta θ 的神经网络 Q t θ Q_t^\theta Qtθ ,使得对于每个输入 x ∈ S x\in S x∈S 模型返回 rate matrix 的单独一列:
x ↦ { Q t θ ( y ∣ x ) } y ∈ S x \mapsto \{Q_t^\theta(y|x)\}_{y\in S} x↦{Qtθ(y∣x)}y∈S
我们希望模型返回完整的一列,因为在 CTMC 的模拟中需要它(公式 (88)),即采样下一个状态。
上述模型的一个复杂之处在于空间 S S S 可能非常大。特别地 ∣ S ∣ = V d |S|=V^d ∣S∣=Vd ,其中 V V V 是词汇表大小, d d d 是序列长度。这种指数增长使得在内存中存储 rate matrix 的完整一列基本不可能,也就是说 { Q t θ ( y ∣ x ) } y ∈ S \{Q_t^\theta(y|x)\}_{y\in S} {Qtθ(y∣x)}y∈S 不可能在计算机中表示。
因此,我们必须对模型施加约束。具体来说,几乎所有 CTMC models 都是 factorized(因式化)的(见 Figure 18),这实际上是一种稀疏性约束。具体而言,一个 factorized CTMC model 由一个 CTMC model Q t θ Q_t^\theta Qtθ 给出,使得对于所有 y = ( y 1 , ⋯ , y d ) , x = ( x 1 , ⋯ , x d ) ∈ S = V d y=(y_1,\cdots,y_d), \, x=(x_1,\cdots,x_d) \in S=\mathcal V^d y=(y1,⋯,yd),x=(x1,⋯,xd)∈S=Vd 都有:
Q t θ ( y ∣ x ) = 0 whenever y i ≠ x i for more than one position i Q_t^\theta(y\mid x) = 0 \qquad \text{whenever } y_i\ne x_i \text{ for more than one position } i Qtθ(y∣x)=0whenever yi=xi for more than one position i
我们将所有与 x x x 最多只相差一个 token 的 y y y 称为 x x x 的邻居 N ( x ) N(x) N(x) 。我们可以将这样的 factorized CTMC model 写为:
x ↦ { Q t θ ( y ∣ x ) } y ∈ N ( x ) = ( Q t θ ( v 1 , 1 ∣ x ) ⋯ Q t θ ( v V , 1 ∣ x ) ⋮ ⋮ Q t θ ( v 1 , d ∣ x ) ⋯ Q t θ ( v V , d ∣ x ) ) x \mapsto \{Q_t^\theta(y|x)\}_{y\in N(x)} = \begin{pmatrix} Q_t^\theta(v_1,1|x) & \cdots & Q_t^\theta(v_V,1|x)\\ \vdots & & \vdots\\ Q_t^\theta(v_1,d|x) & \cdots & Q_t^\theta(v_V,d|x) \end{pmatrix} x↦{Qtθ(y∣x)}y∈N(x)= Qtθ(v1,1∣x)⋮Qtθ(v1,d∣x)⋯⋯Qtθ(vV,1∣x)⋮Qtθ(vV,d∣x)
其中 Q t ( y ∣ x ) = Q t θ ( v i , j ∣ x ) Q_t(y|x)=Q_t^\theta(v_i,j|x) Qt(y∣x)=Qtθ(vi,j∣x) 现在表示从 x = ( x 1 , ⋯ , x d ) x=(x_1,\cdots,x_d) x=(x1,⋯,xd) 转移到 x x x 的某个 neighbor 的速率。该 neighbor 是通过将第 j j j 个元素替换为 v i v_i vi 得到的 y = ( x 1 , ⋯ , x j − 1 , v i , x j + 1 , ⋯ , x d ) y=(x_1,\cdots,x_{j-1},v_i,x_{j+1},\cdots,x_d) y=(x1,⋯,xj−1,vi,xj+1,⋯,xd) 。每一行对应于每个位置 i = 1 , ⋯ , d i=1,\cdots,d i=1,⋯,d 上的一个 rate matrix,即我们要求:
Q t θ ( v , i ∣ x ) ≥ 0 if v ≠ x i , Q t ( x i , i ∣ x ) = − ∑ v ≠ x i Q t θ ( v , i ∣ x ) Q_t^\theta(v,i|x)\ge0 \quad \text{if } v\ne x_i, \quad Q_t(x_i,i|x) = - \sum_{v\ne x_i} Q_t^\theta(v,i|x) Qtθ(v,i∣x)≥0if v=xi,Qt(xi,i∣x)=−v=xi∑Qtθ(v,i∣x)
我们可以很容易地对神经网络的输出施加这些条件。例如,可以使用一个作用在序列长度 d d d 上的 transformer model,其输出维度为 V V V 。还需要注意 factorized rate matrix 使输出形状变为 d × V d\times V d×V ,该大小随着维度线性增长,而不是指数增长。
Figure 18:factorized CTMC model 的示意图。Factorized CTMC 只有在起点与终点仅相差一个维度时,其 rate 才是非零的即 Q t ( y ∣ x ) ≠ 0 Q_t(y|x)\neq0 Qt(y∣x)=0 ,这里的示例中 d = 2 d=2 d=2 。图片取自文献 [26]。
Simulating a CTMC model.
为了从一个 CTMC model 中采样,我们先采样 X 0 ∼ p i n i t X_0 \sim p_{\mathrm{init}} X0∼pinit ,然后执行迭代,在每一步根据 公式 (88) 对下一个状态进行采样。我们在 Algorithm 7 中给出了该算法。
如上所示,对于 factorized CTMC models,可以使用一种并行的 per-token 欧拉近似,其中每个 token 都会在一个较小步长 h > 0 h>0 h>0 下被独立更新。这种近似在关于 h h h 的一阶项上与完整的 CTMC 欧拉步一致,但它允许出现一个 O ( h 2 ) O(h^2) O(h2) 概率的事件,即多个 token 被同时更新。

7.2 Training CTMC models
接下来,我们讨论如何学习 CTMC models。其核心原则与 flow matching 相同:
- (1). 我们构造一条在 noise 与 data 之间进行插值的概率路径。
- (2). 我们推导条件 rate matrix 与边缘 rate matrix。
- (3). 我们以 simulation-free 的方式学习边缘 rate matrix。
下面我们将逐步解释这一训练流程。
在本节中,数据分布 p d a t a p_{\mathrm{data}} pdata 是在状态空间 S S S 上的一个分布,并由概率质量函数表示。即: p d a t a : S → R ≥ 0 , z ↦ p d a t a ( z ) p_{\mathrm{data}}: S \to \mathbb{R}_{\ge 0},\, z\mapsto p_{\mathrm{data}}(z) pdata:S→R≥0,z↦pdata(z) 满足 ∑ z ∈ S p d a t a ( z ) = 1 \sum_{z\in S} p_{\mathrm{data}}(z)=1 ∑z∈Spdata(z)=1 。我们并不知道 p d a t a p_{\mathrm{data}} pdata 的具体形式,但在训练过程中可以访问其样本 z ∼ p d a t a z\sim p_{\mathrm{data}} z∼pdata ,这些样本以数据集的形式给出。例如,互联网上的所有文本数据。我们的目标是学习生成样本 z ∼ p d a t a z\sim p_{\mathrm{data}} z∼pdata 。我们的目标是训练 CTMC model Q t θ Q_t^\theta Qtθ 使得:
X 0 ∼ p i n i t , X t CTMC of Q t θ ⇒ X 1 ∼ p d a t a X_0\sim p_{\mathrm{init}},\qquad X_t\ \text{CTMC of }Q_t^\theta \quad\Rightarrow\quad X_1\sim p_{\mathrm{data}} X0∼pinit,Xt CTMC of Qtθ⇒X1∼pdata
因此你会发现,这与欧氏空间 R d \mathbb{R}^d Rd 中的情形(见 Sections 2 与 3)并没有本质区别,只不过这里我们使用的是 CTMC model,而不是 flow/diffusion model。
7.2.1 Conditional and Marginal Probability Path
我们定义 δ z ( x ) \delta_z(x) δz(x) 为如下函数 δ z ( x ) = 0 if x ≠ z \delta_z(x)=0 \quad \text{if } x\neq z δz(x)=0if x=z 以及 δ z ( x ) = 1 if x = z \delta_z(x)=1 \quad \text{if } x=z δz(x)=1if x=z 。一个(离散的)条件概率路径 由一组分布 p t ( x ∣ z ) p_t(x|z) pt(x∣z) 给出,其中 x , z ∈ S , 0 ≤ t ≤ 1 x,z\in S,\quad 0\le t\le1 x,z∈S,0≤t≤1 并满足:
p 0 ( ⋅ ∣ z ) = p i n i t , p 1 ( ⋅ ∣ z ) = δ z p_0(\cdot|z)=p_{\mathrm{init}}, \quad p_1(\cdot|z)=\delta_z p0(⋅∣z)=pinit,p1(⋅∣z)=δz
因此,与欧氏空间中的情形类似,一个离散条件概率路径会在一个与 z z z 无关的分布和一个所有概率质量都集中在 z z z 上的分布之间进行插值。随后,(离散的)边缘概率路径定义为:
p t ( x ) = ∑ z ∈ S p t ( x ∣ z ) p d a t a ( z ) p_t(x)=\sum_{z\in S} p_t(x|z)p_{\mathrm{data}}(z) pt(x)=z∈S∑pt(x∣z)pdata(z)
可以很容易验证,边缘概率路径在 “noise” 与 data 之间进行插值:
p 0 = p i n i t , p 1 = p d a t a (89) p_0=p_{\mathrm{init}}, \quad p_1=p_{\mathrm{data}} \tag{89} p0=pinit,p1=pdata(89)
Example 35 (Factorized mixture path (independent noising per token))
设 S = V d S=\mathcal{V}^d S=Vd ,并令 p i n i t ( x ) = ∏ j = 1 d p i n i t ( j ) ( x j ) p_{\mathrm{init}}(x)=\prod_{j=1}^{d} p_{\mathrm{init}}^{(j)}(x_j) pinit(x)=∏j=1dpinit(j)(xj) 为一个因子化初始分布。固定一个 scheduler 0 ≤ κ t ≤ 1 0\le\kappa_t\le1 0≤κt≤1 满足 κ 0 = 0 , κ 1 = 1 , d d t κ t ≥ 0 \kappa_0=0,\, \kappa_1=1,\, \frac{d}{dt}\kappa_t\ge0 κ0=0,κ1=1,dtdκt≥0 。定义条件路径为:
p t ( x ∣ z ) = ∏ j = 1 d [ ( 1 − κ t ) p i n i t ( j ) ( x j ) + κ t δ z j ( x j ) ] . p_t(x|z) = \prod_{j=1}^{d} \Big[ (1-\kappa_t)p_{\mathrm{init}}^{(j)}(x_j) + \kappa_t\delta_{z_j}(x_j) \Big]. pt(x∣z)=j=1∏d[(1−κt)pinit(j)(xj)+κtδzj(xj)].
等价地,我们可以通过如下方式采样 x ∼ p t ( ⋅ ∣ z ) x\sim p_t(\cdot|z) x∼pt(⋅∣z) :首先采样 i.i.d. masks m j ∈ { 0 , 1 } m_j\in\{0,1\} mj∈{0,1} 以及 noise ξ j ∼ p i n i t ( j ) \xi_j\sim p_{\mathrm{init}}^{(j)} ξj∼pinit(j) ,然后设定:
m j ∼ B e r n o u l l i ( κ t ) , ξ j ∼ p i n i t ( j ) x j = m j z j + ( 1 − m j ) ξ j , j = 1 , … , d x = ( x 1 , … , x d ) . \begin{align*} m_j &\sim \mathrm{Bernoulli}(\kappa_t), \qquad \xi_j\sim p_{\mathrm{init}}^{(j)} \\[8pt] x_j &= m_j z_j + (1-m_j)\xi_j, \qquad j=1,\ldots,d \\[8pt] x &= (x_1,\ldots,x_d). \end{align*} mjxjx∼Bernoulli(κt),ξj∼pinit(j)=mjzj+(1−mj)ξj,j=1,…,d=(x1,…,xd).
我们将上述路径称为 factorized mixture path(因子化混合路径)。上述过程实际上会以概率 1 − κ t 1-\kappa_t 1−κt 独立地 “破坏” 序列中每一个位置的第 j j j 个 token。即当 t = 0 t=0 t=0 时, 1 − κ t = 1 1-\kappa_t=1 1−κt=1 ,所有信息都被破坏;当 t = 1 t=1 t=1 时, 1 − κ t = 0 1-\kappa_t=0 1−κt=0 ,没有任何信息被破坏。
注意,这与高斯概率路径(Example 8)是相似的,因为信息都会按照 scheduler κ t \kappa_t κt 所决定的速度逐渐被破坏。然而,它与高斯概率路径也存在不同:因子化混合路径并不会移动 / 运输概率质量(因为在离散空间中不存在方向),它只是从一个分布逐渐淡出,再逐渐淡入另一个分布。
Figure 19:当 d = 2 d=2 d=2 时,一个离散概率路径的示意图。第一行:条件概率路径在初始分布与狄拉克分布之间进行插值。第二行:初始分布与数据分布(这里为棋盘格模式)之间的插值。注意它与 Figure 5 的相似性与不同点。这里,概率路径是被 “teleported” 的(我们减小初始分布的权重,并增大 terminal distribution 的权重)。
7.2.2 Conditional and Marginal Rate Matrix
作为下一步,我们现在构造离散 flow matching 的训练目标。首先,我们构造一个条件 rate matrix — 它对应于 flow matching 中的条件向量场。对于每一个数据点 z ∈ S z\in S z∈S ,令 Q t z ( y ∣ x ) Q_t^z(y|x) Qtz(y∣x) 为一个 rate matrix。如果满足:
X 0 ∼ p i n i t , X t CTMC of Q t z ⇒ X t ∼ p t ( ⋅ ∣ z ) X_0\sim p_{\mathrm{init}}, \qquad X_t\ \text{CTMC of }Q_t^z \quad\Rightarrow\quad X_t\sim p_t(\cdot|z) X0∼pinit,Xt CTMC of Qtz⇒Xt∼pt(⋅∣z)
则我们称其为 conditional rate matrix。
换句话说,条件 rate matrix 所对应的 CTMC 会 “follow” 条件概率路径。条件 rate matrix 作为一个基础构件,用于构造遵循边缘概率路径的边缘 rate matrix:
Theorem 36 (Discrete marginalization trick)
定义 marginal rate matrix 为:
Q t ( y ∣ x ) = ∑ z ∈ S Q t z ( y ∣ x ) p t ( x ∣ z ) p d a t a ( z ) p t ( x ) = ∑ z ∈ S Q t z ( y ∣ x ) p 1 ∣ t ( z ∣ x ) where p 1 ∣ t ( z ∣ x ) : = p t ( x ∣ z ) p d a t a ( z ) p t ( x ) (90) Q_t(y|x) = \sum_{z\in S} Q_t^z(y|x) \frac{p_t(x|z)p_{\mathrm{data}}(z)}{p_t(x)} = \sum_{z\in S} Q_t^z(y|x)p_{1|t}(z|x) \quad \text{where } p_{1|t}(z|x) := \frac{p_t(x|z)p_{\mathrm{data}}(z)}{p_t(x)} \tag{90} Qt(y∣x)=z∈S∑Qtz(y∣x)pt(x)pt(x∣z)pdata(z)=z∈S∑Qtz(y∣x)p1∣t(z∣x)where p1∣t(z∣x):=pt(x)pt(x∣z)pdata(z)(90)
则它是一个合法的 rate matrix,并满足如下条件:
X 0 ∼ p i n i t , X t CTMC of Q t ⇒ X t ∼ p t X_0\sim p_{\mathrm{init}}, \qquad X_t\ \text{CTMC of }Q_t \quad\Rightarrow\quad X_t\sim p_t X0∼pinit,Xt CTMC of Qt⇒Xt∼pt
特别地,由 公式 (89) 可知 X 1 ∼ p d a t a X_1\sim p_{\mathrm{data}} X1∼pdata ,即边缘 rate matrix 对应的 CTMC 会将 noise 转换为 data。
为了证明这一结论,我们需要一个关于 CTMC 的基础方程,即所谓的 Kolmogorov Forward equation(科尔莫戈罗夫正方程):
Proposition 2 (Kolmogorov Forward Equation)
设 p t p_t pt 是在 0 ≤ t ≤ 1 0\le t\le1 0≤t≤1 上定义于状态空间 S S S 的一组分布。进一步地,设 X t X_t Xt 是一个具有矩阵 Q t Q_t Qt 以及初始分布 p 0 p_0 p0 的 CTMC。那么, X t ∼ p t for all 0 ≤ t ≤ 1 X_t\sim p_t \,\text{for all }0\le t\le1 Xt∼ptfor all 0≤t≤1 当且仅当 Kolmogorov Forward Equation(KFE) 成立:
d d t p t ( x ) = ∑ y ∈ S Q t ( x ∣ y ) p t ( y ) \frac{d}{dt}p_t(x) = \sum_{y\in S} Q_t(x|y)p_t(y) dtdpt(x)=y∈S∑Qt(x∣y)pt(y)
Proof of KFE. 为了证明 KFE 是必要条件,假设 p t ( x ) p_t(x) pt(x) 是 CTMC 的真实 marginals,即 X t ∼ p t for every 0 ≤ t ≤ 1 X_t\sim p_t \,\text{for every }0\le t\le1 Xt∼ptfor every 0≤t≤1 。那么我们可以计算:
d d t p t ( x ) = ( i ) d d h ∣ h = 0 p t + h ( x ) = ( i i ) d d h ∣ h = 0 ∑ y p t + h ∣ t ( x ∣ y ) p t ( y ) = ( i i i ) ∑ y d d h ∣ h = 0 p t + h ∣ t ( x ∣ y ) p t ( y ) = ( i v ) ∑ y Q t ( x ∣ y ) p t ( y ) \begin{align*} \frac{d}{dt}p_t(x) &\overset{(i)}{=} \frac{d}{dh}\Big|_{h=0} p_{t+h}(x) \\[8pt] &\overset{(ii)}{=} \frac{d}{dh}\Big|_{h=0} \sum_y p_{t+h|t}(x|y)p_t(y) \\[8pt] &\overset{(iii)}{=} \sum_y \frac{d}{dh}\Big|_{h=0} p_{t+h|t}(x|y)p_t(y) \\[8pt] &\overset{(iv)}{=} \sum_y Q_t(x|y)p_t(y) \end{align*} dtdpt(x)=(i)dhd h=0pt+h(x)=(ii)dhd h=0y∑pt+h∣t(x∣y)pt(y)=(iii)y∑dhd h=0pt+h∣t(x∣y)pt(y)=(iv)y∑Qt(x∣y)pt(y)
其中:
- 在 ( i ) (i) (i) 中,我们只是使用了一个时间偏移;
- 在 ( i i ) (ii) (ii) 中,我们使用了转移概率的定义;
- 在 ( i i i ) (iii) (iii) 中,我们交换了求和与求导;
- 在 ( i v ) (iv) (iv) 中,我们使用了 rate matrix 的定义(见 公式 (87))。
接下来,为了证明 KFE 是充分条件,我们可以将 KFE 改写为矩阵形式:
d d t p t = Q t p t \frac{d}{dt}p_t = Q_t p_t dtdpt=Qtpt
其中,在这个方程中,我们将 p t = ( p t ( x ) ) x ∈ S p_t=(p_t(x))_{x\in S} pt=(pt(x))x∈S 视为一个向量,并将 Q t = ( Q t ( y ∣ x ) ) x , y ∈ S Q_t=(Q_t(y|x))_{x,y\in S} Qt=(Qt(y∣x))x,y∈S 视为一个矩阵。注意,上式是在向量空间 R S \mathbb{R}^S RS 上的一个线性 ODE。其初始条件由定理中的 p 0 p_0 p0 给定。
因此,如果任何其他一组 marginals q t q_t qt 也满足该方程,那么根据 ODE 的唯一性(见 Theorem 3),我们可以得出 q t = p t q_t=p_t qt=pt 。这表明 KFE 同样也是充分条件。
Proof of Theorem 36. 使用 KFE 后,剩下只需要证明定理中定义的 marginal rate matrix(见 公式 (90))满足 KFE:
d d t p t ( x ) = ( i ) d d t ∑ z ∈ S p t ( x ∣ z ) p d a t a ( z ) = ( i i ) ∑ z ∈ S d d t p t ( x ∣ z ) p d a t a ( z ) = ( i i i ) ∑ z ∈ S [ ∑ y ∈ S Q t z ( x ∣ y ) p t ( y ∣ z ) ] p d a t a ( z ) = ( i v ) ∑ y ∈ S p t ( y ) [ ∑ z ∈ S Q t z ( x ∣ y ) p t ( y ∣ z ) p d a t a ( z ) p t ( y ) ] = ( v ) ∑ y ∈ S p t ( y ) Q t ( x ∣ y ) \begin{align*} \frac{d}{dt}p_t(x) &\overset{(i)}{=} \frac{d}{dt} \sum_{z\in S} p_t(x|z)p_{\mathrm{data}}(z) \\[8pt] &\overset{(ii)}{=} \sum_{z\in S} \frac{d}{dt} p_t(x|z)p_{\mathrm{data}}(z) \\[8pt] &\overset{(iii)}{=} \sum_{z\in S} \left[ \sum_{y\in S} Q_t^z(x|y)p_t(y|z) \right] p_{\mathrm{data}}(z) \\[8pt] &\overset{(iv)}{=} \sum_{y\in S} p_t(y) \left[ \sum_{z\in S} Q_t^z(x|y) \frac{ p_t(y|z)p_{\mathrm{data}}(z) }{ p_t(y) } \right] \\[8pt] &\overset{(v)}{=} \sum_{y\in S} p_t(y)Q_t(x|y) \end{align*} dtdpt(x)=(i)dtdz∈S∑pt(x∣z)pdata(z)=(ii)z∈S∑dtdpt(x∣z)pdata(z)=(iii)z∈S∑ y∈S∑Qtz(x∣y)pt(y∣z) pdata(z)=(iv)y∈S∑pt(y)[z∈S∑Qtz(x∣y)pt(y)pt(y∣z)pdata(z)]=(v)y∈S∑pt(y)Qt(x∣y)
其中:
- 在 ( i ) (i) (i) 中,使用了边缘概率路径的定义;
- 在 ( i i ) (ii) (ii) 中,交换了求和与求导;
- 在 ( i i i ) (iii) (iii) 中,对条件 rate matrix 使用了 KFE;
- 在 ( i v ) (iv) (iv) 中,同时乘以并除以 p t ( y ) p_t(y) pt(y) ;
- 在 ( v ) (v) (v) 中,使用了边缘 rate matrix Q t ( y ∣ x ) Q_t(y|x) Qt(y∣x) 的定义。
这说明 KFE 成立。该结论由 Proposition 2 得出。
现在,我们来推导因子化混合路径的条件 rate matrix 的一个具体例子。
Example 37 (Conditional rate matrix for factorized mixture path)
设 d d t κ t = κ ˙ t \frac{d}{dt}\kappa_t=\dot{\kappa}_t dtdκt=κ˙t ,因子化混合路径具有如下因子化条件 rate matrix:
Q t z ( y ∣ x ) = ( Q t z ( v i , j ∣ x j ) ) v i , j Q t z ( v i , j ∣ x j ) = κ ˙ t 1 − κ t ( δ z j ( v i ) − δ x j ( v i ) ) = κ ˙ t 1 − κ t { 0 if x j = z j 1 if v i = z j , x j ≠ z j 0 if v i ≠ z j , x j ≠ z j − 1 if v i = x j , x j ≠ z j \begin{align*} Q_t^z(y|x) &= \left( Q_t^z(v_i,j|x_j) \right)_{v_i,j} \\[8pt] Q_t^z(v_i,j|x_j) &= \frac{\dot{\kappa}_t}{1-\kappa_t} \left( \delta_{z_j}(v_i)-\delta_{x_j}(v_i) \right) \\[8pt] &= \frac{\dot{\kappa}_t}{1-\kappa_t} \begin{cases} 0 & \text{if } x_j=z_j\\ 1 & \text{if } v_i=z_j,\ x_j\ne z_j\\ 0 & \text{if } v_i\ne z_j,\ x_j\ne z_j\\ -1 & \text{if } v_i=x_j,\ x_j\ne z_j \end{cases} \end{align*} Qtz(y∣x)Qtz(vi,j∣xj)=(Qtz(vi,j∣xj))vi,j=1−κtκ˙t(δzj(vi)−δxj(vi))=1−κtκ˙t⎩ ⎨ ⎧010−1if xj=zjif vi=zj, xj=zjif vi=zj, xj=zjif vi=xj, xj=zj
注意,这是一个非常简单的 rate matrix:它只允许跳转到 z j z^j zj ,也就是说,如果任意 token j j j 被更新,它必须跳转到终点数据点 z = ( z 1 , ⋯ , z d ) z=(z_1,\cdots,z_d) z=(z1,⋯,zd) 的 token value;并且只有在当前还没有到达该 token 时,它才会跳转到 z j z^j zj 。
Proof. 注意,因子化混合路径完全分解为独立分量,所提出的条件 rate matrix 也是如此。因此,不失一般性,我们可以假设 d = 1 d=1 d=1 。也就是说,我们只需要逐维进行计算。于是可以推导:
d d t p t ( x ∣ z ) = ( i ) d d t [ ( 1 − κ t ) p i n i t ( x ) + κ t δ z ( x ) ] = ( i i ) κ ˙ t δ z ( x ) − κ ˙ t p i n i t ( x ) = ( i i i ) κ ˙ t 1 − κ t ( δ z ( x ) − [ ( 1 − κ t ) p i n i t ( x ) + κ t δ z ( x ) ] ) = ( i v ) κ ˙ t 1 − κ t ( δ z ( x ) − p t ( x ∣ z ) ) = ( v ) κ ˙ t 1 − κ t δ z ( x ) ( 1 − p t ( x ∣ z ) ) + κ ˙ t 1 − κ t ( δ z ( x ) − 1 ) p t ( x ∣ z ) = ( v i ) ∑ y ≠ x κ ˙ t 1 − κ t δ z ( x ) p t ( y ∣ z ) + κ ˙ t 1 − κ t ( δ z ( x ) − 1 ) p t ( x ∣ z ) = ( v i i ) ∑ y ≠ x Q t z ( x ∣ y ) p t ( y ∣ z ) + Q t z ( x ∣ x ) p t ( x ∣ z ) = ( v i i i ) ∑ y ∈ S Q t z ( x ∣ y ) p t ( y ∣ z ) \begin{align*} \frac{d}{dt}p_t(x|z) &\overset{(i)}{=} \frac{d}{dt} \left[ (1-\kappa_t)p_{\mathrm{init}}(x) + \kappa_t\delta_z(x) \right] \\[8pt] &\overset{(ii)}{=} \dot{\kappa}_t\delta_z(x) - \dot{\kappa}_t p_{\mathrm{init}}(x) \\[8pt] &\overset{(iii)}{=} \frac{\dot{\kappa}_t}{1-\kappa_t} \left( \delta_z(x) - \left[ (1-\kappa_t)p_{\mathrm{init}}(x) + \kappa_t\delta_z(x) \right] \right) \\[8pt] &\overset{(iv)}{=} \frac{\dot{\kappa}_t}{1-\kappa_t} \left( \delta_z(x)-p_t(x|z) \right) \\[8pt] &\overset{(v)}{=} \frac{\dot{\kappa}_t}{1-\kappa_t} \delta_z(x) \left( 1-p_t(x|z) \right) + \frac{\dot{\kappa}_t}{1-\kappa_t} \left( \delta_z(x)-1 \right) p_t(x|z) \\[8pt] &\overset{(vi)}{=} \sum_{y\ne x} \frac{\dot{\kappa}_t}{1-\kappa_t} \delta_z(x)p_t(y|z) + \frac{\dot{\kappa}_t}{1-\kappa_t} \left( \delta_z(x)-1 \right) p_t(x|z) \\[8pt] &\overset{(vii)}{=} \sum_{y\ne x} Q_t^z(x|y)p_t(y|z) + Q_t^z(x|x)p_t(x|z) \\[8pt] &\overset{(viii)}{=} \sum_{y\in S} Q_t^z(x|y)p_t(y|z) \end{align*} dtdpt(x∣z)=(i)dtd[(1−κt)pinit(x)+κtδz(x)]=(ii)κ˙tδz(x)−κ˙tpinit(x)=(iii)1−κtκ˙t(δz(x)−[(1−κt)pinit(x)+κtδz(x)])=(iv)1−κtκ˙t(δz(x)−pt(x∣z))=(v)1−κtκ˙tδz(x)(1−pt(x∣z))+1−κtκ˙t(δz(x)−1)pt(x∣z)=(vi)y=x∑1−κtκ˙tδz(x)pt(y∣z)+1−κtκ˙t(δz(x)−1)pt(x∣z)=(vii)y=x∑Qtz(x∣y)pt(y∣z)+Qtz(x∣x)pt(x∣z)=(viii)y∈S∑Qtz(x∣y)pt(y∣z)
其中:
- ( i ) (i) (i) 使用了 d = 1 d=1 d=1 时因子化混合路径的定义;
- ( i i ) (ii) (ii) 通过求导并设 d d t κ t = κ ˙ t \frac{d}{dt}\kappa_t=\dot{\kappa}_t dtdκt=κ˙t 得到;
- ( i i i ) (iii) (iii) 由简单代数变形得到;
- ( i v ) (iv) (iv) 使用了因子化混合路径的定义;
- ( v ) (v) (v) 由简单代数变形得到;
- ( v i ) (vi) (vi) 使用了 ∑ y ∈ S p t ( y ∣ z ) = 1 \sum_{y\in S}p_t(y|z)=1 ∑y∈Spt(y∣z)=1 这一事实;
- ( v i i ) (vii) (vii) 使用了 rate matrix 的定义;
- ( v i i i ) (viii) (viii) 由简单代数变形得到。
上述推导说明 KFE 成立,因此命题得证。
7.2.3 Learning the Marginal Rate Matrix
在本节中,我们推导用于训练 CTMC models 的基本算法。根据 Theorem 36,训练一个 CTMC model Q t θ ( y ∣ x ) Q_t^\theta(y|x) Qtθ(y∣x) 可以通过学习 marginal rate matrix 来实现。
在本节中,我们现在只考虑因子化混合路径(见 Example 35),因为这是目前大多数 discrete diffusion / flow matching models 所使用的路径。在这种情况下,marginal rate matrix 具有非常直观的形式:
Theorem 38 (Marginalization trick for factorized mixture path)
因子化混合路径的 marginal rate matrix 是 factorized 的,并且具有如下形式:
Q t ( v i , j ∣ x ) = κ ˙ t 1 − κ t ( p 1 ∣ t ( z j = v i ∣ x ) − δ x j ( v i ) ) Q_t(v_i,j|x) = \frac{\dot{\kappa}_t}{1-\kappa_t} \left( p_{1|t}(z_j=v_i|x) - \delta_{x_j}(v_i) \right) Qt(vi,j∣x)=1−κtκ˙t(p1∣t(zj=vi∣x)−δxj(vi))
其中, p 1 ∣ t ( z j = v i ∣ x ) p_{1|t}(z_j=v_i|x) p1∣t(zj=vi∣x) 是在给定完整 noisy sequence x x x 的情况下,第 j j j 个位置(序列中的第 j j j 个 token)等于 v i v_i vi 的条件概率。
Proof. marginal rate matrix 定义为:
Q t ( y ∣ x ) = ∑ z ∈ S Q t z ( y ∣ x ) p 1 ∣ t ( z ∣ x ) (91) Q_t(y|x) = \sum_{z\in S} Q_t^z(y|x)p_{1|t}(z|x) \tag{91} Qt(y∣x)=z∈S∑Qtz(y∣x)p1∣t(z∣x)(91)
现在,当 y y y 与 x x x 不是 neighbors 时(即超过一个 token 不同),对于任意 z z z 都有 Q t z ( y ∣ x ) = 0 Q_t^z(y|x)=0 Qtz(y∣x)=0 。因此,在这种情况下也有 Q t ( y ∣ x ) = 0 Q_t(y|x)=0 Qt(y∣x)=0 。这说明 marginal rate matrix 也是因子化的。于是有:
Q t ( v i , j ∣ x ) = ∑ z ∈ S Q t z ( v i , j ∣ x ) p 1 ∣ t ( z ∣ x ) = ( i ) ∑ z ∈ S κ ˙ t 1 − κ t ( δ z j ( v i ) − δ x j ( v i ) ) p 1 ∣ t ( z ∣ x ) = ( i i ) κ ˙ t 1 − κ t ( ∑ z ∈ S δ z j ( v i ) p 1 ∣ t ( z ∣ x ) − δ x j ( v i ) ) = ( i i i ) κ ˙ t 1 − κ t ( p 1 ∣ t ( z j = v i ∣ x ) − δ x j ( v i ) ) \begin{align} Q_t(v_i,j|x) &= \sum_{z\in S} Q_t^z(v_i,j|x)p_{1|t}(z|x) \tag{92} \\[8pt] &\overset{(i)}{=} \sum_{z\in S} \frac{\dot{\kappa}_t}{1-\kappa_t} \left( \delta_{z_j}(v_i)-\delta_{x_j}(v_i) \right) p_{1|t}(z|x) \tag{93} \\[8pt] &\overset{(ii)}{=} \frac{\dot{\kappa}_t}{1-\kappa_t} \left( \sum_{z\in S} \delta_{z_j}(v_i)p_{1|t}(z|x) - \delta_{x_j}(v_i) \right) \tag{94} \\[8pt] &\overset{(iii)}{=} \frac{\dot{\kappa}_t}{1-\kappa_t} \left( p_{1|t}(z_j=v_i|x) - \delta_{x_j}(v_i) \right) \tag{95} \end{align} Qt(vi,j∣x)=z∈S∑Qtz(vi,j∣x)p1∣t(z∣x)=(i)z∈S∑1−κtκ˙t(δzj(vi)−δxj(vi))p1∣t(z∣x)=(ii)1−κtκ˙t(z∈S∑δzj(vi)p1∣t(z∣x)−δxj(vi))=(iii)1−κtκ˙t(p1∣t(zj=vi∣x)−δxj(vi))(92)(93)(94)(95)
其中:
- ( i ) (i) (i) 来自条件 rate matrix 的公式(见 Example 37);
- ( i i ) (ii) (ii) 来自 ∑ z ∈ S p 1 ∣ t ( z ∣ x ) = 1 \sum_{z\in S} p_{1|t}(z|x)=1 ∑z∈Sp1∣t(z∣x)=1 这一事实;
- ( i i i ) (iii) (iii) 来自边缘化。
证明结束。
前面的定理非常值得注意:marginal rate matrix 本质上是对概率 p 1 ∣ t ( z j = v i ∣ x ) p_{1|t}(z_j=v_i|x) p1∣t(zj=vi∣x) 的一种重新参数化。这实际上无非就是对每一个 token 位置 j = 1 , … , d j=1,\ldots,d j=1,…,d 学习一个分类器。换句话说,我们可以简单地定义一个 denoising probabilities network(去噪概率网络):
p 1 ∣ t θ : x ⏟ network input ↦ ( p 1 ∣ t θ ( z j = v i ∣ x ) ) j = 1 , … , d , v i ∈ V ⏟ network output p_{1\mid t}^\theta: \qquad \underbrace{x}_{\text{network input}} \quad\mapsto\quad \underbrace{ \left( p_{1\mid t}^\theta(z_j=v_i\mid x) \right)_{j=1,\ldots,d,\ v_i\in\mathcal{V}} }_{\text{network output}} p1∣tθ:network input x↦network output (p1∣tθ(zj=vi∣x))j=1,…,d, vi∈V
注意,网络输出的 shape 为 d × V d\times V d×V 。我们可以通过简单的 softmax layer 得到每个 token 位置上的概率。网络本身可以是一个标准 sequence-to-sequence network,例如 transformer 就可以工作(见 Section 6.1.2)。
由于这本质上只是对每个位置 j j j 进行分类,因此我们可以通过每个位置 j = 1 , … , d j=1,\ldots,d j=1,…,d 上的交叉熵损失来训练该网络,这便得到如下的 Discrete Flow Matching loss:
L D F M ( θ ) = E z ∼ p d a t a , t ∼ U n i f [ 0 , 1 ] , x ∼ p t ( ⋅ ∣ z ) [ ∑ j = 1 d − log p 1 ∣ t θ ( z j ∣ x ) ] \mathcal{L}_{\mathrm{DFM}}(\theta) = \mathbb{E}_{z\sim p_{\mathrm{data}}, t\sim \mathrm{Unif}_{[0,1]}, x\sim p_t(\cdot|z)} \left[ \sum_{j=1}^{d} -\log p_{1|t}^{\theta}(z_j|x) \right] LDFM(θ)=Ez∼pdata,t∼Unif[0,1],x∼pt(⋅∣z)[j=1∑d−logp1∣tθ(zj∣x)]
这一点非常值得注意:为了训练一个生成模型,我们真正需要做的,只是对每个位置 j j j 训练一个分类模型。就像连续 flow matching 被简化成简单回归问题(见 Section 3)一样,discrete flow matching 以及 discrete diffusion models 也被简化成了简单的分类训练。
在 Algorithm 8 中,我们总结了训练算法。训练完成后,我们可以通过 Algorithm 7 进行采样。
Example 39 (Masked Diffusion Language Model)
上述方法的一个特殊情况是 masked diffusion language models(MDLMs)。MDLM 的核心思想是:我们可以将 token 词汇表 V = { v 1 , … , v V } \mathcal{V}=\{v_1,\ldots,v_V\} V={v1,…,vV} 扩展一个新的 token [ m a s k ] [\mathrm{mask}] [mask] ,它表示该 token 缺失(或者被 mask 掉)。
具体而言,我们设置 V = { v 1 , … , v V , [ m a s k ] } \mathcal{V}=\{v_1,\ldots,v_V,[\mathrm{mask}]\} V={v1,…,vV,[mask]} 并将初始点简单设置为 [ m a s k ] d [\mathrm{mask}]^d [mask]d ,即整个序列全部由 mask token 构成。形式化地,这意味着在上述框架中设置 p i n i t = δ [ m a s k ] d p_{\mathrm{init}} = \delta_{[\mathrm{mask}]^d} pinit=δ[mask]d ,采样过程如 Figure 20 所示。

Figure 20:MDLM 轨迹示意图
至此,我们已经完成了一个完整的 CTMC 模型训练与采样 pipeline,它能够用于生成诸如文本这样的离散序列。当前最先进的 discrete diffusion models [4] 采用的正是本文中描述的方法:使用神经网络(通常是 transformers)并在 web-scale 数据上进行训练。
Remark 40 (Generator Matching)
你可能会好奇为什么 flow/diffusion models 的原理能够如此自然地推广到离散状态空间?事实证明,flow matching 的原理并不局限于 flow,甚至也不局限于 CTMC。更准确地说,这些实际上是利用马尔可夫过程构建生成模型的一般性学习原理。
这便引出了 Generator Matching framework [19],这是一个能够统一并扩展离散与连续 flow/diffusion models 的框架。generator 可以被看作是向量场 u t u_t ut 以及 rate matrix Q t Q_t Qt 的一种泛化。
马尔可夫过程与 generators 可以针对任意数据模态与状态空间进行构建。例如,你可以为 smooth manifolds 构建模型 [8, 10](例如几何数据);也可以针对混合状态空间构建模型(例如联合文本与图像生成)[6];还可以构建其它类型的马尔可夫过程,例如 jump processes [19, 7]。
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐


所有评论(0)