前言

MIT6.S184: Generative AI with Stochastic Differential Equations 最新 2026 年课程笔记 An Introduction to Flow Matching and Diffusion Models 翻译,本篇文章翻译第六章节 Building Large-Scale Image or Video Generators 相关内容🤗。

Course Noteshttps://diffusion.csail.mit.edu/2026/docs/lecture_notes.pdf

Course Websitehttps://diffusion.csail.mit.edu/2026/index.html

6 Building Large-Scale Image or Video Generators

在前面的章节中我们已经学习了如何训练 flow matching model 和 diffusion model 来从分布 p d a t a ( x ∣ y ) p_{\mathrm{data}}(x|y) pdata(xy) 中进行采样。这一整套方法其实是通用的,并且可以应用于各种数据类型和实际应用。

在本节中,我们将深入研究规模图像与视频生成的实际构建方式。其中会涉及一些知名模型,例如:FLUX 2.0Stable Diffusion 3Nano BananaVEO-3 以及 Meta Movie Gen Video。最后,我们还会在 lab 中从零开始构建属于我们自己的版本。

本节大致分为以下几个部分:

  1. Neural Network Architectures(神经网络架构):首先,我们会讨论原始条件输入如何被模型使用。包括时间 t t t 和引导变量 y r a w y_{\mathrm{raw}} yraw(例如离散类别标签、原始文本)。这些输入会被转换或嵌入为模型 u t θ ( x ∣ y ) u_t^\theta(x|y) utθ(xy) 能够处理的向量形式。随后,我们会讨论一些主流架构设计,包括 U-Netdiffusion transformer

  2. Latent Space(潜空间):接下来,我们会讨论变分自编码器 Variational Autoencoders(VAE),它们能够在更低维的 latent space 中进行生成建模。这样做的好处是可以支持超高分辨率图像生成。

  3. Case Studies(案例分析):最后,我们会深入分析前面提到的两个 SOTA 模型 Stable DiffusionMeta MovieGen,目的是让大家真正理解大规模生成模型是如何构建与训练的。

6.1 Neural Network Architectures

现在让我们将注意力转向面向 flow / diffusion models 的可扩展神经网络架构设计,尤其是针对 image-like modalities(例如图像与视频)的情况。更具体地说,我们将探索带参数 θ \theta θ 的(guided)vector field u t θ ( x ∣ y ) u_t^\theta(x|y) utθ(xy) 在实际工程中是如何实现的。

请注意,该神经网络具有三个输入:数据向量 x ∈ R d x\in\mathbb{R}^d xRd ,条件变量 y ∈ Y y\in\mathcal{Y} yY ,时间变量 t ∈ [ 0 , 1 ] t\in[0,1] t[0,1] ,并输出一个向量 u t θ ( x ∣ y ) ∈ R d u_t^\theta(x|y)\in\mathbb{R}^d utθ(xy)Rd

对于低维分布,例如之前章节中的 toy distributions,我们只需要使用多层感知机(MLP)来参数化 u t θ ( x ∣ y ) u_t^\theta(x|y) utθ(xy) 即可。MLP 也就是 Fully Connected Neural Network(全连接神经网络)。在这种简单情况下,forward pass 的过程大概是将输入 x x x ,条件 y y y ,时间 t t t 拼接起来,然后输入到 MLP 中。

然而,对于复杂的高维分布,例如图像、视频、蛋白质,MLP 通常是不够的。因此,实际中通常会使用特殊的面向具体任务的架构。

在本小节剩余部分,我们主要讨论图像(以及扩展到视频)的情况。首先,我们会研究原始条件信息如何被模型使用。包括时间 t t t ,条件变量 y y y 。这些信息会被 embedding(嵌入)为模型能够处理的向量形式。随后,我们会介绍两种最主流的架构选择:U-Net [38, 17, 22, 11] 和 Diffusion Transformer(DiT)[12, 30, 28]。

6.1.1 Embedding the Conditioning Variables

Embedding Time.

对于简单的 toy models 来说,直接把时间变量 t t t 拼接到输入中通常已经足够训练出性能不错的网络。但在实际工程中,这个标量时间 t t t 通常会被 embedding 到更高维空间中。最常见的方法是 Fourier Features,这样做的好处是模型能够更好地捕获高频时间依赖 [46]

更具体地,时间 embedding 的形式为:

T i m e E m b ( t ) = 2 d [ cos ⁡ ( 2 π w 1 t ) ⋯ cos ⁡ ( 2 π w d / 2 t ) sin ⁡ ( 2 π w 1 t ) ⋯ sin ⁡ ( 2 π w d / 2 t ) ] T , (68) \mathrm{TimeEmb}(t)= \sqrt{\frac{2}{d}} \left[\cos(2\pi w_1 t) \quad \cdots \quad \cos(2\pi w_{d/2} t) \quad \sin(2\pi w_1 t) \quad \cdots \quad \sin(2\pi w_{d/2} t) \right]^T, \tag{68} TimeEmb(t)=d2 [cos(2πw1t)cos(2πwd/2t)sin(2πw1t)sin(2πwd/2t)]T,(68)

其中,频率 w i w_i wi 按照下面方式设置:

w i = w min ⁡ ( w max ⁡ w min ⁡ ) i − 1 d / 2 − 1 , i = 1 , … , d / 2 (69) w_i = w_{\min} \left( \frac{w_{\max}}{w_{\min}} \right)^{\frac{i-1}{d/2-1}}, \qquad i=1,\dots,d/2 \tag{69} wi=wmin(wminwmax)d/21i1,i=1,,d/2(69)

这个 T i m e E m b \mathrm{TimeEmb} TimeEmb 的选择虽然是标准做法,但不是必须严格采用的唯一形式。它本质上只是一种方便的方式,用于构造维度为 d 的 normalized embedding。因为,对于任意频率 sin ⁡ 2 ( ⋅ ) + cos ⁡ 2 ( ⋅ ) = 1 \sin^2(\cdot)+\cos^2(\cdot)=1 sin2()+cos2()=1 ,因此整个 embedding 满足 ∥ T i m e E m b ( t ) ∥ = 1 \|\mathrm{TimeEmb}(t)\|=1 TimeEmb(t)=1 ,也就是说所有时间 embedding 的模长恒定。

Embedding Class Labels.

当原始条件变量 y raw ∈ Y ≜ { 0 , … , N } y_{\text{raw}} \in \mathcal{Y} \triangleq \{0,\dots,N\} yrawY{0,,N} 只是一个离散类别标签时,最简单的方法通常是为每个类别学习一个独立 embedding 向量。也就是说,对于 N + 1 N+1 N+1 个可能类别,我们直接学习 N + 1 N+1 N+1 个 embedding vectors。然后将 y y y 设置成对应类别的 embedding。这些 embedding 参数会被视为模型 u t θ ( x ∣ y ) u_t^\theta(x|y) utθ(xy) 参数的一部分,因此,它们会在训练过程中与整个 diffusion / flow model 一起训练。

Embedding Textual Input.

y raw y_{\text{raw}} yraw 是一个文本 prompt 时,情况会更加复杂,并且相关方法在很大程度上依赖于冻结的预训练模型。这类模型会被训练来将离散文本输入 embedding 到一个连续向量中,该向量能够捕获相关信息。其中一种这样的模型被称为 CLIP(Contrastive Language-Image Pre-training)。

CLIP 被训练用于学习图像与文本 prompt 的共享 embedding 空间。其训练损失会鼓励图像 embedding 对应 prompt 的 embedding 彼此接近,同时与其他图像和 prompt 的 embedding 保持更远距离 [34]。因此,我们可以取 y = C L I P ( y raw ) ∈ R d CLIP y = \mathrm{CLIP}(y_{\text{raw}}) \in \mathbb{R}^{d_{\text{CLIP}}} y=CLIP(yraw)RdCLIP 作为由冻结的预训练 CLIP 模型生成的 embedding。

在某些情况下,将整个序列压缩成单一表示可能并不理想。此时还可以进一步考虑使用预训练 Transformer 对 prompt 进行 embedding,从而得到 embedding 序列。在条件生成中,也经常会组合多个这样的预训练 embeddings,以同时获得各个模型的优势 [14, 33]。

对于本文而言,我们可以简单假设在应用这样的模型之后,prompt embedding 的形状为:

P r o m p t E m b e d ( y raw ) ∈ R S × k \mathrm{PromptEmbed}(y_{\text{raw}}) \in \mathbb{R}^{S \times k} PromptEmbed(yraw)RS×k

6.1.2 Diffusion Transformers

在深入讨论这些架构的细节之前,让我们先回顾一下:图像本质上只是一个向量 x ∈ R C image × H × W x \in \mathbb{R}^{C_{\text{image}} \times H \times W} xRCimage×H×W 。其中 C image C_{\text{image}} Cimage 表示通道数(channels)(例如 RGB 图像通常具有 C input = 3 C_{\text{input}} = 3 Cinput=3 个颜色通道),而 H H H W W W 分别表示图像在像素上的高度和宽度。

一种特别重要的架构类别是 diffusion transformers(DiTs)以及它们的变体,这些模型使用 attention 机制来构建网络 [49, 30, 28]。Diffusion transformer 存在不同的变体。这里我们解释一种通用设计,并说明 DiT 的具体实现可能会根据模型和应用场景而有所不同。

在本节剩余部分中,我们使用:

  • d d d 表示 hidden dimension
  • L L L 表示 transformer layers 的数量
  • h h h 表示每层中的 head 数量

Diffusion transformers 基于 vision transformers(ViTs),其核心思想本质上是将图像划分为 patches,并将这些 patch embedding 成 token sequence,然后通过标准 attention 处理这些 token [13]。最后会应用一个 depatchification operation 以恢复正确形状的图像。

最初的 patchification 操作本质上只是对图像 tensor x ∈ R C × H × W x \in \mathbb{R}^{C \times H \times W} xRC×H×W 进行重新组织。具体来说:

P a t c h i f y ( x ) ∈ R N × C ′ \mathrm{Patchify}(x) \in \mathbb{R}^{N \times C'} Patchify(x)RN×C

其中 C ′ = C P 2 , N = ( H / P ) ⋅ ( W / P ) C' = CP^2,N = (H/P)\cdot(W/P) C=CP2,N=(H/P)(W/P) ,这里 P P P 是 patch size。

接下来,我们对输出应用一个线性变换,得到最终的 patch embedding:

P a t c h E m b ( x ) = P a t c h i f y ( x ) W ∈ R N × d \mathrm{PatchEmb}(x) = \mathrm{Patchify}(x)W \in \mathbb{R}^{N \times d} PatchEmb(x)=Patchify(x)WRN×d

其中 W ∈ R C ′ × d W \in \mathbb{R}^{C' \times d} WRC×d 是一个可学习权重矩阵。Diffusion transformer 的输入随后包括时间 embedding、prompt embedding 以及 patchified image tensor,具体为(见 Section 6.1.1):

t ~ = T i m e E m b ( t ) ∈ R d y ~ = P r o m p t E m b ( y ) ∈ R S × d x ~ 0 = P a t c h E m b ( x ) ∈ R N × d \begin{align*} \tilde{t} &= \mathrm{TimeEmb}(t) \in \mathbb{R}^{d} \\[8pt] \tilde{y} &= \mathrm{PromptEmb}(y) \in \mathbb{R}^{S \times d} \\[8pt] \tilde{x}_0 &= \mathrm{PatchEmb}(x) \in \mathbb{R}^{N \times d} \end{align*} t~y~x~0=TimeEmb(t)Rd=PromptEmb(y)RS×d=PatchEmb(x)RN×d

注意,现在所有元素都具有 transformer 所需的 hidden dimension。随后,diffusion transformer 会通过 DiTBlock(细节请参考 Remark 29) 中的 transformer layers,迭代更新 x ~ i \tilde{x}_i x~i 对于 i = 0 , … , L − 1 i = 0,\dots,L-1 i=0,,L1 有:

x ~ i + 1 = D i T B l o c k ( x ~ i , t ~ , y ~ ) ∈ R N × d ( i = 0 , … , L − 1 ) . (70) \tilde{x}_{i+1} = \mathrm{DiTBlock}(\tilde{x}_i,\tilde{t},\tilde{y})\in \mathbb{R}^{N \times d} \qquad (i = 0,\dots,L-1). \tag{70} x~i+1=DiTBlock(x~i,t~,y~)RN×d(i=0,,L1).(70)

其中 N N N 是 layer 的数量。

最后,一个最终操作会应用 depatchification operation 将 DiT 输出映射回目标输出形状:

u = D e p a t c h i f y ( x ~ N W ~ ) ∈ R C × H × W u=\mathrm{Depatchify}(\tilde{x}_N\tilde{W}) \in \mathbb{R}^{C \times H \times W} u=Depatchify(x~NW~)RC×H×W

其中 W ~ ∈ R d × C ′ \tilde{W} \in \mathbb{R}^{d \times C'} W~Rd×C 。最终 tensor u u u 作为模型输出以及预测的 velocity u t θ ( x ∣ y ) u_t^\theta(x|y) utθ(xy)

Figure 14:左图为 diffusion transformer 架构概述,引自 [30];右图为对比 CLIP 损失示意图,其中学习了一个共享的图像-文本嵌入空间,引自 [34]


Remark 29 (DiT Block)

为了完整起见,我们给出单个 DiT layer 的简要数学描述。虽然我们尝试包含足够多的细节,以帮助理解 DiT 模型家族,但我们提醒读者这里更强调关键算法设计而不是架构细节。

现在,令 x ∈ R N × d x \in \mathbb{R}^{N \times d} xRN×d 表示当前的 patch token 序列(这里 x = x ~ i x = \tilde{x}_i x=x~i),并令 y ∈ R S × d y \in \mathbb{R}^{S \times d} yRS×d 表示 embedding 后的 guidance variable(这里 y = y ~ y = \tilde{y} y=y~)。随后,一个典型的 DiT block 会通过 (i) patch 上的 self-attention,(ii) 对 prompt 的 cross-attention,(iii )通过 adaptive normalization(AdaLN)进行 time conditioning 来更新 x x x

Scaled Dot Product Attention.

给定 queries Q ∈ R N × d h Q \in \mathbb{R}^{N \times d_h} QRN×dh ,keys K ∈ R M × d h K \in \mathbb{R}^{M \times d_h} KRM×dh ,values V ∈ R M × d h V \in \mathbb{R}^{M \times d_h} VRM×dh ,定义:

A t t n ( Q , K , V ) = s o f t m a x ( Q K ⊤ d h ) V ∈ R N × d h , \mathrm{Attn}(Q,K,V) = \mathrm{softmax} \left(\frac{QK^\top}{\sqrt{d_h}}\right)V \in \mathbb{R}^{N \times d_h}, Attn(Q,K,V)=softmax(dh QK)VRN×dh,

其中,softmax 按行(row-wise)应用。

Multi-Head Attention.

h h h 表示 head 的数量,并令 d h = d h d_h = \frac{d}{h} dh=hd 表示每个 head 的维度。对于每个 head h ∈ { 1 , … , n heads } h \in \{1,\dots,n_{\text{heads}}\} h{1,,nheads} 学习投影矩阵 W Q ( h ) , W K ( h ) , W V ( h ) ∈ R k × d h W_Q^{(h)},W_K^{(h)},W_V^{(h)} \in \mathbb{R}^{k \times d_h} WQ(h),WK(h),WV(h)Rk×dh 。定义:

h e a d h ( x , z ) = A t t n ( x W Q ( h ) , z W K ( h ) , z W V ( h ) ) , \mathrm{head}_h(x,z) = \mathrm{Attn} (xW_Q^{(h)},zW_K^{(h)},zW_V^{(h)}), headh(x,z)=Attn(xWQ(h),zWK(h),zWV(h)),

其中 source sequence z z z 可以是:

z = x ( self-attention on patches ) , z = y ( cross-attention to the prompt ) . z=x \quad (\text{self-attention on patches}), \qquad z=y \quad (\text{cross-attention to the prompt}). z=x(self-attention on patches),z=y(cross-attention to the prompt).

随后拼接所有 heads,并应用输出投影 W O ∈ R d × d W_O \in \mathbb{R}^{d \times d} WORd×d 得到:

M u l t i H e a d a t t e n t i o n ( x , z ) = C o n c a t ( h e a d 1 ( x , z ) , … , h e a d h ( x , z ) ) W O ∈ R N × d . \mathrm{MultiHeadattention}(x,z) = \mathrm{Concat} (\mathrm{head}_1(x,z), \dots, \mathrm{head}_h(x,z)) W_O \in \mathbb{R}^{N \times d}. MultiHeadattention(x,z)=Concat(head1(x,z),,headh(x,z))WORN×d.

Time Conditioning via Adaptive Normalization.

t ~ ∈ R d \tilde{t} \in \mathbb{R}^{d} t~Rd 表示 timestep embedding。DiT 中的一种标准做法是使用 t ~ \tilde{t} t~ 生成 per-channel scale/shift 参数以调制归一化后的激活值 [31]。具体来说,令 g : R d → R 2 d g:\mathbb{R}^d \to \mathbb{R}^{2d} g:RdR2d 为一个 MLP,并设:

( γ , β ) = g ( t ~ ) , (\gamma,\beta)=g(\tilde{t}), (γ,β)=g(t~),

其中 γ , β ∈ R d \gamma,\beta \in \mathbb{R}^{d} γ,βRd(或者根据实现不同,attention 与 MLP 等不同子层可能使用独立的 ( γ , β ) (\gamma,\beta) (γ,β))。给定 token matrix x ∈ R N × d x \in \mathbb{R}^{N \times d} xRN×d 以及归一化算子 N o r m ( ⋅ ) \mathrm{Norm}(\cdot) Norm()(例如 LayerNorm),定义 modulated normalization 为:

A d a N o r m t ~ ( x ) = ( 1 + γ ) ⊙ N o r m ( H ) + β , \mathrm{AdaNorm}_{\tilde{t}}(x) = (1+\gamma)\odot\mathrm{Norm}(H)+\beta, AdaNormt~(x)=(1+γ)Norm(H)+β,

其中 ⊙ \odot 表示带有 token 维度 broadcasting 的逐元素乘法。

Putting It Together.

组合后的操作,也即 DiTBlock 定义为:

x ← x + g self ( t ~ ) ⊙ M u l t i H e a d a t t e n t i o n ( A d a N o r m t ~ ( x ) , A d a N o r m t ~ ( x ) ) x ← x + g cross ( t ~ ) M u l t i H e a d a t t e n t i o n ( A d a N o r m t ~ ( x ) , y ) x ← x + g MLP ( t ~ ) M L P ( A d a N o r m t ~ ( x ) ) \begin{align*} x & \leftarrow x + g_{\text{self}}(\tilde{t}) \odot \mathrm{MultiHeadattention} (\mathrm{AdaNorm}_{\tilde{t}}(x), \mathrm{AdaNorm}_{\tilde{t}}(x) ) \\[8pt] x & \leftarrow x + g_{\text{cross}}(\tilde{t}) \mathrm{MultiHeadattention} (\mathrm{AdaNorm}_{\tilde{t}}(x), y) \\[8pt] x & \leftarrow x + g_{\text{MLP}}(\tilde{t}) \mathrm{MLP} (\mathrm{AdaNorm}_{\tilde{t}}(x)) \end{align*} xxxx+gself(t~)MultiHeadattention(AdaNormt~(x),AdaNormt~(x))x+gcross(t~)MultiHeadattention(AdaNormt~(x),y)x+gMLP(t~)MLP(AdaNormt~(x))

其中,MLP 是 position-wise feed-forward network,并且 g ⋅ g_\cdot g 是可学习 gating 参数。输出 x ∈ R N × d x \in \mathbb{R}^{N \times d} xRN×d 会成为下一层的 patch-token sequence(在本文记号中 x ~ i + 1 \tilde{x}_{i+1} x~i+1)。

最后,我们注意到 class-conditioned DiT(例如实验中实现的版本)通常会更加简单,并且会省略 cross-attention layer,转而使用基于 time 和 class 的 AdaNorm conditioning。


6.1.3 U-Net

U-Net 架构 [38] 是 DiT 架构的一种替代架构,并且属于一种特殊的卷积神经网络。它最初是为图像分割设计的。其关键特性在于输入与输出都具有图像的形状(尽管 channel 数量可能不同)。这使得它非常适合参数化向量场 x ↦ u t θ ( x ∣ y ) x \mapsto u_t^\theta(x|y) xutθ(xy) ,因为对于固定的 y , t y,t y,t 其输入具有图像形状,输出也同样如此。因此,U-Net 在早期 diffusion models 文献中被广泛使用 [17, 22, 11]。

一个 U-Net 包含一系列 encoders E i \mathcal{E}_i Ei ,对应的一系列 decoders D i \mathcal{D}_i Di ,以及位于中间的 latent processing block,我们称之为 midcoder。作为一个缩放示例,考虑一张图像 x t ∈ R 3 × 256 × 256 x_t \in \mathbb{R}^{3 \times 256 \times 256} xtR3×256×256(这里 ( C input , H , W ) = ( 3 , 256 , 256 ) (C_{\text{input}},H,W) =(3,256,256) (Cinput,H,W)=(3,256,256))经过 U-Net 时的路径:

x t i n p u t ∈ R 3 × 256 × 256 ▶  Input to the U-Net. x t l a t e n t = E ( x t i n p u t ) ∈ R 512 × 32 × 32 ▶  Pass through encoders to obtain latent. x t l a t e n t = M ( x t l a t e n t ) ∈ R 512 × 32 × 32 ▶  Pass latent through midcoder. x t o u t p u t = D ( x t l a t e n t ) ∈ R 3 × 256 × 256 ▶  Pass through decoders to obtain output. \begin{align*} x_t^{\mathrm{input}} &\in \mathbb{R}^{3\times256\times256} && \blacktriangleright\ \text{Input to the U-Net.} \\[14pt] x_t^{\mathrm{latent}} = \mathcal{E}(x_t^{\mathrm{input}}) &\in \mathbb{R}^{512\times32\times32} && \blacktriangleright\ \text{Pass through encoders to obtain latent.} \\[14pt] x_t^{\mathrm{latent}} = \mathcal{M}(x_t^{\mathrm{latent}}) &\in \mathbb{R}^{512\times32\times32} && \blacktriangleright\ \text{Pass latent through midcoder.} \\[14pt] x_t^{\mathrm{output}} = \mathcal{D}(x_t^{\mathrm{latent}}) &\in \mathbb{R}^{3\times256\times256} && \blacktriangleright\ \text{Pass through decoders to obtain output.} \end{align*} xtinputxtlatent=E(xtinput)xtlatent=M(xtlatent)xtoutput=D(xtlatent)R3×256×256R512×32×32R512×32×32R3×256×256 Input to the U-Net. Pass through encoders to obtain latent. Pass latent through midcoder. Pass through decoders to obtain output.

Note:Midcoder 是本文中使用的一个完全非标准术语,用于表示 U-Net stack 最底部部分,并与 encoder 与 decoder 相对应。

注意,当输入经过 encoders 时,其表示中的 channel 数量增加,而图像的高度和宽度会减小。encoder 与 decoder 通常都由一系列卷积层组成,其中包含 activation functions、pooling operations 等等。上面未展示的还有两点:第一,输入 x t input ∈ R 3 × 256 × 256 x_t^{\text{input}} \in \mathbb{R}^{3 \times 256 \times 256} xtinputR3×256×256 通常会先经过初始 pre-encoding block,以增加 channel 数量,然后再输入第一个 encoder block。第二,encoders 与 decoders 通常通过残差连接。完整结构如 Figure 15 所示。

Figure 15:一种简化的 U-Net 架构(本课程 2025 年版本的实验 03 中使用了类似架构)。

从高层角度来看,大多数 U-Net 都包含上面描述结构的某种变体。然而,上述某些设计选择在实际实现中可能会有所不同。特别地,这里我们采用的是纯卷积架构,而在实践中通常也会在 encoders 和 decoders 中加入 attention layers。U-Net 的名字来源于其 encoder 与 decoder 共同形成的 “U” 形结构(见 Figure 15)。

6.2 Working in Latent Space: (Variational) Autoencoders

到目前为止,我们一直在数据空间 R d \mathbb{R}^d Rd 中进行操作。然而,随着图像分辨率不断提高,直接在这样的空间中建模的代价会迅速变得难以承受。例如,一个 1024 × 1024 1024 \times 1024 1024×1024 且具有三个 RGB 颜色通道的图像,对应的总维度为 d = H ⋅ W ⋅ 3 ≈ 3 ∗ 10 6 d = H \cdot W \cdot 3 \approx 3 * 10^6 d=HW33106 。注意,对于视频而言,由于所有内容都会随着帧数 T T T 增长,维度还会进一步增加。

正如你可以想象的那样,在这样的空间上训练很快就会变得不可行。不同于图像分类任务,其低维输出允许使用更窄的 convolutional stacks,我们的基于 flow 的建模方法要求输出 u t θ ( x ) ∈ R d u_t^\theta(x) \in \mathbb{R}^d utθ(x)Rd 必须与输入一样大。因此,一个重要问题变成了:如何在合理的内存与计算预算下,对高维图像进行建模?

6.2.1 Standard Autoencoders

对于这个问题,一个自然的答案是 compression(压缩)。例如,图像的真实空间可能位于高维图像空间中的更低维流形(lower-dimensional manifold)附近。更具体地,我们可以考虑一个 encoder μ ϕ : R d → R k \mu_\phi :\mathbb{R}^d\to\mathbb{R}^k μϕ:RdRk 以及一个 decoder μ θ : R k → R d \mu_\theta :\mathbb{R}^k\to\mathbb{R}^d μθ:RkRd ,它们共同将原始图像 x ∈ R d x \in \mathbb{R}^d xRd 映射到 latent z ∈ R k z \in \mathbb{R}^k zRk ,并从 latent 中恢复回来。维度 k k k 通常会远小于 d d d

对于图像,例如 d = 3 × 1024 × 1024 d = 3 \times 1024 \times 1024 d=3×1024×1024 时,通常会进行 downsample,得到例如 k = 3 × 1024 16 × 1024 16 k=3\times\frac{1024}{16}\times\frac{1024}{16} k=3×161024×161024 。共同地 μ ϕ \mu_\phi μϕ μ θ \mu_\theta μθ 被称为 autoencoder。理想情况下 μ ϕ \mu_\phi μϕ μ θ \mu_\theta μθ 会被选择为实现高重建质量。换句话说,使得 μ θ ( μ ϕ ( x ) ) \mu_\theta(\mu_\phi(x)) μθ(μϕ(x)) 平均而言能够近似 x x x 。因此,autoencoders 通常通过 reconstruction loss 进行训练:

L Recon ( ϕ , θ ) = E x ∼ p data [ ∥ μ θ ( μ ϕ ( x ) ) − x ∥ 2 ] . \mathcal{L}_{\text{Recon}}(\phi,\theta) = \mathbb{E}_{x \sim p_{\text{data}}} \left[ \| \mu_\theta(\mu_\phi(x)) - x \|^2 \right]. LRecon(ϕ,θ)=Expdata[μθ(μϕ(x))x2].

该损失衡量的是原始数据点 x x x 与重建结果 μ θ ( μ ϕ ( x ) ) \mu_\theta(\mu_\phi(x)) μθ(μϕ(x)) 之间的平方误差。

Amenability to Generative Modeling.

遗憾的是,上面的 reconstruction loss 并不足以训练出一个 “好的” autoencoder。回忆一下,我们的最终目标是在 latent space 中训练一个 generative model,并建模 latent distribution p latent ( z ) p_{\text{latent}}(z) platent(z) 。其中 z = μ ϕ ( x ) ,   x ∼ p data z = \mu_\phi(x), \, x \sim p_{\text{data}} z=μϕ(x),xpdata 。随后,针对 p data ( x ) p_{\text{data}}(x) pdata(x) 的 generative model 可以通过将 latent generative model 的输出送入 decoder μ θ \mu_\theta μθ 来实现。

当前 autoencoder 的公式存在一个微妙的问题:我们几乎无法控制 p latent ( z ) p_{\text{latent}}(z) platent(z) ,因此,也几乎无法保证 p latent ( z ) p_{\text{latent}}(z) platent(z) 本身足够 “well-behaved”,从而适合训练 generative model。

例如,我们希望 latent distribution 是 nice,simple,Gaussian-like 这样的分布。虽然把数据变换到 latent space 可能已经完成了压缩,但与此同时,我们也可能将 p data p_{\text{data}} pdata 变换成了 一个非常难学习的 latent distribution p latent p_{\text{latent}} platent 。因此,问题变成了:如何保证 latent distribution p latent p_{\text{latent}} platent 仍然是 well-behaved 且 easy-to-learn 的?

为了能够对 latent distribution 进行更加显式的正则化,我们现在会在一个更一般的概率框架下重新表述 autoencoder 的概念,并由此引出 variational autoencoder(VAE)。

6.2.2 Variational Autoencoders

variational autoencoder(VAE)是在我们之前的(deterministic)standard autoencoder 公式基础上得到的,方法是放松 encoder 与 decoder 必须是 deterministic function 的约束。具体来说,我们考虑一个 encoder q ϕ ( z ∣ x ) q_\phi(z|x) qϕ(zx) 其参数为 ϕ \phi ϕ ,一个 decoder p θ ( x ∣ z ) p_\theta(x|z) pθ(xz) 其参数为 θ \theta θ ,最常见的选择是:

q ϕ ( z ∣ x ) = N ( z ; μ ϕ ( x ) , diag ⁡ ( σ ϕ 2 ( x ) ) ) , p θ ( x ∣ z ) = N ( x ; μ θ ( z ) , σ θ 2 ( z ) I d ) (71) q_\phi(z|x) = \mathcal N \left( z; \mu_\phi(x), \operatorname{diag}(\sigma_\phi^2(x)) \right), \qquad p_\theta(x|z) = \mathcal N \left( x; \mu_\theta(z), \sigma_\theta^2(z)I_d \right) \tag{71} qϕ(zx)=N(z;μϕ(x),diag(σϕ2(x))),pθ(xz)=N(x;μθ(z),σθ2(z)Id)(71)

其中 μ ϕ ( x ) ∈ R k ,   σ ϕ 2 ( x ) ∈ R ≥ 0 k ,   μ θ ( z ) ∈ R d ,   σ θ 2 ( z ) ∈ R > 0 \mu_\phi(x)\in\mathbb R^k, \, \sigma_\phi^2(x)\in\mathbb R_{\ge0}^k, \, \mu_\theta(z)\in\mathbb R^d , \, \sigma_\theta^2(z)\in\mathbb R_{>0} μϕ(x)Rk,σϕ2(x)R0k,μθ(z)Rd,σθ2(z)R>0 都由 neural network 参数化,而 diag ⁡ \operatorname{diag} diag 表示对角矩阵。为了 encode 或 decode 一个变量,我们进行采样:

z ∼ q ϕ ( ⋅ ∣ x ) ( encode ) x ∼ p θ ( ⋅ ∣ z ) ( decode ) z\sim q_\phi(\cdot|x) \qquad \qquad \qquad \qquad (\text{encode}) \\[8pt] x\sim p_\theta(\cdot|z) \qquad \qquad \qquad \qquad (\text{decode}) zqϕ(x)(encode)xpθ(z)(decode)

最后需要注意当 σ ϕ ( x ) = 0 \sigma_\phi(x)=0 σϕ(x)=0 σ θ ( x ) = 0 \sigma_\theta(x)=0 σθ(x)=0 始终成立时,我们就恢复到了 standard autoencoder。

现在我们来看看 reconstruction loss。一个自然的目标函数是:

L VAE-Recon ( ϕ , θ ) = − E x ∼ p data ( x ) , z ∼ q ϕ ( z ∣ x ) [ log ⁡ p θ ( x ∣ z ) ] (72) \mathcal L_{\text{VAE-Recon}}(\phi,\theta) = - \mathbb E_{x\sim p_{\text{data}}(x), z\sim q_\phi(z|x)} \left[ \log p_\theta(x|z) \right] \tag{72} LVAE-Recon(ϕ,θ)=Expdata(x),zqϕ(zx)[logpθ(xz)](72)

注意这里有两个变化:第一,现在不再是 deterministic encoding,而是 z ∼ q ϕ ( z ∣ x ) z\sim q_\phi(z|x) zqϕ(zx) 的随机采样。第二,现在我们取的是在 decoding 下, x x x 的负对数似然。也就是说,这个 loss 实际上在问:原始数据点 x x x 在 encode 与 decode 后,出现的可能性有多大。

由于现在一切都变成了随机过程,因此我们需要考虑所有可能的 encodings / decodings。对于 Gaussian 情况,这个 reconstruction loss 变为:

L VAE-Recon ( ϕ , θ ) = E x ∼ p data ( x ) , z ∼ q ϕ ( z ∣ x ) [ 1 2 σ θ 2 ( z ) ∥ x − μ θ ( z ) ∥ 2 + d 2 log ⁡ σ θ 2 ( z ) ] + const (73) \mathcal L_{\text{VAE-Recon}}(\phi,\theta) = \mathbb E_{x\sim p_{\text{data}}(x), z\sim q_\phi(z|x)} \left[ \frac{1}{2\sigma_\theta^2(z)} \|x-\mu_\theta(z)\|^2 + \frac d2\log\sigma_\theta^2(z) \right] + \text{const} \tag{73} LVAE-Recon(ϕ,θ)=Expdata(x),zqϕ(zx)[2σθ2(z)1xμθ(z)2+2dlogσθ2(z)]+const(73)

其中,我们使用了正态分布的密度函数(见 公式 (97))。因此,VAE reconstruction loss 与 standard AE reconstruction loss 并没有太大区别。我们只是需要考虑所有可能的 z ∼ q ϕ ( ⋅ ∣ x ) z\sim q_\phi(\cdot|x) zqϕ(x) 编码结果。

第二项 d 2 log ⁡ σ θ 2 ( z ) \frac d2\log\sigma_\theta^2(z) 2dlogσθ2(z) 依赖于 decoder variance,它控制重建精度和预测不确定性之间的权衡。很多实现(包括 lab 中的实现)会固定 σ ϕ ( x ) \sigma_\phi(x) σϕ(x) σ θ ( z ) \sigma_\theta(z) σθ(z) 为可学习标量常数。也就是说,它们分别独立于 x x x z z z 。这样可以避免在学习方差时出现病态行为以及数值不稳定性。

因此,在这种情况下,VAE reconstruction loss 基本上会退化为带随机 encoding 的 standard autoencoder reconstruction loss:

L VAE-Recon ( ϕ , θ ) = E x ∼ p data ( x ) , z ∼ q ϕ ( z ∣ x ) [ 1 2 σ θ 2 ∥ x − μ θ ( z ) ∥ 2 ] + const (74) \mathcal L_{\text{VAE-Recon}}(\phi,\theta) = \mathbb E_{x\sim p_{\text{data}}(x), z\sim q_\phi(z|x)} \left[ \frac{1}{2\sigma_\theta^2} \|x-\mu_\theta(z)\|^2 \right] + \text{const} \tag{74} LVAE-Recon(ϕ,θ)=Expdata(x),zqϕ(zx)[2σθ21xμθ(z)2]+const(74)

现在让我们重新回到目标:我们希望构造一种针对 p data ( x ) p_{\text{data}}(x) pdata(x) 的 encoding,使得在映射到 latent space 后,得到的 distribution 是 nice 或 easy-to-learn 的。为此,我们现在引入 一个 latent prior distribution p prior ( z ) p_{\text{prior}}(z) pprior(z)

在本文中,我们取 p prior = N ( 0 , I k ) p_{\text{prior}}=\mathcal N(0,I_k) pprior=N(0,Ik) 即 isotropic Gaussian。这种 prior distribution 的选择实际上表示了 latent distribution 理想情况下应当具有的形式。

正态分布通常非常容易学习,因此它满足我们对于 “trainable latent distribution” 的目标。核心思想是 regularize encoder 使得编码后的数据分布尽可能接近 p prior p_{\text{prior}} pprior 。我们通过下面这个辅助 loss 来实现:

L VAE-Prior ( ϕ ) = E x ∼ p data ( x ) [ D KL ( q ϕ ( ⋅ ∣ x )    ∥    p prior ) ] , (75) \mathcal L_{\text{VAE-Prior}}(\phi)=\mathbb E_{x\sim p_{\text{data}}(x)}\left[D_{\text{KL}}\left(q_\phi(\cdot|x)\;\|\;p_{\text{prior}}\right)\right],\tag{75} LVAE-Prior(ϕ)=Expdata(x)[DKL(qϕ(x)pprior)],(75)

其中 D KL D_{\text{KL}} DKL 表示 KL 散度。KL-divergence 是一种衡量两个概率分布差异程度的重要方法。完整解释超出了本文范围,但我们会在 Remark 30 中给出简要背景。

这里定义的 L VAE-Prior \mathcal L_{\text{VAE-Prior}} LVAE-Prior 非常直观:我们希望对于任意数据点 x x x 其 encoding distribution 都尽可能像 Gaussian distribution。如果对所有 x x x 都成立,那么自然可以期待最终的 latent distribution 也会近似 Gaussian。


Remark 30 (Background on KL-divergence)

对于两个概率密度 q , p q,p q,p ,KL 散度定义为:

D K L ( q ( x )   ∥   p ( x ) ) = ∫ q ( x ) log ⁡ q ( x ) p ( x ) = E X ∼ q [ log ⁡ q ( X ) p ( X ) ] . D_{\mathrm{KL}}(q(x)\,\|\,p(x)) = \int q(x)\log\frac{q(x)}{p(x)} = \mathbb E_{X\sim q} \left[ \log\frac{q(X)}{p(X)} \right]. DKL(q(x)p(x))=q(x)logp(x)q(x)=EXq[logp(X)q(X)].

KL divergence 是衡量分布之间不相似程度的一种标准方法。特别地,KL divergence 满足下面这些有用性质:

D K L ( q ( x )   ∥   p ( x ) ) ≥ 0 , D K L ( q ( x )   ∥   p ( x ) ) = 0 ⇔ q = p . \begin{align} &D_{\mathrm{KL}}(q(x)\,\|\,p(x)) \ge 0, \tag{76} \\[6pt] &D_{\mathrm{KL}}(q(x)\,\|\,p(x)) =0 \quad \Leftrightarrow \quad q=p. \tag{77} \end{align} DKL(q(x)p(x))0,DKL(q(x)p(x))=0q=p.(76)(77)

也就是说,KL divergence 永远非负;当且仅当两个 probability distribution 完全一致时,KL divergence 才等于 0。


为了定义 variational autoencoder 的损失函数,我们现在可以将 reconstruction loss 与 prior loss 组合起来,并使用一个参数权重 β ≥ 0 \beta\ge0 β0 从而得到 VAE training objective

L V A E ( ϕ , θ ) = L V A E - R e c o n ( ϕ , θ ) + β L V A E - P r i o r ( ϕ ) = − E x ∼ p d a t a ( x ) , z ∼ q ϕ ( z ∣ x ) [ log ⁡ p θ ( x ∣ z ) ] + β E x ∼ p d a t a ( x ) [ D K L ( q ϕ ( ⋅ ∣ x ) ∥ p p r i o r ) ] \begin{align}\mathcal L_{\mathrm{VAE}}(\phi,\theta) &= \mathcal L_{\mathrm{VAE\text{-}Recon}}(\phi,\theta)+ \beta \mathcal L_{\mathrm{VAE\text{-}Prior}}(\phi) \tag{78} \\[6pt]&= -\mathbb E_{x\sim p_{\mathrm{data}}(x),z\sim q_\phi(z|x)}\left[\log p_\theta(x|z)\right]+\beta\mathbb E_{x\sim p_{\mathrm{data}}(x)}\left[D_{\mathrm{KL}}\left(q_\phi(\cdot|x)\|p_{\mathrm{prior}}\right)\right]\tag{79}\end{align} LVAE(ϕ,θ)=LVAE-Recon(ϕ,θ)+βLVAE-Prior(ϕ)=Expdata(x),zqϕ(zx)[logpθ(xz)]+βExpdata(x)[DKL(qϕ(x)pprior)](78)(79)

其中,第一项保证 latent variables 能够被有效地 decode 回数据空间。第二项保证我们的 latent distribution 接近 Gaussian distribution。参数 β \beta β 用于控制这两部分约束的强度。

为了让这个 loss 更具体,下面我们来推导 Gaussian 情况下的 KL divergence:


Example 31 (KL Divergence Between Isotropic Gaussians)

q ( x ) = N ( x ; μ q , diag ⁡ ( σ q 2 ) ) q(x)=\mathcal N\left(x;\mu_q,\operatorname{diag}(\sigma_q^2)\right) q(x)=N(x;μq,diag(σq2)) 以及 p ( x ) = N ( x ; μ p , diag ⁡ ( σ p 2 ) ) p(x) = \mathcal N \left( x; \mu_p, \operatorname{diag}(\sigma_p^2) \right) p(x)=N(x;μp,diag(σp2)) 是具有对角协方差矩阵的高斯分布,其中 σ q , σ p ∈ R ≥ 0 d \sigma_q,\sigma_p\in\mathbb R_{\ge0}^d σq,σpR0d 并且 x ∈ R d x\in\mathbb R^d xRd 。则:

D K L ( q ∥ p ) = 1 2 ( K  ⁣ ( σ q 2 σ p 2 ) + ∥ μ q − μ p ∥ 2 σ p 2 ) , w h e r e   K ( α ) = ∑ i = 1 d α i − log ⁡ α i − 1. (80) D_{\mathrm{KL}}(q\|p) = \frac12 \left( \mathcal K\!\left( \frac{\sigma_q^2}{\sigma_p^2} \right) + \frac{\|\mu_q-\mu_p\|^2}{\sigma_p^2} \right), \qquad \mathrm{where} \, \mathcal K(\alpha) = \sum_{i=1}^d \alpha_i-\log\alpha_i-1. \tag{80} DKL(qp)=21(K(σp2σq2)+σp2μqμp2),whereK(α)=i=1dαilogαi1.(80)

上面的表达式具有非常直观的含义:如果 mean 与 variance 完全一致,则 D K L ( q ∥ p ) = 0 D_{\mathrm{KL}}(q\|p)=0 DKL(qp)=0 。进一步,它会随着 mean vector 之间的平方误差 ∥ μ q − μ p ∥ 2 \|\mu_q-\mu_p\|^2 μqμp2 增大而增大。最后,函数 K ( α ) \mathcal K(\alpha) K(α) α = 1 \alpha=1 α=1 时具有唯一最小值,因此,当 σ q = σ p \sigma_q=\sigma_p σq=σp 时, D K L ( q ∥ p ) D_{\mathrm{KL}}(q\|p) DKL(qp) 最小。

Proof. 我们先证明 d = 1 d=1 d=1 的情况(对于 d > 1 d>1 d>1 只需要对每个维度求和即可,推导完全类似)。根据正态分布的密度函数,我们知道(见 公式 (97)):

log ⁡ q ( x ) = − 1 2 log ⁡ ( 2 π σ q 2 ) − 1 2 σ q 2 ∥ x − μ q ∥ 2 , log ⁡ p ( x ) = − 1 2 log ⁡ ( 2 π σ p 2 ) − 1 2 σ p 2 ∥ x − μ p ∥ 2 \log q(x) = -\frac12\log(2\pi\sigma_q^2) - \frac{1}{2\sigma_q^2} \|x-\mu_q\|^2, \qquad \log p(x) = -\frac12\log(2\pi\sigma_p^2) - \frac{1}{2\sigma_p^2} \|x-\mu_p\|^2 logq(x)=21log(2πσq2)2σq21xμq2,logp(x)=21log(2πσp2)2σp21xμp2

因此:

D K L ( q ∥ p ) = E x ∼ q [ log ⁡ q ( x ) − log ⁡ p ( x ) ] = 1 2 log ⁡ σ p 2 σ q 2 + 1 2 σ p 2 E q [ ∥ x − μ p ∥ 2 ] − 1 2 σ q 2 E q [ ∥ x − μ q ∥ 2 ] . (81) D_{\mathrm{KL}}(q\|p) = \mathbb E_{x\sim q} \left[ \log q(x)-\log p(x) \right] = \frac12\log\frac{\sigma_p^2}{\sigma_q^2} + \frac{1}{2\sigma_p^2} \mathbb E_q \left[ \|x-\mu_p\|^2 \right] - \frac{1}{2\sigma_q^2} \mathbb E_q \left[ \|x-\mu_q\|^2 \right]. \tag{81} DKL(qp)=Exq[logq(x)logp(x)]=21logσq2σp2+2σp21Eq[xμp2]2σq21Eq[xμq2].(81)

对于 x ∼ N ( μ q , σ q 2 I ) x\sim\mathcal N(\mu_q,\sigma_q^2 I) xN(μq,σq2I) 我们有:

E q [ ∥ x − μ q ∥ 2 ] = tr ⁡ ( σ q 2 I ) = σ q 2 . \mathbb E_q \left[ \|x-\mu_q\|^2 \right] = \operatorname{tr}(\sigma_q^2 I) = \sigma_q^2. Eq[xμq2]=tr(σq2I)=σq2.

再结合 x − μ p = ( x − μ q ) + ( μ q − μ p ) x-\mu_p=(x-\mu_q)+(\mu_q-\mu_p) xμp=(xμq)+(μqμp) 以及 E q [ x − μ q ] = 0 \mathbb E_q[x-\mu_q]=0 Eq[xμq]=0 可以得到:

E q [ ∥ x − μ p ∥ 2 ] = E q [ ∥ x − μ q ∥ 2 ] + ∥ μ q − μ p ∥ 2 = σ q 2 + ∥ μ q − μ p ∥ 2 . \mathbb E_q \left[ \|x-\mu_p\|^2 \right] = \mathbb E_q \left[ \|x-\mu_q\|^2 \right] + \|\mu_q-\mu_p\|^2 = \sigma_q^2 + \|\mu_q-\mu_p\|^2. Eq[xμp2]=Eq[xμq2]+μqμp2=σq2+μqμp2.

将这些结果代入 公式 (81),即可得到 公式 (80)


现在我们假设 encoder 具有 Gaussian 形式。那么我们得到:

L V A E - P r i o r ( ϕ ) = E x ∼ p d a t a ( x ) [ D K L ( q ϕ ( ⋅ ∣ x )    ∥    N ( 0 , I k ) ) ] = E [ 1 2 K ( σ ϕ 2 ( x ) ) + 1 2 ∥ μ ϕ ( x ) ∥ 2 ] (82) \mathcal L_{\mathrm{VAE\text{-}Prior}}(\phi) = \mathbb E_{x\sim p_{\mathrm{data}}(x)} \left[ D_{\mathrm{KL}} \left( q_\phi(\cdot|x) \;\|\; \mathcal N(0,I_k) \right) \right] = \mathbb E \left[ \frac12 \mathcal K \left( \sigma_\phi^2(x) \right) + \frac12 \|\mu_\phi(x)\|^2 \right] \tag{82} LVAE-Prior(ϕ)=Expdata(x)[DKL(qϕ(x)N(0,Ik))]=E[21K(σϕ2(x))+21μϕ(x)2](82)

这个 loss 非常直观:均值 μ ϕ ( x ) \mu_\phi(x) μϕ(x) 会因为偏离 0 而受到惩罚;方差则会因为偏离 1 而受到惩罚。因此,作为 VAE 的总 loss,我们得到:

L V A E ( ϕ , θ ) = L V A E - R e c o n ( ϕ , θ ) + β L V A E - P r i o r ( ϕ ) = E x ∼ p d a t a ( x ) , z ∼ q ϕ ( z ∣ x ) [ 1 2 σ θ 2 ( z ) ∥ x − μ θ ( z ) ∥ 2 ⏟ recon. error + d 2 log ⁡ σ θ 2 ( z ) ⏟ decoder confidence + β 2 K ( σ ϕ 2 ( x ) ) ⏟ make latent variance = 1 + β 2 ∥ μ ϕ ( x ) ∥ 2 ⏟ make latent mean = 0 ] \begin{align*} &\mathcal L_{\mathrm{VAE}}(\phi,\theta) \\ &= \mathcal L_{\mathrm{VAE\text{-}Recon}}(\phi,\theta) + \beta \mathcal L_{\mathrm{VAE\text{-}Prior}}(\phi) \\ &= \mathbb E_{x\sim p_{\mathrm{data}}(x), z\sim q_\phi(z|x)} \Bigg[ \underbrace{ \frac{1}{2\sigma_\theta^2(z)} \|x-\mu_\theta(z)\|^2 }_{\text{recon. error}} + \underbrace{ \frac d2\log\sigma_\theta^2(z) }_{\text{decoder confidence}} + \underbrace{ \frac\beta2 \mathcal K \left( \sigma_\phi^2(x) \right) }_{\text{make latent variance}=1} + \underbrace{ \frac\beta2 \|\mu_\phi(x)\|^2 }_{\text{make latent mean}=0} \Bigg] \tag{83} \end{align*} LVAE(ϕ,θ)=LVAE-Recon(ϕ,θ)+βLVAE-Prior(ϕ)=Expdata(x),zqϕ(zx)[recon. error 2σθ2(z)1xμθ(z)2+decoder confidence 2dlogσθ2(z)+make latent variance=1 2βK(σϕ2(x))+make latent mean=0 2βμϕ(x)2](83)

上面这个 loss function 的四项都非常直观:第一项:只是 reconstruction error;第二项:描述 decoder 的不确定性,更小的 variance 会让 decoder 更 “confident”,但同时也会更强地惩罚 reconstruction error;此外,我们希望 latent variance 为 1,并且 latent mean 为 0,从而保证 latent space 中的分布尽可能接近 Gaussian。

Training a VAE.

现在还剩下一个问题:我们该如何最小化 L V A E ( ϕ , θ ) \mathcal L_{\mathrm{VAE}}(\phi,\theta) LVAE(ϕ,θ) 这个 VAE loss?这个 loss 的问题在于,目前我们对其求期望的分布(即 q ϕ ( z ∣ x ) q_\phi(z|x) qϕ(zx))本身仍然依赖于参数 ϕ \phi ϕ 。不过,我们可以应用所谓的 reparameterization trick(重参数化技巧)来重写它。具体来说,对于:

q ϕ ( z ∣ x ) = N ( z ; μ ϕ ( x ) , σ ϕ 2 ( x ) I k ) q_\phi(z|x) = \mathcal N \left( z; \mu_\phi(x), \sigma_\phi^2(x)I_k \right) qϕ(zx)=N(z;μϕ(x),σϕ2(x)Ik)

我们可以通过下面的方法得到 samples:

ϵ ∼ N ( 0 , I k ) , z = μ ϕ ( x ) + σ ϕ ( x ) ϵ ⇒ z ∼ q ϕ ( ⋅ ∣ x ) \epsilon\sim\mathcal N(0,I_k), \quad z=\mu_\phi(x)+\sigma_\phi(x)\epsilon \quad\Rightarrow\quad z\sim q_\phi(\cdot|x) ϵN(0,Ik),z=μϕ(x)+σϕ(x)ϵzqϕ(x)

注意,在这个方程中,唯一的 noise / stochasticity 来源是 ϵ \epsilon ϵ ,而 ϵ \epsilon ϵ 的分布与 ϕ \phi ϕ 无关。因此,我们可以将 loss 重写为:

L V A E ( ϕ , θ ) = E x ∼ p d a t a ( x ) , ϵ ∼ N ( 0 , I k ) [ 1 2 σ θ 2 ( z ) ∥ x − μ θ ( μ ϕ ( x ) + σ ϕ ( x ) ϵ ) ∥ 2 + d 2 log ⁡ σ θ 2 ( z ) + β 2 K ( σ ϕ 2 ( x ) ) + β 2 ∥ μ ϕ ( x ) ∥ 2 ] \mathcal L_{\mathrm{VAE}}(\phi,\theta) = \mathbb E_{ x\sim p_{\mathrm{data}}(x), \epsilon\sim\mathcal N(0,I_k) } \Bigg[ \frac{1}{2\sigma_\theta^2(z)} \left\| x- \mu_\theta \left( \mu_\phi(x)+\sigma_\phi(x)\epsilon \right) \right\|^2 + \frac d2\log\sigma_\theta^2(z) + \frac\beta2 \mathcal K \left( \sigma_\phi^2(x) \right) + \frac\beta2 \|\mu_\phi(x)\|^2 \Bigg] LVAE(ϕ,θ)=Expdata(x),ϵN(0,Ik)[2σθ2(z)1xμθ(μϕ(x)+σϕ(x)ϵ)2+2dlogσθ2(z)+2βK(σϕ2(x))+2βμϕ(x)2]

在重参数化之后,随机性只来自于 ϵ ∼ N ( 0 , I k ) \epsilon\sim\mathcal N(0,I_k) ϵN(0,Ik) ,而它的分布不依赖于 ϕ \phi ϕ 。因此,我们现在可以使用深度学习的标准工具来最小化这个 loss。为了进一步简化,我们可以再次令 σ θ 2 ( z ) = σ 2 \sigma_\theta^2(z)=\sigma^2 σθ2(z)=σ2 为常数,从而得到:

L V A E ( ϕ , θ ) = E x ∼ p d a t a ( x ) , ϵ ∼ N ( 0 , I k ) [ 1 2 σ 2 ∥ x − μ θ ( μ ϕ ( x ) + σ ϕ ( x ) ϵ ) ∥ 2 + β 2 K ( σ ϕ 2 ( x ) ) + β 2 ∥ μ ϕ ( x ) ∥ 2 ] \mathcal L_{\mathrm{VAE}}(\phi,\theta) = \mathbb E_{ x\sim p_{\mathrm{data}}(x), \epsilon\sim\mathcal N(0,I_k) } \Bigg[ \frac{1}{2\sigma^2} \left\| x- \mu_\theta \left( \mu_\phi(x)+\sigma_\phi(x)\epsilon \right) \right\|^2 + \frac\beta2 \mathcal K \left( \sigma_\phi^2(x) \right) + \frac\beta2 \|\mu_\phi(x)\|^2 \Bigg] LVAE(ϕ,θ)=Expdata(x),ϵN(0,Ik)[2σ21xμθ(μϕ(x)+σϕ(x)ϵ)2+2βK(σϕ2(x))+2βμϕ(x)2]

Algorithm 6 中,我们会总结 VAE 的训练流程。

在这里插入图片描述

Practical remarks.

我们这里给出的构造,展示了 autoencoder 设计的基本原理。当然,在实际应用中,人们可能会加入更多的 loss term 或者其他约束。因此,最后我们给出一些关于 autoencoder 的 practical remarks:

1. Choosing β \beta β(以及 KL warm-up).

较大的 β \beta β 会使 latent 更接近先验分布,但同时也可能损害重建,并可能触发 posterior collapse。也就是说 encoder 会忽略 x x x 并输出 q ϕ ( z ∣ x ) ≈ N ( 0 , I k ) q_\phi(z|x)\approx\mathcal N(0,I_k) qϕ(zx)N(0,Ik) 。一种常见的稳定化技巧是 KL warm-up:一开始令 β = 0 \beta=0 β=0 然后在前几个 epochs 中,逐渐将其增大到目标值。不过,在现代 autoencoder 中 β \beta β 通常都非常小,即 β ≪ 1 \beta\ll1 β1

2. Decoder variance.

学习 Gaussian decoder variance σ θ 2 \sigma_\theta^2 σθ2 在数值上可能比较 delicate,并且如果不加正则化,可能会导致 degenerate solution。为了保证稳定性,很多实现会固定 p θ ( x ∣ z ) = N ( x ; μ θ ( z ) , σ 2 I d ) p_\theta(x|z) = \mathcal N \left( x; \mu_\theta(z), \sigma^2 I_d \right) pθ(xz)=N(x;μθ(z),σ2Id) ,其中 σ 2 \sigma^2 σ2 为常数。这样 reconstruction term 就会与 mean-squared error 成正比(忽略 constant 项)。

3. 超越 pixel MSE 的 reconstruction loss.

对于图像,像素级高斯似然(即 mean squared error)通常会产生过于平滑的 reconstruction。因此,在实践中,人们会加入 perceptual losses(即使用预训练网络的特征空间损失)来提升清晰度和语义保真度。

4. Adversarial 与 hybrid objective.

为了进一步提升视觉真实感,可以将 VAE objective 与对抗损失(VAE-GAN 风格)结合起来。做法是在 decoded samples 上使用判别器。这种方法通常会使输出更加清晰,但同时也会引入额外的优化不稳定性以及更多的超参数。


Remark 32 (Working in Latent Space)

为了训练一个 latent generative model,我们只需要遵循已有的训练方案,但直接在 latent space 中工作。在训练阶段,我们使用 x ∼ p d a t a x\sim p_{\mathrm{data}} xpdata 并从 q ϕ ( z ∣ x ) q_\phi(z|x) qϕ(zx) 中采样。在推理阶段,我们从 latent diffusion 或 latent flow model 中采样 z z z ,随后通过 x = μ m e a n ( z ) x=\mu_{\mathrm{mean}}(z) x=μmean(z) 进行 decode(注意,这里我们取的是 mean,而不是随机 sample,这样可以避免噪声引起的伪影)。

从直觉上来说,一个训练良好的 autoencoder 可以被视为过滤掉高频或者其他在语义上无意义的细节。这样,generative model 便可以更加专注于重要的、在感知上相关的特征 [36]

在本文写作时,几乎所有 image / video generation 的 state-of-the-art 方法,都采用了所谓的 latent diffusion 范式。也就是说,在 autoencoder 的 latent space 中训练 flow 或 diffusion model [36, 48]。

不过需要注意的是,在训练 diffusion model 之前,还必须先训练 autoencoder。并且,最终性能同样依赖于 autoencoder 是否能够很好地将 image 压缩到 latent space 以及恢复出美观的 image。


我们会在 Section D 中,进一步讨论 VAE。

6.3 Case Study: Stable Diffusion 3 and Meta Movie Gen

我们通过简要分析两个大规模生成模型来结束本节:用于 image generation 的 Stable Diffusion 3;用于 video generation 的 Meta Movie Gen [14, 33]。正如你将会看到的,这些模型使用了我们在本文中介绍的技术,同时还加入了额外的架构增强以便扩展模型规模和适配结构更加复杂的调节方式,例如基于文本的输入。

6.3.1 Stable Diffusion 3

Stable Diffusion 是一系列 state-of-the-art 的图像生成模型。这些模型是最早将 large-scale latent diffusion models 用于图像生成的工作之一。如果你还没有体验过,我们非常推荐你亲自在线测试一下:https://stability.ai/news/stable-diffusion-3

Stable Diffusion 3 使用了与本文研究相同的 conditional flow matching objective(见 Algorithm 4)。注意,在他们的工作中,他们使用了一种不同的约定来对噪声施加条件。但这只是符号上的差异,算法本身是相同的。正如其论文中所描述的那样,他们广泛测试了各种 flow 与 diffusion 替代方案,并发现 flow matching 的效果最好。在训练方面,它使用了 classifier-free guidance training(通过随机丢弃 class labels),与上文描述一致。此外,Stable Diffusion 3 遵循了 Section 6.1 中介绍的方法:在一个预训练 autoencoder 的 latent space 中进行训练。训练一个高质量 autoencoder 是早期 stable diffusion 工作中的重要贡献之一。

为了增强 text conditioning,Stable Diffusion 3 同时使用了 3 种不同类型的 text embeddings,包括 CLIP embeddings;以及由 Google T5-XXL [35] encoder 的一个 pretrained 实例产生的 sequential outputs;并采用了与文献 [3, 39] 类似的方法。其中,CLIP embeddings 提供的是一种粗略、整体性的文本表示;而 T5 embeddings 则提供了更加细粒度的上下文信息,从而使模型能够关注 conditioning text 中的特定元素。

为了适配这种 sequential context embeddings,作者进一步提出扩展 diffusion transformer,使其不仅能够 attention 到 image patches,还能够 attention 到 text embeddings。这相当于将 DiT 最初面向 class-based conditioning 的机制,扩展到了 sequential context embeddings。这种修改后的 DiT 被称为 multi-modal DiT(MM-DiT),并在 Figure 16 中进行了展示。

他们最终最大的模型拥有 8 billion parameters。在 sampling 时,他们使用 50 个 steps(即网络需要被 evaluate 50 次);采用 Euler simulation scheme;并且 classifier-free guidance weight 位于 2.0–5.0 之间。

Figure 16:文献 [14] 中提出的多模态 diffusion transformer(MM-DiT)的架构。该图同样引自 [14]

6.3.2 Meta Movie Gen Video

接下来,我们讨论 Meta 的视频生成器 Movie Gen Video https://ai.meta.com/research/movie-gen/。由于这里的数据不再是 images,而是 videos,因此数据 x x x 位于空间 R T × C × H × W \mathbb{R}^{T \times C \times H \times W} RT×C×H×W ,其中 T T T 表示新增的 temporal dimension(即 frame 的数量)。正如我们将看到的那样,许多在视频场景中的设计选择,本质上都可以看作是将已有的图像生成技术(例如 autoencoders、diffusion transformers 等)扩展到额外 temporal dimension 的结果。

Movie Gen Video 使用 conditional flow matching objective,并采用如下直线调度器 α t = t , σ t = 1 − t \alpha_t = t, \sigma_t = 1 - t αt=t,σt=1t 。与 Stable Diffusion 3 类似,Movie Gen Video 也运行在冻结的、预训练 autoencoder 的 latent space 中。需要注意的是,用于降低内存消耗的 autoencoder 在视频场景中甚至比图像场景更加重要。这也是为什么当前大多数 video generators 在生成视频长度方面仍然比较受限。

具体而言,作者提出通过引入 temporal autoencoder(TAE)来处理新增的时间维度。该模型将原始视频 x t ′ ∈ R T ′ × 3 × H × W x_t' \in \mathbb{R}^{T' \times 3 \times H \times W} xtRT×3×H×W 映射到 latent x t ∈ R T × C × H × W x_t \in \mathbb{R}^{T \times C \times H \times W} xtRT×C×H×W ,其中 T ′ T = H ′ H = W ′ W = 8 \frac{T'}{T}=\frac{H'}{H}=\frac{W'}{W}=8 TT=HH=WW=8 [33]

为了适配长视频,作者提出了一种 temporal tiling procedure,将视频切分为多个 pieces,每个 piece 单独进行 encoder,随后再将 latents 拼接起来。模型本身(即 u t θ ( x t ) u_t^\theta(x_t) utθ(xt))采用了类似 DiT 的 backbone,其中 x t x_t xt 会同时沿着时间维度与空间维度进行 patchification。随后,这些 image patches 会被送入 transformer,并同时使用 image patches 之间的 self-attention 与语言模型 embeddings 的 cross-attention,其方式与 Stable Diffusion 3 中使用的 MM-DiT 类似。

在 text conditioning 方面,Movie Gen Video 使用了三种 text embeddings:

  • UL2 embeddings:用于更加细粒度的 text-based reasoning [47]
  • ByT5 embeddings:用于关注字符级细节(例如 prompt 中明确要求出现某些特定文本)[50]
  • MetaCLIP embeddings:在 shared text-image embedding space 中训练得到 [24, 33]。

其最终最大的模型拥有 30 billion parameters。如果希望获得更加详细且完整的介绍,建议直接阅读 Movie Gen 技术报告 [33]

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐