环境声明

  • Python版本:Python 3.10+
  • PyTorch版本:PyTorch 2.0+
  • 开发工具:PyCharm 或 VS Code
  • 操作系统:Windows / macOS / Linux (通用)
  • 可选依赖:numpy, matplotlib, torchvision

学习目标

本章将带你深入探索深度生成模型的另外两大重要范式:归一化流(Normalizing Flows)和能量模型(Energy-Based Models)。通过本章学习,你将掌握:

  1. 理解归一化流的数学原理和可逆变换机制
  2. 掌握RealNVP、Glow等经典流模型的架构设计
  3. 理解能量模型的核心思想和玻尔兹曼分布
  4. 学习得分匹配(Score Matching)和去噪得分匹配(DSM)的理论基础
  5. 了解2024-2025年最新进展:Flow Matching和Consistency Models
  6. 能够使用PyTorch实现基础的流模型

摘要:除了VAE和GAN之外,归一化流通过可逆神经网络实现了精确的密度估计;能量模型通过定义能量函数来刻画数据分布;而得分匹配方法则为扩散模型奠定了理论基础。近年来,Flow Matching和Consistency Models作为新兴的生成范式,在生成质量和采样效率上取得了突破性进展。


1. 归一化流(Normalizing Flows)原理

1.1 什么是归一化流

归一化流是一种生成模型,它通过一系列可逆变换将简单的先验分布(如标准高斯分布)转换为复杂的数据分布。与VAE和GAN不同,归一化流可以精确计算数据的对数似然,这是其最大的优势。

核心思想:想象你有一瓶清水(简单分布),通过一系列精心设计的操作(可逆变换),你可以将它变成任意复杂的形状(数据分布),而且这个过程是完全可逆的。

1.2 数学基础:变量变换公式

归一化流的理论基础是概率论中的变量变换公式。假设我们有一个可逆变换 z = f(x),其中 z 服从简单的先验分布 p_Z(z),那么 x 的分布可以通过以下公式计算:

pX(x)=pZ(f(x))⋅∣det⁡∂f(x)∂x∣p_X(x) = p_Z(f(x)) \cdot \left| \det \frac{\partial f(x)}{\partial x} \right|pX(x)=pZ(f(x)) detxf(x)

其中:

  • pX(x)p_X(x)pX(x) 是数据分布
  • pZ(z)p_Z(z)pZ(z) 是先验分布(通常是标准高斯分布)
  • ∣det⁡∂f(x)∂x∣\left| \det \frac{\partial f(x)}{\partial x} \right| detxf(x) 是变换的雅可比行列式的绝对值

取对数后,我们得到对数似然:

log⁡pX(x)=log⁡pZ(f(x))+log⁡∣det⁡∂f(x)∂x∣\log p_X(x) = \log p_Z(f(x)) + \log \left| \det \frac{\partial f(x)}{\partial x} \right|logpX(x)=logpZ(f(x))+log detxf(x)

1.3 可逆变换的构建

为了构建可逆神经网络,我们需要设计特殊的网络结构。关键要求是:

  1. 可逆性:变换必须是双射(一一对应)
  2. 雅可比行列式易计算:行列式的计算复杂度不能太高
  3. 表达能力:变换需要足够复杂以建模真实数据分布

常用的可逆变换包括:

  • 仿射耦合层(Affine Coupling Layer)
  • 加性耦合层(Additive Coupling Layer)
  • 可逆1x1卷积
  • 激活归一化(ActNorm)

1.4 仿射耦合层详解

仿射耦合层是RealNVP等流模型的核心组件。其工作原理如下:

给定输入 x∈RDx \in \mathbb{R}^DxRD,将其分成两部分:

  • x1:dx_{1:d}x1:d:前d维(保持不变)
  • xd+1:Dx_{d+1:D}xd+1:D:剩余维度(进行变换)

变换公式为:

y1:d=x1:dy_{1:d} = x_{1:d}y1:d=x1:d
yd+1:D=xd+1:D⊙exp⁡(s(x1:d))+t(x1:d)y_{d+1:D} = x_{d+1:D} \odot \exp(s(x_{1:d})) + t(x_{1:d})yd+1:D=xd+1:Dexp(s(x1:d))+t(x1:d)

其中:

  • sssttt 是任意的神经网络(通常是MLP或CNN)
  • ⊙\odot 表示逐元素乘法
  • exp⁡\expexp 确保尺度因子为正

逆变换为:

x1:d=y1:dx_{1:d} = y_{1:d}x1:d=y1:d
xd+1:D=(yd+1:D−t(y1:d))⊙exp⁡(−s(y1:d))x_{d+1:D} = (y_{d+1:D} - t(y_{1:d})) \odot \exp(-s(y_{1:d}))xd+1:D=(yd+1:Dt(y1:d))exp(s(y1:d))

雅可比行列式:由于 y1:d=x1:dy_{1:d} = x_{1:d}y1:d=x1:d,雅可比矩阵是三角矩阵,其行列式为对角元素的乘积:

det⁡∂y∂x=∏i=d+1Dexp⁡(s(x1:d)i)=exp⁡(∑i=d+1Ds(x1:d)i)\det \frac{\partial y}{\partial x} = \prod_{i=d+1}^{D} \exp(s(x_{1:d})_i) = \exp\left(\sum_{i=d+1}^{D} s(x_{1:d})_i\right)detxy=i=d+1Dexp(s(x1:d)i)=exp(i=d+1Ds(x1:d)i)


2. 经典流模型架构

2.1 RealNVP(Real-valued Non-Volume Preserving)

RealNVP是2017年提出的开创性工作,首次将仿射耦合层应用于图像生成。

架构特点

  • 使用仿射耦合层作为基本构建块
  • 采用通道维度分割(channel-wise masking)
  • 引入多尺度架构(multi-scale architecture)

多尺度架构:为了降低计算复杂度,RealNVP在每一层后将部分维度直接输出,不再参与后续变换。这类似于小波分解,可以在多个尺度上建模数据分布。

2.2 Glow

Glow是OpenAI在2018年提出的改进版本,在RealNVP的基础上引入了两个重要创新:

1. 可逆1x1卷积

在RealNVP中,通道之间的信息混合需要通过耦合层逐步完成。Glow引入了可逆1x1卷积来加速这一过程:

yi,j=W⋅xi,jy_{i,j} = W \cdot x_{i,j}yi,j=Wxi,j

其中 WWW 是可学习的可逆矩阵。雅可比行列式为:

log⁡∣det⁡(J)∣=h⋅w⋅log⁡∣det⁡(W)∣\log |\det(J)| = h \cdot w \cdot \log |\det(W)|logdet(J)=hwlogdet(W)

2. 激活归一化(ActNorm)

替代了RealNVP中的批归一化,ActNorm对每个通道进行仿射变换:

y=s⊙x+by = s \odot x + by=sx+b

其中 sssbbb 是可学习的参数,通过数据初始化使得每个通道的均值为0、方差为1。

2.3 Flow++

Flow++是2019年提出的进一步改进,主要贡献包括:

1. 自回归耦合层

使用自回归网络替代独立的MLP/CNN,可以更好地捕捉维度间的依赖关系。

2. 连续混合分布

不再使用简单的高斯分布作为条件分布,而是使用更复杂的混合分布(如逻辑混合分布)。

3. 去量化处理

图像数据是离散的(0-255整数),Flow++引入了均匀去量化(uniform dequantization)和变分去量化(variational dequantization)来处理这一问题。


3. 能量模型与玻尔兹曼机

3.1 能量模型的基本思想

能量模型(Energy-Based Models, EBM)是一类通过定义能量函数来刻画数据分布的生成模型。其核心思想是:数据点的能量越低,其出现的概率越高。

玻尔兹曼分布

p(x)=exp⁡(−E(x))Zp(x) = \frac{\exp(-E(x))}{Z}p(x)=Zexp(E(x))

其中:

  • E(x)E(x)E(x) 是能量函数(通常由神经网络参数化)
  • Z=∫exp⁡(−E(x))dxZ = \int \exp(-E(x)) dxZ=exp(E(x))dx 是配分函数(归一化常数)
  • 温度参数被设为1(可吸收进能量函数)

3.2 受限玻尔兹曼机(RBM)

RBM是最经典的能量模型之一,由Hinton等人于1986年提出。它包含两层:

  • 可见层(visible layer):vvv,表示观测数据
  • 隐藏层(hidden layer):hhh,表示潜在特征

能量函数

E(v,h)=−vTWh−bTv−cThE(v, h) = -v^T W h - b^T v - c^T hE(v,h)=vTWhbTvcTh

其中 WWWbbbccc 是模型参数。

联合分布

p(v,h)=exp⁡(−E(v,h))Zp(v, h) = \frac{\exp(-E(v, h))}{Z}p(v,h)=Zexp(E(v,h))

边缘分布

p(v)=∑hp(v,h)=∑hexp⁡(−E(v,h))Zp(v) = \sum_h p(v, h) = \frac{\sum_h \exp(-E(v, h))}{Z}p(v)=hp(v,h)=Zhexp(E(v,h))

3.3 深度玻尔兹曼机(DBM)

DBM是RBM的深层扩展,包含多个隐藏层:

E(v,h(1),h(2),...)=−vTW(1)h(1)−h(1)TW(2)h(2)−...E(v, h^{(1)}, h^{(2)}, ...) = -v^T W^{(1)} h^{(1)} - h^{(1)T} W^{(2)} h^{(2)} - ...E(v,h(1),h(2),...)=vTW(1)h(1)h(1)TW(2)h(2)...

DBM可以学习更复杂的特征层次,但训练和采样也更加困难。

3.4 能量模型的训练挑战

EBM的训练面临两个主要挑战:

1. 配分函数的计算

ZZZ 涉及高维积分,难以精确计算。常用近似方法包括:

  • 对比散度(Contrastive Divergence, CD)
  • 持续性对比散度(Persistent CD)
  • 噪声对比估计(Noise Contrastive Estimation, NCE)

2. 采样困难

p(x)∝exp⁡(−E(x))p(x) \propto \exp(-E(x))p(x)exp(E(x)) 采样需要MCMC方法,如:

  • 吉布斯采样(Gibbs Sampling)
  • 朗之万动力学(Langevin Dynamics)

4. 得分匹配与去噪得分匹配

4.1 得分匹配(Score Matching)

得分匹配是Hyvarinen于2005年提出的一种训练方法,它避免了直接计算配分函数。

得分函数(Score Function)

s(x)=∇xlog⁡p(x)s(x) = \nabla_x \log p(x)s(x)=xlogp(x)

得分函数指向对数概率密度增长最快的方向,与配分函数无关:

∇xlog⁡p(x)=−∇xE(x)−∇xlog⁡Z=−∇xE(x)\nabla_x \log p(x) = -\nabla_x E(x) - \nabla_x \log Z = -\nabla_x E(x)xlogp(x)=xE(x)xlogZ=xE(x)

得分匹配目标

我们希望模型得分 sθ(x)s_\theta(x)sθ(x) 接近真实数据得分 sdata(x)s_{data}(x)sdata(x)。由于我们不知道真实数据分布,Hyvarinen提出了以下目标函数:

JSM(θ)=12Epdata[∥sθ(x)−∇xlog⁡pdata(x)∥2]J_{SM}(\theta) = \frac{1}{2} \mathbb{E}_{p_{data}} \left[ \| s_\theta(x) - \nabla_x \log p_{data}(x) \|^2 \right]JSM(θ)=21Epdata[sθ(x)xlogpdata(x)2]

经过推导,可以得到不需要真实得分的等价形式:

JSM(θ)=Epdata[tr(∇xsθ(x))+12∥sθ(x)∥2]J_{SM}(\theta) = \mathbb{E}_{p_{data}} \left[ \text{tr}(\nabla_x s_\theta(x)) + \frac{1}{2} \| s_\theta(x) \|^2 \right]JSM(θ)=Epdata[tr(xsθ(x))+21sθ(x)2]

其中 tr(∇xsθ(x))\text{tr}(\nabla_x s_\theta(x))tr(xsθ(x)) 是得分函数雅可比矩阵的迹。

4.2 去噪得分匹配(Denoising Score Matching, DSM)

原始得分匹配需要计算Hessian矩阵的迹,计算成本较高。Vincent于2011年提出了去噪得分匹配作为替代方案。

核心思想

对数据添加少量噪声,然后训练模型估计噪声的分布。具体来说,设 qσ(x~∣x)q_\sigma(\tilde{x} | x)qσ(x~x) 是给定 xxx 时噪声数据 x~\tilde{x}x~ 的条件分布(通常是高斯分布),则DSM目标为:

JDSM(θ)=12Epdata(x)Eqσ(x~∣x)[∥sθ(x~)−∇x~log⁡qσ(x~∣x)∥2]J_{DSM}(\theta) = \frac{1}{2} \mathbb{E}_{p_{data}(x)} \mathbb{E}_{q_\sigma(\tilde{x}|x)} \left[ \| s_\theta(\tilde{x}) - \nabla_{\tilde{x}} \log q_\sigma(\tilde{x}|x) \|^2 \right]JDSM(θ)=21Epdata(x)Eqσ(x~x)[sθ(x~)x~logqσ(x~x)2]

对于高斯噪声 qσ(x~∣x)=N(x~;x,σ2I)q_\sigma(\tilde{x}|x) = \mathcal{N}(\tilde{x}; x, \sigma^2 I)qσ(x~x)=N(x~;x,σ2I),有:

∇x~log⁡qσ(x~∣x)=−x~−xσ2\nabla_{\tilde{x}} \log q_\sigma(\tilde{x}|x) = -\frac{\tilde{x} - x}{\sigma^2}x~logqσ(x~x)=σ2x~x

因此,DSM目标简化为:

JDSM(θ)=12E[∥sθ(x~)+x~−xσ2∥2]J_{DSM}(\theta) = \frac{1}{2} \mathbb{E} \left[ \left\| s_\theta(\tilde{x}) + \frac{\tilde{x} - x}{\sigma^2} \right\|^2 \right]JDSM(θ)=21E[ sθ(x~)+σ2x~x 2]

这意味着模型 sθ(x~)s_\theta(\tilde{x})sθ(x~) 实际上在学习估计噪声的方向和大小。

4.3 噪声条件得分网络(NCSN)

2019年,Song和Ermon提出了NCSN,使用多尺度噪声训练得分网络:

L(θ)=1L∑i=1Lλ(σi)Epdata(x)Eqσi(x~∣x)[∥sθ(x~,σi)−∇x~log⁡qσi(x~∣x)∥2]L(\theta) = \frac{1}{L} \sum_{i=1}^{L} \lambda(\sigma_i) \mathbb{E}_{p_{data}(x)} \mathbb{E}_{q_{\sigma_i}(\tilde{x}|x)} \left[ \| s_\theta(\tilde{x}, \sigma_i) - \nabla_{\tilde{x}} \log q_{\sigma_i}(\tilde{x}|x) \|^2 \right]L(θ)=L1i=1Lλ(σi)Epdata(x)Eqσi(x~x)[sθ(x~,σi)x~logqσi(x~x)2]

训练完成后,可以使用退火朗之万动力学(annealed Langevin dynamics)进行采样。


5. 生成模型对比

下表对比了VAE、GAN、流模型、扩散模型和能量模型的主要特点:

特性 VAE GAN 流模型 扩散模型 能量模型
训练目标 ELBO 对抗损失 对数似然 去噪得分匹配 得分匹配/NCE
隐变量 无(可逆) 有(时间序列)
精确似然 下界 有(变分下界) 难计算
采样速度 慢(多步) 慢(MCMC)
模式覆盖 差(模式崩溃) 极好 取决于采样
训练稳定性 稳定 不稳定 稳定 稳定 较难稳定
图像质量 中等 极高 中等
主要应用 表示学习 图像生成 密度估计 图像生成 密度估计

一句话总结

  • VAE:学习数据的压缩表示,适合表示学习和半监督学习
  • GAN:生成逼真样本,但训练不稳定且难以评估
  • 流模型:精确密度估计,适合需要概率建模的任务
  • 扩散模型:当前图像生成的SOTA,质量最高但采样慢
  • 能量模型:理论优雅但训练困难,是理解其他模型的基础

6. 2024-2025年最新进展

6.1 Flow Matching

Flow Matching是2022-2023年兴起的新型生成建模框架,在ICML 2025成为热门主题。它与扩散模型密切相关,但提供了更简洁的数学形式和更高效的训练。

核心思想

Flow Matching直接学习一个向量场 vt(x)v_t(x)vt(x),它定义了从先验分布 p0p_0p0 到数据分布 p1p_1p1 的概率流。这个向量场满足连续性方程:

∂pt∂t+∇⋅(ptvt)=0\frac{\partial p_t}{\partial t} + \nabla \cdot (p_t v_t) = 0tpt+(ptvt)=0

条件流匹配

给定数据点 x1x_1x1,定义条件概率路径 pt(x∣x1)p_t(x|x_1)pt(xx1),通常选择高斯:

pt(x∣x1)=N(x;t⋅x1,(1−t)2⋅I)p_t(x|x_1) = \mathcal{N}(x; t \cdot x_1, (1-t)^2 \cdot I)pt(xx1)=N(x;tx1,(1t)2I)

对应的条件向量场为:

vt(x∣x1)=x1−x1−tv_t(x|x_1) = \frac{x_1 - x}{1-t}vt(xx1)=1tx1x

训练目标

LFM(θ)=Et,x1,x∼pt(⋅∣x1)[∥vθ(t,x)−vt(x∣x1)∥2]L_{FM}(\theta) = \mathbb{E}_{t, x_1, x \sim p_t(\cdot|x_1)} \left[ \| v_\theta(t, x) - v_t(x|x_1) \|^2 \right]LFM(θ)=Et,x1,xpt(x1)[vθ(t,x)vt(xx1)2]

与扩散模型的关系

Flow Matching和扩散模型本质上是同一枚硬币的两面。扩散模型学习得分函数,Flow Matching学习向量场,两者通过以下关系联系:

vt(x)=12∇log⁡pt(x)v_t(x) = \frac{1}{2} \nabla \log p_t(x)vt(x)=21logpt(x)

优势

  • 训练更稳定,不需要复杂的噪声调度
  • 采样可以使用更高效的ODE求解器
  • 理论框架更统一,便于扩展

6.2 Consistency Models

Consistency Models是OpenAI于2023年提出的扩散模型加速方法,可以实现单步或少步生成。

核心思想

一致性模型学习一个函数 fff,它将任意时间步 ttt 的噪声数据 xtx_txt 直接映射到时间步0的干净数据:

f:(xt,t)↦x0f: (x_t, t) \mapsto x_0f:(xt,t)x0

并且满足一致性属性:对于同一轨迹上的任意两点,映射结果应该相同:

f(xt,t)=f(xt′,t′)当 xt,xt′ 属于同一PF ODE轨迹f(x_t, t) = f(x_{t'}, t') \quad \text{当 } x_t, x_{t'} \text{ 属于同一PF ODE轨迹}f(xt,t)=f(xt,t) xt,xt 属于同一PF ODE轨迹

训练目标

使用自举(bootstrapping)方法训练:

LCM(θ)=Ex,t[λ(t)⋅d(fθ(xt,t),fθ−(x^t−Δt,t−Δt))]L_{CM}(\theta) = \mathbb{E}_{x, t} \left[ \lambda(t) \cdot d(f_\theta(x_t, t), f_{\theta^-}(\hat{x}_{t-\Delta t}, t-\Delta t)) \right]LCM(θ)=Ex,t[λ(t)d(fθ(xt,t),fθ(x^tΔt,tΔt))]

其中:

  • fθ−f_{\theta^-}fθ 是目标网络(EMA更新)
  • x^t−Δt\hat{x}_{t-\Delta t}x^tΔt 是通过ODE求解器得到的一步预测
  • ddd 是距离度量(如L2距离或LPIPS)

Latent Consistency Models (LCM)

将一致性模型应用到潜空间(如Stable Diffusion的VAE编码空间),可以进一步加速生成:

  • 生图速度提升5-10倍
  • 实现秒级甚至实时生成
  • 在保持质量的同时大幅降低计算成本

应用

  • 实时图像生成
  • 视频生成加速
  • 交互式创作工具

7. 流模型PyTorch实现

下面是一个完整的RealNVP风格流模型的PyTorch实现:

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt

# 设置随机种子
torch.manual_seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


class CouplingLayer(nn.Module):
    """
    仿射耦合层(Affine Coupling Layer)
    这是RealNVP的核心组件
    """
    def __init__(self, input_dim, hidden_dim=256, mask_type='channel'):
        super().__init__()
        self.input_dim = input_dim
        self.mask_type = mask_type
        
        # 分割维度:一半保持不变,一半进行变换
        self.split_dim = input_dim // 2
        
        # 尺度网络 s(·)
        self.scale_net = nn.Sequential(
            nn.Linear(self.split_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim - self.split_dim),
            nn.Tanh()  # 限制尺度范围,增加稳定性
        )
        
        # 平移网络 t(·)
        self.translate_net = nn.Sequential(
            nn.Linear(self.split_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim - self.split_dim)
        )
    
    def forward(self, x, reverse=False):
        """
        前向/逆向变换
        
        Args:
            x: 输入张量,形状 (batch_size, input_dim)
            reverse: 是否进行逆变换
        
        Returns:
            y: 变换后的张量
            log_det: 对数雅可比行列式
        """
        # 分割输入
        x1, x2 = x[:, :self.split_dim], x[:, self.split_dim:]
        
        if not reverse:
            # 前向变换: y1 = x1, y2 = x2 * exp(s(x1)) + t(x1)
            s = self.scale_net(x1)
            t = self.translate_net(x1)
            y1 = x1
            y2 = x2 * torch.exp(s) + t
            log_det = s.sum(dim=1)  # 对数雅可比行列式
        else:
            # 逆变换: x1 = y1, x2 = (y2 - t(y1)) * exp(-s(y1))
            s = self.scale_net(x1)
            t = self.translate_net(x1)
            y1 = x1
            y2 = (x2 - t) * torch.exp(-s)
            log_det = -s.sum(dim=1)
        
        y = torch.cat([y1, y2], dim=1)
        return y, log_det


class PermutationLayer(nn.Module):
    """
    置换层:用于在耦合层之间混合维度信息
    使用固定的随机置换矩阵
    """
    def __init__(self, input_dim):
        super().__init__()
        # 创建随机置换
        permutation = torch.randperm(input_dim)
        self.register_buffer('permutation', permutation)
        # 逆置换
        self.register_buffer('inverse_permutation', torch.argsort(permutation))
    
    def forward(self, x, reverse=False):
        if not reverse:
            return x[:, self.permutation], torch.zeros(x.size(0), device=x.device)
        else:
            return x[:, self.inverse_permutation], torch.zeros(x.size(0), device=x.device)


class NormalizingFlow(nn.Module):
    """
    归一化流模型
    堆叠多个耦合层和置换层
    """
    def __init__(self, input_dim, num_flows=4, hidden_dim=256):
        super().__init__()
        self.input_dim = input_dim
        self.num_flows = num_flows
        
        # 构建流层
        self.flows = nn.ModuleList()
        for i in range(num_flows):
            self.flows.append(CouplingLayer(input_dim, hidden_dim))
            if i < num_flows - 1:  # 最后一层不需要置换
                self.flows.append(PermutationLayer(input_dim))
    
    def forward(self, x, reverse=False):
        """
        通过整个流进行前向/逆向变换
        
        Args:
            x: 输入张量
            reverse: 是否进行逆变换(采样时使用)
        
        Returns:
            z: 变换后的潜变量(或重构数据)
            log_det_total: 总的对数雅可比行列式
        """
        log_det_total = torch.zeros(x.size(0), device=x.device)
        
        if not reverse:
            # 编码: x -> z
            for flow in self.flows:
                x, log_det = flow(x, reverse=False)
                log_det_total += log_det
            return x, log_det_total
        else:
            # 解码/采样: z -> x
            for flow in reversed(self.flows):
                x, log_det = flow(x, reverse=True)
                log_det_total += log_det
            return x, log_det_total
    
    def log_prob(self, x):
        """
        计算数据的对数概率
        
        Args:
            x: 数据样本
        
        Returns:
            log_p: 对数概率
        """
        z, log_det = self.forward(x, reverse=False)
        
        # 先验分布: 标准高斯
        log_pz = -0.5 * (z ** 2 + np.log(2 * np.pi)).sum(dim=1)
        
        # 通过变量变换公式
        log_px = log_pz + log_det
        return log_px
    
    def sample(self, num_samples):
        """
        从模型中采样
        
        Args:
            num_samples: 采样数量
        
        Returns:
            samples: 生成的样本
        """
        # 从先验分布采样
        z = torch.randn(num_samples, self.input_dim, device=device)
        
        # 通过逆变换生成数据
        x, _ = self.forward(z, reverse=True)
        return x


class MNISTFlowModel(nn.Module):
    """
    针对MNIST数据集的流模型包装器
    """
    def __init__(self, image_size=28*28, num_flows=8, hidden_dim=512):
        super().__init__()
        self.flow = NormalizingFlow(image_size, num_flows, hidden_dim)
    
    def forward(self, x):
        """
        训练前向传播
        
        Args:
            x: 图像张量,形状 (batch_size, 1, 28, 28)
        
        Returns:
            loss: 负对数似然
        """
        # 展平图像
        batch_size = x.size(0)
        x_flat = x.view(batch_size, -1)
        
        # 添加微小噪声进行去量化(将离散数据转为连续)
        x_flat = x_flat + torch.rand_like(x_flat) / 256.0
        
        # 计算对数概率
        log_prob = self.flow.log_prob(x_flat)
        
        # 返回负对数似然作为损失
        return -log_prob.mean()
    
    def sample(self, num_samples):
        """
        生成样本
        """
        with torch.no_grad():
            samples = self.flow.sample(num_samples)
            # 重塑为图像形状
            samples = samples.view(num_samples, 1, 28, 28)
            # 裁剪到有效范围
            samples = torch.clamp(samples, 0, 1)
            return samples


def train_epoch(model, dataloader, optimizer):
    """
    训练一个epoch
    """
    model.train()
    total_loss = 0
    num_batches = 0
    
    for batch_idx, (data, _) in enumerate(dataloader):
        data = data.to(device)
        
        optimizer.zero_grad()
        loss = model(data)
        loss.backward()
        
        # 梯度裁剪,增加稳定性
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        optimizer.step()
        
        total_loss += loss.item()
        num_batches += 1
        
        if batch_idx % 100 == 0:
            print(f'Batch {batch_idx}/{len(dataloader)}, Loss: {loss.item():.4f}')
    
    return total_loss / num_batches


def visualize_samples(model, epoch, num_samples=16):
    """
    可视化生成的样本
    """
    model.eval()
    samples = model.sample(num_samples).cpu()
    
    fig, axes = plt.subplots(4, 4, figsize=(8, 8))
    for i, ax in enumerate(axes.flat):
        ax.imshow(samples[i].squeeze(), cmap='gray')
        ax.axis('off')
    
    plt.suptitle(f'Generated Samples - Epoch {epoch}')
    plt.tight_layout()
    plt.savefig(f'samples_epoch_{epoch}.png')
    plt.close()


def main():
    """
    主训练函数
    """
    # 超参数
    batch_size = 128
    epochs = 10
    learning_rate = 1e-4
    num_flows = 8
    hidden_dim = 512
    
    # 数据加载
    transform = transforms.Compose([
        transforms.ToTensor(),
    ])
    
    train_dataset = datasets.MNIST(
        root='./data', 
        train=True, 
        download=True, 
        transform=transform
    )
    
    train_loader = DataLoader(
        train_dataset, 
        batch_size=batch_size, 
        shuffle=True,
        num_workers=2
    )
    
    # 创建模型
    model = MNISTFlowModel(
        image_size=28*28,
        num_flows=num_flows,
        hidden_dim=hidden_dim
    ).to(device)
    
    print(f"模型参数数量: {sum(p.numel() for p in model.parameters()):,}")
    
    # 优化器
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    
    # 学习率调度器
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)
    
    # 训练循环
    for epoch in range(1, epochs + 1):
        print(f'\n=== Epoch {epoch}/{epochs} ===')
        avg_loss = train_epoch(model, train_loader, optimizer)
        print(f'Average Loss: {avg_loss:.4f}')
        
        # 更新学习率
        scheduler.step()
        
        # 每几个epoch可视化一次
        if epoch % 2 == 0:
            visualize_samples(model, epoch)
    
    # 最终采样
    print('\n生成最终样本...')
    visualize_samples(model, epochs, num_samples=16)
    
    # 保存模型
    torch.save(model.state_dict(), 'flow_model_mnist.pth')
    print('模型已保存')


if __name__ == '__main__':
    main()

代码说明

1. CouplingLayer(耦合层)

  • 实现仿射耦合变换
  • 包含尺度网络和平移网络
  • 支持前向和逆向变换

2. PermutationLayer(置换层)

  • 在耦合层之间混合维度信息
  • 使用固定的随机置换

3. NormalizingFlow(归一化流)

  • 堆叠多个耦合层和置换层
  • 实现编码(x -> z)和解码/采样(z -> x)
  • 计算精确的对数似然

4. MNISTFlowModel(MNIST专用模型)

  • 包装器类,处理图像展平和去量化
  • 提供训练前向和采样接口

8. 避坑小贴士

8.1 数值稳定性问题

问题:流模型中涉及指数运算(exp(s)),容易导致数值溢出或下溢。

解决方案

  • 在尺度网络输出使用Tanh激活,限制范围
  • 使用对数空间计算,避免直接exp
  • 梯度裁剪防止训练不稳定
# 推荐:限制尺度范围
s = torch.tanh(self.scale_net(x1))  # 限制在[-1, 1]

8.2 维度分割策略

问题:简单的前后分割可能导致信息混合不充分。

解决方案

  • 使用棋盘掩码(checkerboard masking)处理图像
  • 使用通道掩码(channel-wise masking)
  • 在层之间加入置换操作

8.3 去量化的重要性

问题:图像数据是离散的(0-255整数),但流模型假设连续分布,直接建模会导致退化分布。

解决方案

  • 添加均匀噪声进行去量化:x = x + u/256, u ~ Uniform(0, 1)
  • 使用变分去量化(Variational Dequantization)
# 简单去量化
x_dequantized = x + torch.rand_like(x) / 256.0

8.4 流模型深度与表达能力

问题:流层数太少导致表达能力不足,太多导致训练困难。

建议

  • 对于MNIST级别数据:4-8层
  • 对于CIFAR-10/ImageNet:需要更复杂的架构(如Glow)
  • 使用多尺度架构降低计算复杂度

8.5 能量模型的采样问题

问题:MCMC采样在EBM训练中可能无法收敛到真实分布。

解决方案

  • 使用持续性对比散度(PCD)
  • 增加采样步数
  • 考虑使用短期MCMC(Short-run MCMC)

9. 本章小结

核心知识点回顾

  1. 归一化流

    • 通过可逆变换将简单分布映射到复杂分布
    • 可以精确计算对数似然
    • 仿射耦合层是核心构建块
  2. 经典架构

    • RealNVP:引入仿射耦合层和多尺度架构
    • Glow:添加可逆1x1卷积和ActNorm
    • Flow++:使用自回归耦合和连续混合分布
  3. 能量模型

    • 通过能量函数定义概率分布
    • RBM和DBM是经典架构
    • 训练需要处理配分函数和采样问题
  4. 得分匹配

    • 避免计算配分函数的训练方法
    • DSM通过去噪目标简化训练
    • 是扩散模型的理论基础
  5. 最新进展

    • Flow Matching:统一框架,训练更稳定
    • Consistency Models:实现快速采样

关键公式总结

变量变换
pX(x)=pZ(f(x))⋅∣det⁡∂f(x)∂x∣p_X(x) = p_Z(f(x)) \cdot \left| \det \frac{\partial f(x)}{\partial x} \right|pX(x)=pZ(f(x)) detxf(x)

仿射耦合层
yd+1:D=xd+1:D⊙exp⁡(s(x1:d))+t(x1:d)y_{d+1:D} = x_{d+1:D} \odot \exp(s(x_{1:d})) + t(x_{1:d})yd+1:D=xd+1:Dexp(s(x1:d))+t(x1:d)

玻尔兹曼分布
p(x)=exp⁡(−E(x))Zp(x) = \frac{\exp(-E(x))}{Z}p(x)=Zexp(E(x))

去噪得分匹配
JDSM(θ)=12E[∥sθ(x~)+x~−xσ2∥2]J_{DSM}(\theta) = \frac{1}{2} \mathbb{E} \left[ \left\| s_\theta(\tilde{x}) + \frac{\tilde{x} - x}{\sigma^2} \right\|^2 \right]JDSM(θ)=21E[ sθ(x~)+σ2x~x 2]

下一步学习建议

  • 尝试在CIFAR-10上实现更复杂的流模型(如Glow)
  • 研究扩散模型与流模型的联系
  • 探索Consistency Models在实际项目中的应用
  • 阅读Flow Matching原始论文,理解其数学框架

本文是《深度学习精通》系列教程的第17章,专注于流模型与能量模型的理论与实践。如需转载,请注明出处。

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐