目录

  1. 概述
  2. 历史发展与背景
  3. 核心思想
  4. DALL-E 1 架构详解
  5. DALL-E 2 架构详解
  6. DALL-E 3 架构详解
  7. 数学原理
  8. 关键技术深度解析
  9. 训练方法
  10. 应用场景
  11. 代码实现
  12. 对比分析
  13. 参考资料

1. 概述

DALL-E 是 OpenAI 开发的一系列文本到图像生成模型,能够根据自然语言描述生成高质量、多样化的图像。

命名由来

DALL-E 的名字来源于:

  • Salvador Dalí - 超现实主义画家
  • WALL·E - 皮克斯动画电影中的机器人

核心能力

文本输入: "一只戴着贝雷帽的柴犬在巴黎咖啡馆喝咖啡"
                    ↓
              [DALL-E模型]
                    ↓
图像输出: 生成符合描述的逼真图像

发展历程

版本 发布时间 核心技术 分辨率 主要突破
DALL-E 1 2021.1 dVAE + Transformer 256×256 首次大规模文本到图像生成
DALL-E 2 2022.4 CLIP + Diffusion 1024×1024 照片级真实感
DALL-E 3 2023.9 改进的Diffusion + ChatGPT 1024×1024 更好的文本理解

2. 历史发展与背景

2.1 图像生成的演进

早期方法 (2014-2018)

GAN (Generative Adversarial Networks)

生成器 G: 噪声 z → 图像 x
判别器 D: 图像 x → 真/假

训练目标:
min_G max_D V(D,G) = E[logD(x)] + E[log(1-D(G(z)))]

问题

  • 训练不稳定(模式崩溃)
  • 难以生成多样化图像
  • 文本控制能力弱
VAE (Variational Autoencoder)
编码器: x → μ, σ → z ~ N(μ, σ²)
解码器: z → x'

损失 = 重建损失 + KL散度

问题

  • 生成图像模糊
  • 表达能力有限

2.2 预训练语言模型的启示

GPT系列的成功

  • GPT-2 (2019): 1.5B参数
  • GPT-3 (2020): 175B参数

关键洞察

  • 大规模预训练 + 大数据 = 强大的生成能力
  • Transformer架构的可扩展性
  • 自回归建模的有效性

2.3 视觉-语言预训练

CLIP (2021.1) - Contrastive Language-Image Pre-training

  • 4亿图文对训练
  • 学习通用的视觉-语言表示
  • 零样本迁移能力

ALIGN (2021.2) - Google的类似工作

  • 18亿图文对
  • 更大规模的预训练

2.4 DALL-E的诞生背景

时机成熟

  1. 大规模图文数据集(LAION等)
  2. Transformer架构成熟
  3. 计算资源充足(数千GPU)
  4. CLIP提供强大的多模态表示

3. 核心思想

3.1 图像作为离散Token

核心创新:将图像表示为离散的token序列

原始图像 (256×256×3)
        ↓
    [dVAE编码器]
        ↓
离散token序列 (32×32 = 1024个token)
        ↓
    [Transformer处理]
        ↓
    [dVAE解码器]
        ↓
重建图像

为什么用离散token?

  1. 与文本token统一处理
  2. 便于自回归建模
  3. 离散空间更易优化
  4. 与语言模型架构兼容

3.2 两阶段训练范式

第一阶段:学习视觉词汇表

图像 → dVAE编码器 → 离散码本 → dVAE解码器 → 图像
         ↓
学习32×32=1024个视觉token

第二阶段:学习图文对应

文本token + 图像token → Transformer → 预测下一个图像token
         ↓
学习文本到图像的映射

3.3 自回归图像生成

DALL-E 1的核心思想

给定文本: "一只红色的苹果"
生成过程:
Step 1: 预测token[0,0] → 可能是"红色区域"
Step 2: 预测token[0,1] → 可能是"苹果轮廓"
Step 3: 预测token[0,2] → ...
...
Step 1024: 预测token[31,31] → 完成图像

与语言模型的类比

语言模型: P(w_t | w_1, ..., w_{t-1})
DALL-E:   P(v_t | text, v_1, ..., v_{t-1})

其中:
- w: 文本token
- v: 视觉token
- text: 文本描述

3.4 扩散模型的引入

DALL-E 2的关键转变

DALL-E 1: 自回归生成 (逐token预测)
DALL-E 2: 扩散模型 (从噪声逐步去噪)

扩散过程: x_0 → x_1 → ... → x_T (加噪)
去噪过程: x_T → x_{T-1} → ... → x_0 (生成)

为什么转向扩散?

  1. 生成质量更高
  2. 训练更稳定
  3. 采样更可控
  4. 与CLIP表示更好地结合

3.5 层次化生成

DALL-E 2的层次结构

文本 → [CLIP文本编码器] → 文本嵌入
                            ↓
                    [先验模型 P(z|y)] → 图像嵌入
                            ↓
                    [解码器 P(x|z)] → 图像

思想

  • 先生成"图像概念"(嵌入)
  • 再生成具体图像(像素)
  • 分层处理,降低复杂度

4. DALL-E 1 架构详解

4.1 整体架构

┌─────────────────────────────────────────────────────────────────┐
│                        DALL-E 1 整体架构                         │
├─────────────────────────────────────────────────────────────────┤
│                                                                  │
│  文本输入: "A red apple"                                         │
│       ↓                                                          │
│  ┌─────────────────────────────────────────┐                    │
│  │         BPE Tokenizer                    │                    │
│  │    "A red apple" → [47, 234, 1234]       │                    │
│  └─────────────────────────────────────────┘                    │
│       ↓                                                          │
│  ┌─────────────────────────────────────────┐                    │
│  │    Transformer Encoder (文本)            │                    │
│  │    输入: 256个文本token                   │                    │
│  │    输出: 文本嵌入                         │                    │
│  └─────────────────────────────────────────┘                    │
│       ↓                                                          │
│  ┌─────────────────────────────────────────┐                    │
│  │    Transformer Decoder (图像)            │                    │
│  │    输入: 文本嵌入 + 已生成的图像token      │                    │
│  │    输出: 下一个图像token的概率分布         │                    │
│  └─────────────────────────────────────────┘                    │
│       ↓                                                          │
│  ┌─────────────────────────────────────────┐                    │
│  │         dVAE Decoder                     │                    │
│  │    将1024个token解码为256×256图像         │                    │
│  └─────────────────────────────────────────┘                    │
│       ↓                                                          │
│  输出图像: 256×256 RGB                                           │
│                                                                  │
└─────────────────────────────────────────────────────────────────┘

4.2 dVAE (Discrete Variational Autoencoder)

4.2.1 架构细节

编码器:

输入: 图像 x ∈ R^{256×256×3}
    ↓
Conv2d(3, 64, 3, stride=2) → ReLU
    ↓  (128×128×64)
Conv2d(64, 128, 3, stride=2) → ReLU
    ↓  (64×64×128)
Conv2d(128, 256, 3, stride=2) → ReLU
    ↓  (32×32×256)
Conv2d(256, 512, 3, stride=1) → ReLU
    ↓  (32×32×512)
Residual Blocks × N
    ↓
Conv2d(512, 8192, 1)  # 8192 = 码本大小
    ↓
Argmax (离散化)
    ↓
输出: 离散token序列 z ∈ {0,1,...,8191}^{32×32}

码本 (Codebook):

码本大小: K = 8192
每个码本向量维度: d = 512

码本 = {e_1, e_2, ..., e_K} ∈ R^{K×512}

离散化过程:
对于编码器输出的每个位置的向量 h:
    找到最近的码本向量: k* = argmin_k ||h - e_k||²
    用码本向量替代: z = e_{k*}

解码器:

输入: 离散token z ∈ {0,...,8191}^{32×32}
    ↓
查码本: z → e_z ∈ R^{32×32×512}
    ↓
Residual Blocks × N
    ↓
ConvTranspose2d(512, 256, 3, stride=1) → ReLU
    ↓  (32×32×256)
ConvTranspose2d(256, 128, 4, stride=2) → ReLU
    ↓  (64×64×128)
ConvTranspose2d(128, 64, 4, stride=2) → ReLU
    ↓  (128×128×64)
ConvTranspose2d(64, 3, 4, stride=2) → Sigmoid
    ↓  (256×256×3)
输出: 重建图像 x' ∈ [0,1]^{256×256×3}
4.2.2 训练目标

dVAE损失函数:

L d V A E = L r e c o n + L c o m m i t \mathcal{L}_{dVAE} = \mathcal{L}_{recon} + \mathcal{L}_{commit} LdVAE=Lrecon+Lcommit

重建损失:
L r e c o n = ∣ ∣ x − D ( E ( x ) ) ∣ ∣ 2 \mathcal{L}_{recon} = ||x - D(E(x))||^2 Lrecon=∣∣xD(E(x))2

承诺损失 (让编码器输出靠近码本):
L c o m m i t = ∣ ∣ s g [ E ( x ) ] − z ∣ ∣ 2 \mathcal{L}_{commit} = ||sg[E(x)] - z||^2 Lcommit=∣∣sg[E(x)]z2

其中 s g [ ⋅ ] sg[\cdot] sg[] 是stop gradient,防止梯度流回编码器。

码本更新:
使用指数移动平均 (EMA) 更新码本:
e k n e w = α ⋅ e k o l d + ( 1 − α ) ⋅ h ˉ k e_k^{new} = \alpha \cdot e_k^{old} + (1-\alpha) \cdot \bar{h}_k eknew=αekold+(1α)hˉk

其中 h ˉ k \bar{h}_k hˉk 是分配到码本k的所有编码器输出的均值。

4.3 Transformer部分

4.3.1 输入构造

DALL-E 1的输入序列:

序列长度 = 文本token数 + 图像token数
         = 256 + 1024 = 1280

输入构造:
[文本token_1, 文本token_2, ..., 文本token_256,
 图像token_1, 图像token_2, ..., 图像token_1024]

嵌入方式:
- 文本token: 学习的嵌入 + 位置编码
- 图像token: 码本嵌入 + 位置编码
- 类型编码: 区分文本/图像token
4.3.2 Transformer架构

模型规模:

  • 参数量: 12B
  • 层数: 64
  • 隐藏维度: 6144
  • 注意力头数: 64
  • 每头维度: 96

架构细节:

输入: x ∈ R^{1280×6144}
    ↓
┌─────────────────────────────────────────┐
│         Transformer Block × 64          │
│  ┌───────────────────────────────────┐  │
│  │      Multi-Head Self-Attention     │  │
│  │  Q = xW_Q, K = xW_K, V = xW_V    │  │
│  │  Attention = softmax(QK^T/√d)V    │  │
│  │  输出 = Concat(heads)W_O          │  │
│  └───────────────────────────────────┘  │
│              ↓ + Residual               │
│           LayerNorm                     │
│              ↓                          │
│  ┌───────────────────────────────────┐  │
│  │         Feed-Forward Network       │  │
│  │  FFN = Linear → GELU → Linear     │  │
│  │  隐藏层维度: 6144 × 4 = 24576     │  │
│  └───────────────────────────────────┘  │
│              ↓ + Residual               │
│           LayerNorm                     │
└─────────────────────────────────────────┘
    ↓
输出: y ∈ R^{1280×6144}
4.3.3 因果注意力掩码

问题: 图像token应该能"看到"所有文本token,但只能看到之前的图像token

解决方案: 稀疏注意力掩码

注意力掩码矩阵 (1280×1280):

           文本token        图像token
          [0,1,...,255]    [256,...,1279]
文本 [0]   [1 1 1 ... 1]   [1 1 1 ... 1]
     [1]   [1 1 1 ... 1]   [1 1 1 ... 1]
    ...    [1 1 1 ... 1]   [1 1 1 ... 1]
   [255]   [1 1 1 ... 1]   [1 1 1 ... 1]

图像 [256]  [1 1 1 ... 1]   [1 0 0 ... 0]
     [257]  [1 1 1 ... 1]   [1 1 0 ... 0]
    ...    [1 1 1 ... 1]   [1 1 1 ... 0]
   [1279]  [1 1 1 ... 1]   [1 1 1 ... 1]

1 = 允许注意力
0 = 禁止注意力

实现:

def create_dalle_mask(text_len=256, image_len=1024):
    total_len = text_len + image_len
    mask = torch.zeros(total_len, total_len)

    # 文本token可以看到所有文本token
    mask[:text_len, :text_len] = 1

    # 图像token可以看到所有文本token
    mask[text_len:, :text_len] = 1

    # 图像token可以看到之前的图像token(因果)
    for i in range(image_len):
        mask[text_len + i, text_len:text_len + i + 1] = 1

    return mask

4.4 训练过程

4.4.1 第一阶段:训练dVAE
数据集: 约6亿图文对
训练时间: 约2周 (256 V100 GPUs)
目标: 学习图像的离散表示

训练循环:
for image, text in dataloader:
    # 编码
    logits = encoder(image)  # [batch, 32, 32, 8192]
    z = gumbel_softmax(logits, tau=1.0)  # 可微分离散化
    z_q = z @ codebook  # 查码本

    # 解码
    reconstruction = decoder(z_q)

    # 计算损失
    loss = reconstruction_loss + commitment_loss

    # 更新
    loss.backward()
    optimizer.step()
    update_codebook_ema(z, codebook)
4.4.2 第二阶段:训练Transformer
数据集: 约2.5亿图文对(经过滤)
训练时间: 约2周 (1024 A100 GPUs)
目标: 学习文本到图像token的映射

训练循环:
for image, text in dataloader:
    # tokenize
    text_tokens = bpe_tokenize(text)  # [batch, 256]
    image_tokens = dVAE_encode(image)  # [batch, 32, 32] → [batch, 1024]

    # 拼接
    tokens = concat(text_tokens, image_tokens)  # [batch, 1280]

    # 前向传播
    logits = transformer(tokens)  # [batch, 1280, 8192]

    # 计算损失(只对图像token)
    loss = cross_entropy(logits[:, 256:], image_tokens)

    # 更新
    loss.backward()
    optimizer.step()

4.5 生成过程

4.5.1 自回归采样
输入: "A red apple"

Step 1: 文本编码
    text_tokens = tokenize("A red apple")
    text_embeddings = transformer.encode(text_tokens)

Step 2: 自回归生成图像token
    image_tokens = []
    for i in range(1024):
        # 构造当前序列
        current_seq = concat(text_tokens, image_tokens)

        # 预测下一个token
        logits = transformer(current_seq)
        next_token = sample(logits[:, -1, :])  # 采样
        image_tokens.append(next_token)

Step 3: 解码图像
    image = dVAE.decode(image_tokens)
    image = image.reshape(256, 256, 3)

输出: 256×256 RGB图像
4.5.2 采样策略

Top-k采样:

def top_k_sampling(logits, k=100):
    # 只保留top-k个最高概率的token
    top_k_values, top_k_indices = torch.topk(logits, k)

    # 重新归一化
    probs = F.softmax(top_k_values, dim=-1)

    # 采样
    idx = torch.multinomial(probs, num_samples=1)

    # 返回原始token索引
    return top_k_indices[idx]

Classifier-Free Guidance:

def cfg_sampling(logits_cond, logits_uncond, scale=3.0):
    # 条件生成 - 无条件生成
    logits = logits_uncond + scale * (logits_cond - logits_uncond)
    return logits
4.5.3 CLIP重排序

问题: 自回归生成可能产生多样化的结果

解决方案: 使用CLIP对多个候选图像排序

生成K个候选图像 (K=512)
    ↓
对于每个候选图像:
    计算CLIP相似度: score = CLIP(image, text)
    ↓
按相似度排序
    ↓
返回得分最高的图像

5. DALL-E 2 架构详解

5.1 整体架构

┌─────────────────────────────────────────────────────────────────┐
│                        DALL-E 2 整体架构                         │
├─────────────────────────────────────────────────────────────────┤
│                                                                  │
│  文本输入: "A red apple on a table"                              │
│       ↓                                                          │
│  ┌─────────────────────────────────────────┐                    │
│  │      CLIP Text Encoder (冻结)            │                    │
│  │    "A red apple on a table"              │                    │
│  │           ↓                              │                    │
│  │    文本嵌入 y ∈ R^{1024}                 │                    │
│  └─────────────────────────────────────────┘                    │
│       ↓                                                          │
│  ┌─────────────────────────────────────────┐                    │
│  │      Prior Model (扩散先验)              │                    │
│  │    输入: 文本嵌入 y                       │                    │
│  │    输出: 图像嵌入 z ∈ R^{1024}           │                    │
│  └─────────────────────────────────────────┘                    │
│       ↓                                                          │
│  ┌─────────────────────────────────────────┐                    │
│  │      Image Decoder (扩散解码器)          │                    │
│  │    输入: 图像嵌入 z                       │                    │
│  │    输出: 64×64图像                        │                    │
│  └─────────────────────────────────────────┘                    │
│       ↓                                                          │
│  ┌─────────────────────────────────────────┐                    │
│  │      Super-Resolution (超分辨率)         │                    │
│  │    64×64 → 256×256 → 1024×1024          │                    │
│  └─────────────────────────────────────────┘                    │
│       ↓                                                          │
│  输出图像: 1024×1024 RGB                                         │
│                                                                  │
└─────────────────────────────────────────────────────────────────┘

5.2 CLIP组件

5.2.1 CLIP文本编码器

架构: Transformer

  • 层数: 12
  • 隐藏维度: 1024
  • 注意力头数: 16
  • 参数量: ~100M

输入输出:

输入: "A red apple on a table"
    ↓
BPE Tokenization: [SOS, 320, 2268, 3316, 525, 2741, EOS]
    ↓
Transformer编码
    ↓
取[EOS]位置的输出作为文本嵌入
    ↓
输出: y ∈ R^{1024}
5.2.2 CLIP图像编码器

架构: Vision Transformer (ViT)

  • 层数: 24
  • 隐藏维度: 1024
  • 注意力头数: 16
  • 图像分辨率: 224×224
  • Patch大小: 14×14
  • 参数量: ~300M

输入输出:

输入: 图像 x ∈ R^{224×224×3}
    ↓
Patch Embedding: 16×16个patch → 256个token
    ↓
加上[CLS] token和位置编码
    ↓
Transformer编码
    ↓
取[CLS]位置的输出作为图像嵌入
    ↓
输出: z ∈ R^{1024}
5.2.3 CLIP训练目标

对比学习:

对于batch中的N个图文对 {(I_i, T_i)}:

图像嵌入: {z_1, z_2, ..., z_N}
文本嵌入: {y_1, y_2, ..., y_N}

相似度矩阵: S_{ij} = z_i · y_j / (||z_i|| · ||y_j||)

对比损失:
L_clip = -1/N * Σ_i [log(exp(S_{ii}/τ) / Σ_j exp(S_{ij}/τ))]
         -1/N * Σ_j [log(exp(S_{jj}/τ) / Σ_i exp(S_{ij}/τ))]

5.3 先验模型 (Prior Model)

5.3.1 作用
文本嵌入 y → [先验模型] → 图像嵌入 z

作用: 将文本嵌入空间映射到CLIP图像嵌入空间
原因: 文本嵌入和图像嵌入虽然在同一空间,但分布不同
5.3.2 扩散先验架构

输入:

  • 文本嵌入 y ∈ R^{1024}
  • 时间步 t ∈ [0, T]
  • 噪声图像嵌入 z_t ∈ R^{1024}

输出:

  • 预测的噪声 ε 或预测的干净嵌入 z_0

网络结构:

输入: [z_t, y, t]
    ↓
拼接: [z_t; y] ∈ R^{2048}
    ↓
时间步编码: t → sin/cos编码 → 线性层 → t_emb ∈ R^{512}
    ↓
Transformer Block × N
    ↓
输出: ε_θ(z_t, y, t) ∈ R^{1024}

Transformer块:

输入: h ∈ R^{1024}
    ↓
┌─────────────────────────────────────────┐
│         Transformer Block               │
│  ┌───────────────────────────────────┐  │
│  │      Self-Attention                │  │
│  │  Q = hW_Q, K = hW_K, V = hW_V    │  │
│  └───────────────────────────────────┘  │
│              ↓ + Residual               │
│           LayerNorm                     │
│              ↓                          │
│  ┌───────────────────────────────────┐  │
│  │      Cross-Attention               │  │
│  │  Q = hW_Q, K = yW_K, V = yW_V    │  │
│  │  (attend to text embedding)        │  │
│  └───────────────────────────────────┘  │
│              ↓ + Residual               │
│           LayerNorm                     │
│              ↓                          │
│  ┌───────────────────────────────────┐  │
│  │         Feed-Forward               │  │
│  │  FFN = Linear → GELU → Linear     │  │
│  └───────────────────────────────────┘  │
│              ↓ + Residual               │
│           LayerNorm                     │
└─────────────────────────────────────────┘
5.3.3 训练目标

扩散过程:

前向过程 (加噪):
z_t = √(ᾱ_t) * z_0 + √(1-ᾱ_t) * ε
其中 ε ~ N(0, I)

反向过程 (去噪):
预测噪声: ε_θ(z_t, y, t) ≈ ε

损失函数:
L_prior = E_{z_0, ε, t} [||ε - ε_θ(z_t, y, t)||²]

5.4 图像解码器 (Image Decoder)

5.4.1 架构概述
输入: 图像嵌入 z ∈ R^{1024}
    ↓
[扩散模型 UNet]
    ↓
输出: 64×64 RGB图像
5.4.2 UNet架构

编码器路径:

输入: x_t ∈ R^{64×64×3}
    ↓
Conv Block → Downsample → 32×32
    ↓
ResBlock × 2 → Attention → Downsample → 16×16
    ↓
ResBlock × 2 → Attention → Downsample → 8×8
    ↓
ResBlock × 2 → Attention → 8×8

解码器路径:

8×8 → ResBlock × 2 → Attention → Upsample → 16×16
    ↓
Concat(skip connection) → ResBlock × 2 → Attention → Upsample → 32×32
    ↓
Concat(skip connection) → ResBlock × 2 → Attention → Upsample → 64×64
    ↓
Concat(skip connection) → ResBlock × 2 → Conv → 64×64
5.4.3 条件注入

图像嵌入注入:

时间步 t → sin/cos编码 → t_emb
图像嵌入 z → 线性层 → z_emb

在每个ResBlock中:
h = h + Linear(t_emb) + Linear(z_emb)

Classifier-Free Guidance:

def cfg_sample(model, z, scale=7.5):
    # 条件预测
    eps_cond = model(x_t, t, z)

    # 无条件预测 (用零向量替代)
    eps_uncond = model(x_t, t, torch.zeros_like(z))

    # 组合
    eps = eps_uncond + scale * (eps_cond - eps_uncond)

    return eps

5.5 超分辨率模块

5.5.1 64×64 → 256×256

架构: 扩散模型 + UNet

输入: 64×64图像 + 256×256低分辨率图像
    ↓
双线性插值: 64×64 → 256×256 (作为条件)
    ↓
UNet去噪 (条件: 上采样的低分辨率图像)
    ↓
输出: 256×256图像
5.5.2 256×256 → 1024×1024

类似架构:

输入: 256×256图像 + 1024×1024低分辨率图像
    ↓
双线性插值: 256×256 → 1024×1024
    ↓
UNet去噪
    ↓
输出: 1024×1024图像

5.6 完整生成流程

文本: "A red apple on a table"
    ↓
Step 1: CLIP文本编码
    text_embedding = clip_text_encoder(text)  # [1024]
    ↓
Step 2: 扩散先验生成图像嵌入
    image_embedding = diffusion_prior.sample(text_embedding)  # [1024]
    ↓
Step 3: 图像解码器生成64×64图像
    image_64 = decoder.sample(image_embedding)  # [64, 64, 3]
    ↓
Step 4: 超分辨率
    image_256 = sr_64_256.sample(image_64)  # [256, 256, 3]
    image_1024 = sr_256_1024.sample(image_256)  # [1024, 1024, 3]
    ↓
输出: 1024×1024 RGB图像

6. DALL-E 3 架构详解

6.1 核心创新

DALL-E 3的主要改进:

  1. 更好的文本理解: 使用ChatGPT改写描述
  2. 改进的训练数据: 更高质量的图文对
  3. 更精细的控制: 支持更复杂的文本描述

6.2 训练数据优化

6.2.1 问题
问题: 现有图文数据集存在噪声
- 标题不准确
- 标题过于简单
- 标题与图像不匹配
6.2.2 解决方案: ChatGPT改写
原始标题: "red apple"
    ↓
[ChatGPT改写]
    ↓
改写标题: "A vibrant red apple sitting on a wooden table,
          with a slight reflection visible on its glossy surface,
          photographed in natural lighting"

改写流程:

1. 收集原始图文对
2. 使用ChatGPT改写每个标题
3. 评估改写质量
4. 使用改写后的标题重新训练
6.2.3 改写提示词
Prompt for ChatGPT:
"I need you to help me rewrite image captions to be more
descriptive and detailed. The rewritten caption should:
1. Describe the main subject clearly
2. Include details about colors, textures, materials
3. Mention the background and environment
4. Specify lighting and mood
5. Be in natural, flowing language

Original caption: {original_caption}
Rewritten caption:"

6.3 改进的扩散模型

6.3.1 架构改进

更大的UNet:

  • 更多的注意力层
  • 更大的隐藏维度
  • 更好的条件注入机制

改进的采样器:

  • DPM-Solver++
  • 更少的采样步数
  • 更高的采样质量
6.3.2 文本条件注入

T5文本编码器:

DALL-E 3使用T5-XXL作为文本编码器
- 参数量: 11B
- 更强的文本理解能力
- 更长的上下文窗口

交叉注意力机制:

图像特征: h ∈ R^{H×W×C}
文本特征: t ∈ R^{L×D}

交叉注意力:
Q = hW_Q
K = tW_K
V = tW_V

Attention = softmax(QK^T/√d)V

6.4 生成控制

6.4.1 精细控制
文本: "A cat sitting on a red mat, looking at a bird outside the window"

DALL-E 3能够:
- 准确定位"猫"的位置
- 正确渲染"红色垫子"
- 表现"看向窗外"的动作
- 保持场景的连贯性
6.4.2 风格控制
风格描述: "in the style of Van Gogh"
    ↓
模型能够:
- 应用特定的艺术风格
- 保持内容的准确性
- 融合风格和内容

7. 数学原理

7.1 变分自编码器 (VAE)

7.1.1 目标

学习数据的潜在表示 z z z,使得:

  1. z z z 能够重建原始数据 x x x
  2. z z z 服从简单的先验分布(通常是高斯分布)
7.1.2 变分推断

后验分布:
p ( z ∣ x ) = p ( x ∣ z ) p ( z ) p ( x ) p(z|x) = \frac{p(x|z)p(z)}{p(x)} p(zx)=p(x)p(xz)p(z)

问题: p ( x ) p(x) p(x) 难以计算(需要积分)

解决方案: 用变分分布 q ϕ ( z ∣ x ) q_\phi(z|x) qϕ(zx) 近似 p ( z ∣ x ) p(z|x) p(zx)

7.1.3 ELBO推导

目标: 最大化对数似然 log ⁡ p ( x ) \log p(x) logp(x)

log ⁡ p ( x ) = log ⁡ ∫ p ( x ∣ z ) p ( z ) d z \log p(x) = \log \int p(x|z)p(z)dz logp(x)=logp(xz)p(z)dz

= log ⁡ ∫ p ( x ∣ z ) p ( z ) q ϕ ( z ∣ x ) q ϕ ( z ∣ x ) d z = \log \int \frac{p(x|z)p(z)}{q_\phi(z|x)} q_\phi(z|x)dz =logqϕ(zx)p(xz)p(z)qϕ(zx)dz

≥ ∫ q ϕ ( z ∣ x ) log ⁡ p ( x ∣ z ) p ( z ) q ϕ ( z ∣ x ) d z \geq \int q_\phi(z|x) \log \frac{p(x|z)p(z)}{q_\phi(z|x)}dz qϕ(zx)logqϕ(zx)p(xz)p(z)dz

= E q ϕ ( z ∣ x ) [ log ⁡ p ( x ∣ z ) ] − D K L ( q ϕ ( z ∣ x ) ∣ ∣ p ( z ) ) = \mathbb{E}_{q_\phi(z|x)}[\log p(x|z)] - D_{KL}(q_\phi(z|x) || p(z)) =Eqϕ(zx)[logp(xz)]DKL(qϕ(zx)∣∣p(z))

ELBO (Evidence Lower Bound):
L E L B O = E q ϕ ( z ∣ x ) [ log ⁡ p ( x ∣ z ) ] − D K L ( q ϕ ( z ∣ x ) ∣ ∣ p ( z ) ) \mathcal{L}_{ELBO} = \mathbb{E}_{q_\phi(z|x)}[\log p(x|z)] - D_{KL}(q_\phi(z|x) || p(z)) LELBO=Eqϕ(zx)[logp(xz)]DKL(qϕ(zx)∣∣p(z))

  • 第一项: 重建损失
  • 第二项: KL散度(正则化)
7.1.4 dVAE的特殊性

离散化:

标准VAE: z ~ N(μ, σ²) (连续)
dVAE: z ∈ {0, 1, ..., K-1} (离散)

离散化方法:
1. Gumbel-Softmax: 可微分的近似
2. Straight-Through Estimator: 直通估计

Gumbel-Softmax:
y i = exp ⁡ ( ( log ⁡ π i + g i ) / τ ) ∑ j exp ⁡ ( ( log ⁡ π j + g j ) / τ ) y_i = \frac{\exp((\log \pi_i + g_i)/\tau)}{\sum_j \exp((\log \pi_j + g_j)/\tau)} yi=jexp((logπj+gj)/τ)exp((logπi+gi)/τ)

其中 g i ∼ Gumbel ( 0 , 1 ) g_i \sim \text{Gumbel}(0, 1) giGumbel(0,1) τ \tau τ 是温度参数。

7.2 扩散模型

7.2.1 前向过程

定义: 逐步向数据添加高斯噪声

q ( x t ∣ x t − 1 ) = N ( x t ; 1 − β t x t − 1 , β t I ) q(x_t | x_{t-1}) = \mathcal{N}(x_t; \sqrt{1-\beta_t} x_{t-1}, \beta_t I) q(xtxt1)=N(xt;1βt xt1,βtI)

累积形式:
q ( x t ∣ x 0 ) = N ( x t ; α ˉ t x 0 , ( 1 − α ˉ t ) I ) q(x_t | x_0) = \mathcal{N}(x_t; \sqrt{\bar{\alpha}_t} x_0, (1-\bar{\alpha}_t) I) q(xtx0)=N(xt;αˉt x0,(1αˉt)I)

其中 α t = 1 − β t \alpha_t = 1 - \beta_t αt=1βt α ˉ t = ∏ s = 1 t α s \bar{\alpha}_t = \prod_{s=1}^t \alpha_s αˉt=s=1tαs

重参数化技巧:
x t = α ˉ t x 0 + 1 − α ˉ t ϵ x_t = \sqrt{\bar{\alpha}_t} x_0 + \sqrt{1-\bar{\alpha}_t} \epsilon xt=αˉt x0+1αˉt ϵ

其中 ϵ ∼ N ( 0 , I ) \epsilon \sim \mathcal{N}(0, I) ϵN(0,I)

7.2.2 反向过程

目标: 学习去噪分布 p θ ( x t − 1 ∣ x t ) p_\theta(x_{t-1} | x_t) pθ(xt1xt)

p θ ( x t − 1 ∣ x t ) = N ( x t − 1 ; μ θ ( x t , t ) , σ t 2 I ) p_\theta(x_{t-1} | x_t) = \mathcal{N}(x_{t-1}; \mu_\theta(x_t, t), \sigma_t^2 I) pθ(xt1xt)=N(xt1;μθ(xt,t),σt2I)

均值参数化:
μ θ ( x t , t ) = 1 α t ( x t − β t 1 − α ˉ t ϵ θ ( x t , t ) ) \mu_\theta(x_t, t) = \frac{1}{\sqrt{\alpha_t}} \left( x_t - \frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}} \epsilon_\theta(x_t, t) \right) μθ(xt,t)=αt 1(xt1αˉt βtϵθ(xt,t))

7.2.3 训练目标

简化损失:
L s i m p l e = E x 0 , ϵ , t [ ∣ ∣ ϵ − ϵ θ ( x t , t ) ∣ ∣ 2 ] \mathcal{L}_{simple} = \mathbb{E}_{x_0, \epsilon, t} [||\epsilon - \epsilon_\theta(x_t, t)||^2] Lsimple=Ex0,ϵ,t[∣∣ϵϵθ(xt,t)2]

推导:
从ELBO出发,经过简化得到:
L = E t [ β t 2 2 σ t 2 α t ( 1 − α ˉ t ) ∣ ∣ ϵ − ϵ θ ( x t , t ) ∣ ∣ 2 ] \mathcal{L} = \mathbb{E}_t \left[ \frac{\beta_t^2}{2\sigma_t^2 \alpha_t (1-\bar{\alpha}_t)} ||\epsilon - \epsilon_\theta(x_t, t)||^2 \right] L=Et[2σt2αt(1αˉt)βt2∣∣ϵϵθ(xt,t)2]

简化后去掉系数,得到 L s i m p l e \mathcal{L}_{simple} Lsimple

7.2.4 采样过程

DDPM采样:

for t = T, T-1, ..., 1:
    z ~ N(0, I) if t > 1 else z = 0
    x_{t-1} = μ_θ(x_t, t) + σ_t * z

DDIM采样 (确定性):

x_{t-1} = √(ᾱ_{t-1}) * predicted_x0 + √(1-ᾱ_{t-1}) * predicted_noise

7.3 CLIP对比学习

7.3.1 对比损失

InfoNCE损失:
L C L I P = − 1 N ∑ i = 1 N [ log ⁡ exp ⁡ ( s i i / τ ) ∑ j = 1 N exp ⁡ ( s i j / τ ) + log ⁡ exp ⁡ ( s i i / τ ) ∑ j = 1 N exp ⁡ ( s j i / τ ) ] \mathcal{L}_{CLIP} = -\frac{1}{N} \sum_{i=1}^N \left[ \log \frac{\exp(s_{ii}/\tau)}{\sum_{j=1}^N \exp(s_{ij}/\tau)} + \log \frac{\exp(s_{ii}/\tau)}{\sum_{j=1}^N \exp(s_{ji}/\tau)} \right] LCLIP=N1i=1N[logj=1Nexp(sij/τ)exp(sii/τ)+logj=1Nexp(sji/τ)exp(sii/τ)]

其中:

  • s i j = z i T y j ∣ ∣ z i ∣ ∣ ⋅ ∣ ∣ y j ∣ ∣ s_{ij} = \frac{z_i^T y_j}{||z_i|| \cdot ||y_j||} sij=∣∣zi∣∣∣∣yj∣∣ziTyj 是余弦相似度
  • τ \tau τ 是温度参数
  • z i z_i zi: 图像嵌入
  • y j y_j yj: 文本嵌入
7.3.2 温度参数

作用: 控制分布的锐度

p i j = exp ⁡ ( s i j / τ ) ∑ k exp ⁡ ( s i k / τ ) p_{ij} = \frac{\exp(s_{ij}/\tau)}{\sum_k \exp(s_{ik}/\tau)} pij=kexp(sik/τ)exp(sij/τ)

  • τ → 0 \tau \to 0 τ0: 分布趋向one-hot(硬分配)
  • τ → ∞ \tau \to \infty τ: 分布趋向均匀

CLIP中的温度:

  • 初始化为 τ = 0.07 \tau = 0.07 τ=0.07
  • 作为可学习参数训练

7.4 Classifier-Free Guidance

7.4.1 背景

问题: 标准条件生成可能不够"遵循"条件

解决方案: 同时训练条件和无条件模型

7.4.2 数学公式

训练:

以概率 p_uncond 将条件置空
训练目标: ε_θ(x_t, t, c) 其中c可以是真实条件或空条件

采样:
ϵ ^ θ ( x t , t , c ) = ϵ θ ( x t , t , ∅ ) + w ⋅ ( ϵ θ ( x t , t , c ) − ϵ θ ( x t , t , ∅ ) ) \hat{\epsilon}_\theta(x_t, t, c) = \epsilon_\theta(x_t, t, \emptyset) + w \cdot (\epsilon_\theta(x_t, t, c) - \epsilon_\theta(x_t, t, \emptyset)) ϵ^θ(xt,t,c)=ϵθ(xt,t,)+w(ϵθ(xt,t,c)ϵθ(xt,t,))

其中:

  • w w w 是引导强度(通常7-12)
  • ∅ \emptyset 是空条件
  • c c c 是真实条件
7.4.3 直觉
无条件生成: 倾向于生成多样但可能不相关的图像
条件生成: 生成符合条件的图像

Classifier-Free Guidance: 放大条件的影响
w=1: 标准条件生成
w>1: 更强的条件遵循
w<1: 更弱的条件遵循

7.5 交叉注意力机制

7.5.1 定义

自注意力 (Self-Attention):
Q = X W Q , K = X W K , V = X W V Q = XW_Q, K = XW_K, V = XW_V Q=XWQ,K=XWK,V=XWV

交叉注意力 (Cross-Attention):
Q = X W Q , K = C W K , V = C W V Q = XW_Q, K = CW_K, V = CW_V Q=XWQ,K=CWK,V=CWV

其中 X X X 是图像特征, C C C 是文本条件

7.5.2 作用
图像特征 X: "我需要知道如何生成这个像素"
文本条件 C: "这是条件信息"
交叉注意力: 图像特征"查询"文本条件,获取相关信息

8. 关键技术深度解析

8.1 Tokenization

8.1.1 BPE (Byte Pair Encoding)

DALL-E 1的文本Tokenization:

输入: "A red apple"
    ↓
字符级: ['A', ' ', 'r', 'e', 'd', ' ', 'a', 'p', 'p', 'l', 'e']
    ↓
BPE合并:
1. 最频繁的pair: 'p','p' → 'pp'
2. 下一个频繁pair: 'ap','ple' → 'apple'
3. ...
    ↓
输出: ['A', ' ', 'red', ' ', 'apple']
    ↓
映射到词表: [47, 234, 1234, 234, 5678]
8.1.2 图像Tokenization

dVAE Tokenization:

图像 (256×256×3)
    ↓
编码器
    ↓
特征图 (32×32×512)
    ↓
与码本比较 (8192个向量)
    ↓
每个位置选择最近的码本索引
    ↓
输出: 32×32 = 1024个token ∈ {0,1,...,8191}

8.2 位置编码

8.2.1 绝对位置编码

正弦位置编码:
P E ( p o s , 2 i ) = sin ⁡ ( p o s / 10000 2 i / d ) PE_{(pos, 2i)} = \sin(pos / 10000^{2i/d}) PE(pos,2i)=sin(pos/100002i/d)
P E ( p o s , 2 i + 1 ) = cos ⁡ ( p o s / 10000 2 i / d ) PE_{(pos, 2i+1)} = \cos(pos / 10000^{2i/d}) PE(pos,2i+1)=cos(pos/100002i/d)

DALL-E中的应用:

文本token: 0-255的位置编码
图像token: 0-1023的位置编码

注意: 文本和图像使用不同的位置编码
8.2.2 2D位置编码

图像的位置编码:

对于32×32的图像token:
- 行编码: 0-31
- 列编码: 0-31

位置 = 行编码 + 列编码
或
位置 = [行编码; 列编码] 然后投影

8.3 注意力机制优化

8.3.1 稀疏注意力

DALL-E 1的稀疏注意力:

问题: 全注意力复杂度 O(n²) 太高
解决: 使用稀疏注意力模式

注意力模式:
1. 文本token: 全注意力(互相可见)
2. 图像token: 因果注意力(只能看到之前的)
3. 跨模态: 图像可以看到所有文本
8.3.2 Flash Attention

优化注意力计算:

传统注意力:
1. 计算Q, K, V
2. 计算注意力矩阵 A = QK^T
3. 应用softmax
4. 计算输出 O = AV

问题: A需要 O(n²) 内存

Flash Attention:
1. 分块计算
2. 在线softmax
3. 避免存储完整A
4. 内存复杂度 O(n)

8.4 采样策略

8.4.1 Top-k采样
def top_k_sampling(logits, k=50):
    # 只保留top-k
    values, indices = torch.topk(logits, k)

    # 重新归一化
    probs = F.softmax(values, dim=-1)

    # 采样
    idx = torch.multinomial(probs, 1)

    return indices[idx]
8.4.2 Top-p (Nucleus)采样
def top_p_sampling(logits, p=0.9):
    # 排序
    sorted_logits, sorted_indices = torch.sort(logits, descending=True)

    # 累积概率
    cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

    # 找到截断点
    mask = cumulative_probs - F.softmax(sorted_logits, dim=-1) >= p

    # 设置概率为0
    sorted_logits[mask] = float('-inf')

    # 重新采样
    probs = F.softmax(sorted_logits, dim=-1)
    idx = torch.multinomial(probs, 1)

    return sorted_indices[idx]
8.4.3 温度采样
def temperature_sampling(logits, temperature=1.0):
    # 调整温度
    logits = logits / temperature

    # 采样
    probs = F.softmax(logits, dim=-1)
    idx = torch.multinomial(probs, 1)

    return idx

8.5 扩散采样器

8.5.1 DDPM采样器
def ddpm_sample(model, x_T, T, betas):
    x = x_T
    for t in reversed(range(T)):
        # 预测噪声
        eps_pred = model(x, t)

        # 计算均值
        alpha_t = 1 - betas[t]
        alpha_bar_t = torch.prod(1 - betas[:t+1])

        mean = (1/np.sqrt(alpha_t)) * (
            x - (betas[t]/np.sqrt(1 - alpha_bar_t)) * eps_pred
        )

        # 添加噪声
        if t > 0:
            noise = torch.randn_like(x)
            x = mean + np.sqrt(betas[t]) * noise
        else:
            x = mean

    return x
8.5.2 DDIM采样器
def ddim_sample(model, x_T, T, steps, eta=0.0):
    # 选择时间步子集
    times = torch.linspace(0, T-1, steps).long()

    x = x_T
    for i in reversed(range(len(times))):
        t = times[i]
        t_prev = times[i-1] if i > 0 else torch.tensor(0)

        # 预测噪声
        eps_pred = model(x, t)

        # 计算x_0
        alpha_bar_t = torch.prod(1 - betas[:t+1])
        x_0_pred = (x - np.sqrt(1 - alpha_bar_t) * eps_pred) / np.sqrt(alpha_bar_t)

        # 计算x_{t-1}
        alpha_bar_t_prev = torch.prod(1 - betas[:t_prev+1]) if t_prev > 0 else 1.0

        sigma_t = eta * np.sqrt((1 - alpha_bar_t_prev) / (1 - alpha_bar_t) *
                   (1 - alpha_bar_t / alpha_bar_t_prev))

        dir_xt = np.sqrt(1 - alpha_bar_t_prev - sigma_t**2) * eps_pred

        x = np.sqrt(alpha_bar_t_prev) * x_0_pred + dir_xt

        if i > 0:
            x = x + sigma_t * torch.randn_like(x)

    return x

9. 训练方法

9.1 DALL-E 1训练

9.1.1 数据准备

数据集: 约6亿图文对

数据来源:

  • 互联网爬取
  • 公开数据集
  • 人工标注

数据过滤:

1. 删除低质量图像
2. 过滤不相关图文对
3. 删除重复数据
4. 过滤有害内容
9.1.2 两阶段训练

第一阶段: dVAE训练

训练数据: 6亿图像
训练时间: ~2周
硬件: 256 × V100 GPUs
批大小: 2048
学习率: 1e-4
优化器: Adam (β1=0.9, β2=0.999)

第二阶段: Transformer训练

训练数据: 2.5亿图文对(过滤后)
训练时间: ~2周
硬件: 1024 × A100 GPUs
批大小: 1024
学习率: 1e-4 → 1e-5 (cosine decay)
优化器: Adam (β1=0.9, β2=0.95)
9.1.3 训练技巧

梯度累积:

# 由于批大小较大,使用梯度累积
accumulation_steps = 8
effective_batch_size = batch_size * accumulation_steps

for i, batch in enumerate(dataloader):
    loss = model(batch) / accumulation_steps
    loss.backward()

    if (i + 1) % accumulation_steps == 0:
        optimizer.step()
        optimizer.zero_grad()

混合精度训练:

# 使用FP16加速训练
scaler = torch.cuda.amp.GradScaler()

with torch.cuda.amp.autocast():
    output = model(input)
    loss = criterion(output, target)

scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()

9.2 DALL-E 2训练

9.2.1 CLIP预训练
数据集: 4亿图文对 (WebImageText)
训练时间: ~2周
硬件: 256 × V100 GPUs
批大小: 32768
学习率: 1e-4 → 0 (cosine decay)
优化器: AdamW (weight_decay=0.2)
9.2.2 先验模型训练
训练数据: 使用CLIP编码的图文对
训练时间: ~3天
硬件: 64 × A100 GPUs
扩散步数: 1000
学习率: 1e-4
9.2.3 解码器训练
训练数据: 2.5亿图文对
训练时间: ~1周
硬件: 256 × A100 GPUs
扩散步数: 1000
图像分辨率: 64×64
9.2.4 超分辨率训练
阶段1: 64×64 → 256×256
    训练时间: ~3天
    数据: 高分辨率图像

阶段2: 256×256 → 1024×1024
    训练时间: ~3天
    数据: 超高分辨率图像

9.3 DALL-E 3训练

9.3.1 数据增强
1. 使用ChatGPT改写标题
2. 评估改写质量
3. 过滤低质量改写
4. 使用改写后的数据重新训练
9.3.2 改进的训练策略
1. 渐进式训练: 从低分辨率开始
2. 多尺度训练: 同时训练多个分辨率
3. 数据增强: 随机裁剪、翻转等
4. 正则化: Dropout、权重衰减

10. 应用场景

10.1 创意设计

10.1.1 概念设计
应用场景: 产品概念设计
输入: "一款未来感十足的智能手表,全息投影界面"
输出: 多个设计概念图

优势:
- 快速生成多个方案
- 探索设计空间
- 激发创意灵感
10.1.2 插画创作
应用场景: 儿童图书插画
输入: "一只小兔子在森林里迷路了,周围是高大的松树"
输出: 风格一致的插画序列

优势:
- 保持风格一致性
- 快速迭代
- 降低创作成本

10.2 教育领域

10.2.1 教学素材
应用场景: 历史教学
输入: "古罗马角斗士在竞技场战斗的场景"
输出: 逼真的历史场景图

优势:
- 可视化抽象概念
- 增强学习体验
- 节省素材制作成本
10.2.2 科学可视化
应用场景: 生物教学
输入: "人体细胞内部结构,线粒体、细胞核清晰可见"
输出: 科学准确的示意图

优势:
- 准确展示微观结构
- 支持自定义视角
- 便于理解复杂概念

10.3 商业应用

10.3.1 广告设计
应用场景: 电商产品图
输入: "一瓶护肤品放在大理石台面上,周围有鲜花和蜡烛"
输出: 高质量产品展示图

优势:
- 快速生成产品图
- 降低拍摄成本
- 支持A/B测试
10.3.2 社交媒体内容
应用场景: 社交媒体运营
输入: "周末早午餐,阳光透过窗户洒在餐桌上"
输出: Instagram风格的美食照片

优势:
- 快速生成内容
- 保持视觉风格
- 提高更新频率

10.4 游戏开发

10.4.1 概念艺术
应用场景: 游戏美术设计
输入: "一个魔法森林,发光的蘑菇和漂浮的精灵"
输出: 游戏场景概念图

优势:
- 加速前期设计
- 探索多种风格
- 降低美术成本
10.4.2 纹理生成
应用场景: 游戏纹理
输入: "古老的石砖墙壁,有苔藓和裂缝"
输出: 可平铺的纹理贴图

优势:
- 快速生成纹理
- 无限变体
- 可定制化

10.5 科研领域

10.5.1 数据增强
应用场景: 计算机视觉研究
需求: 扩充训练数据集
方法: 使用DALL-E生成带标注的图像

优势:
- 解决数据稀缺问题
- 控制数据分布
- 降低标注成本
10.5.2 可控生成研究
应用场景: 生成模型研究
需求: 研究文本-图像对齐
方法: 分析DALL-E的生成结果

研究方向:
- 属性绑定
- 空间关系
- 计数能力
- 组合泛化

11. 代码实现

11.1 简化的dVAE实现

import torch
import torch.nn as nn
import torch.nn.functional as F

class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
        self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
        self.norm1 = nn.GroupNorm(32, channels)
        self.norm2 = nn.GroupNorm(32, channels)

    def forward(self, x):
        residual = x
        x = F.silu(self.norm1(self.conv1(x)))
        x = self.norm2(self.conv2(x))
        return x + residual

class dVAEEncoder(nn.Module):
    def __init__(self, in_channels=3, hidden_channels=256,
                 num_embeddings=8192, num_layers=2):
        super().__init__()

        # 下采样层
        self.conv1 = nn.Conv2d(in_channels, 64, 3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(64, 128, 3, stride=2, padding=1)
        self.conv3 = nn.Conv2d(128, hidden_channels, 3, stride=2, padding=1)

        # 残差块
        self.res_blocks = nn.Sequential(
            *[ResidualBlock(hidden_channels) for _ in range(num_layers)]
        )

        # 输出层
        self.output = nn.Conv2d(hidden_channels, num_embeddings, 1)

    def forward(self, x):
        # x: [batch, 3, 256, 256]
        x = F.silu(self.conv1(x))  # [batch, 64, 128, 128]
        x = F.silu(self.conv2(x))  # [batch, 128, 64, 64]
        x = F.silu(self.conv3(x))  # [batch, 256, 32, 32]
        x = self.res_blocks(x)
        x = self.output(x)  # [batch, 8192, 32, 32]
        return x

class dVAEDecoder(nn.Module):
    def __init__(self, num_embeddings=8192, hidden_channels=256,
                 out_channels=3, num_layers=2):
        super().__init__()

        # 输入层
        self.input = nn.Conv2d(num_embeddings, hidden_channels, 1)

        # 残差块
        self.res_blocks = nn.Sequential(
            *[ResidualBlock(hidden_channels) for _ in range(num_layers)]
        )

        # 上采样层
        self.conv1 = nn.ConvTranspose2d(hidden_channels, 128, 4, stride=2, padding=1)
        self.conv2 = nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1)
        self.conv3 = nn.ConvTranspose2d(64, out_channels, 4, stride=2, padding=1)

    def forward(self, z):
        # z: [batch, 32, 32] (离散token)
        # 转换为one-hot
        z_one_hot = F.one_hot(z, num_classes=self.codebook_size)
        z_one_hot = z_one_hot.permute(0, 3, 1, 2).float()

        x = self.input(z_one_hot)  # [batch, 256, 32, 32]
        x = self.res_blocks(x)
        x = F.silu(self.conv1(x))  # [batch, 128, 64, 64]
        x = F.silu(self.conv2(x))  # [batch, 64, 128, 128]
        x = torch.sigmoid(self.conv3(x))  # [batch, 3, 256, 256]
        return x

class SimpledVAE(nn.Module):
    def __init__(self, num_embeddings=8192, embedding_dim=512):
        super().__init__()
        self.encoder = dVAEEncoder(num_embeddings=num_embeddings)
        self.decoder = dVAEDecoder(num_embeddings=num_embeddings)
        self.codebook = nn.Embedding(num_embeddings, embedding_dim)

    def quantize(self, logits):
        # logits: [batch, num_embeddings, 32, 32]
        # 找到最近的码本向量
        indices = torch.argmax(logits, dim=1)  # [batch, 32, 32]
        return indices

    def forward(self, x):
        # 编码
        logits = self.encoder(x)  # [batch, 8192, 32, 32]

        # 量化
        z = self.quantize(logits)  # [batch, 32, 32]

        # 解码
        x_recon = self.decoder(z)  # [batch, 3, 256, 256]

        return x_recon, logits, z

11.2 简化的Transformer实现

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads

        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model)
        self.out_proj = nn.Linear(d_model, d_model)

    def forward(self, x, mask=None):
        batch_size, seq_len, _ = x.shape

        # 计算Q, K, V
        Q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        K = self.k_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        V = self.v_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

        # 注意力计算
        attn = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)

        if mask is not None:
            attn = attn.masked_fill(mask == 0, float('-inf'))

        attn = F.softmax(attn, dim=-1)
        out = torch.matmul(attn, V)

        # 合并多头
        out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        out = self.out_proj(out)

        return out

class TransformerBlock(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.attention = MultiHeadAttention(d_model, num_heads)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model),
            nn.Dropout(dropout)
        )

    def forward(self, x, mask=None):
        # 自注意力
        attn_out = self.attention(x, mask)
        x = self.norm1(x + attn_out)

        # FFN
        ffn_out = self.ffn(x)
        x = self.norm2(x + ffn_out)

        return x

class DALLE1Transformer(nn.Module):
    def __init__(self, vocab_size=8192, text_len=256, image_len=1024,
                 d_model=1024, num_heads=16, num_layers=24, d_ff=4096):
        super().__init__()

        self.text_len = text_len
        self.image_len = image_len
        self.total_len = text_len + image_len

        # 嵌入层
        self.text_embedding = nn.Embedding(vocab_size, d_model)
        self.image_embedding = nn.Embedding(vocab_size, d_model)
        self.position_embedding = nn.Embedding(self.total_len, d_model)
        self.type_embedding = nn.Embedding(2, d_model)  # 0=text, 1=image

        # Transformer层
        self.layers = nn.ModuleList([
            TransformerBlock(d_model, num_heads, d_ff)
            for _ in range(num_layers)
        ])

        # 输出层
        self.output = nn.Linear(d_model, vocab_size)

        # 创建因果掩码
        self.register_buffer('mask', self._create_mask())

    def _create_mask(self):
        mask = torch.zeros(self.total_len, self.total_len)

        # 文本token可以看到所有文本token
        mask[:self.text_len, :self.text_len] = 1

        # 图像token可以看到所有文本token
        mask[self.text_len:, :self.text_len] = 1

        # 图像token可以看到之前的图像token
        for i in range(self.image_len):
            mask[self.text_len + i, self.text_len:self.text_len + i + 1] = 1

        return mask.unsqueeze(0).unsqueeze(0)  # [1, 1, total_len, total_len]

    def forward(self, text_tokens, image_tokens=None):
        batch_size = text_tokens.shape[0]

        # 嵌入
        text_emb = self.text_embedding(text_tokens)
        if image_tokens is not None:
            image_emb = self.image_embedding(image_tokens)
        else:
            # 生成时,初始化为零
            image_emb = torch.zeros(batch_size, self.image_len,
                                   text_emb.shape[-1], device=text_tokens.device)

        # 拼接
        x = torch.cat([text_emb, image_emb], dim=1)

        # 添加位置编码
        positions = torch.arange(self.total_len, device=x.device).unsqueeze(0)
        x = x + self.position_embedding(positions)

        # 添加类型编码
        type_ids = torch.zeros(self.total_len, dtype=torch.long, device=x.device)
        type_ids[self.text_len:] = 1
        x = x + self.type_embedding(type_ids)

        # Transformer层
        for layer in self.layers:
            x = layer(x, self.mask)

        # 输出(只取图像部分)
        image_logits = self.output(x[:, self.text_len:])

        return image_logits

    @torch.no_grad()
    def generate(self, text_tokens, temperature=1.0, top_k=100):
        batch_size = text_tokens.shape[0]
        device = text_tokens.device

        # 初始化图像token
        image_tokens = torch.zeros(batch_size, self.image_len, dtype=torch.long, device=device)

        for i in range(self.image_len):
            # 前向传播
            logits = self.forward(text_tokens, image_tokens)

            # 获取下一个token的logits
            next_token_logits = logits[:, i, :] / temperature

            # Top-k采样
            if top_k > 0:
                values, _ = torch.topk(next_token_logits, top_k)
                next_token_logits[next_token_logits < values[:, -1:]] = float('-inf')

            # 采样
            probs = F.softmax(next_token_logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)

            # 更新图像token
            image_tokens[:, i] = next_token.squeeze(-1)

        return image_tokens

11.3 简化的扩散模型实现

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings

class ResNetBlock(nn.Module):
    def __init__(self, in_channels, out_channels, time_emb_dim):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
        self.time_mlp = nn.Linear(time_emb_dim, out_channels)
        self.norm1 = nn.GroupNorm(32, in_channels)
        self.norm2 = nn.GroupNorm(32, out_channels)

        if in_channels != out_channels:
            self.shortcut = nn.Conv2d(in_channels, out_channels, 1)
        else:
            self.shortcut = nn.Identity()

    def forward(self, x, time_emb):
        h = F.silu(self.norm1(x))
        h = self.conv1(h)

        # 添加时间嵌入
        time_emb = F.silu(self.time_mlp(time_emb))
        h = h + time_emb[:, :, None, None]

        h = F.silu(self.norm2(h))
        h = self.conv2(h)

        return h + self.shortcut(x)

class SimpleUNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, time_emb_dim=256):
        super().__init__()

        # 时间嵌入
        self.time_mlp = nn.Sequential(
            SinusoidalPositionEmbeddings(time_emb_dim),
            nn.Linear(time_emb_dim, time_emb_dim),
            nn.SiLU()
        )

        # 编码器
        self.enc1 = ResNetBlock(in_channels, 64, time_emb_dim)
        self.enc2 = ResNetBlock(64, 128, time_emb_dim)
        self.enc3 = ResNetBlock(128, 256, time_emb_dim)

        # 下采样
        self.downsample = nn.MaxPool2d(2)

        # 中间层
        self.mid = ResNetBlock(256, 256, time_emb_dim)

        # 解码器
        self.dec3 = ResNetBlock(256 + 256, 128, time_emb_dim)
        self.dec2 = ResNetBlock(128 + 128, 64, time_emb_dim)
        self.dec1 = ResNetBlock(64 + 64, 64, time_emb_dim)

        # 上采样
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear')

        # 输出层
        self.out = nn.Conv2d(64, out_channels, 1)

    def forward(self, x, time):
        # 时间嵌入
        time_emb = self.time_mlp(time)

        # 编码器
        e1 = self.enc1(x, time_emb)
        e2 = self.enc2(self.downsample(e1), time_emb)
        e3 = self.enc3(self.downsample(e2), time_emb)

        # 中间层
        mid = self.mid(self.downsample(e3), time_emb)

        # 解码器
        d3 = self.dec3(torch.cat([self.upsample(mid), e3], dim=1), time_emb)
        d2 = self.dec2(torch.cat([self.upsample(d3), e2], dim=1), time_emb)
        d1 = self.dec1(torch.cat([self.upsample(d2), e1], dim=1), time_emb)

        # 输出
        out = self.out(d1)

        return out

class SimpleDiffusion:
    def __init__(self, num_timesteps=1000, beta_start=0.0001, beta_end=0.02):
        self.num_timesteps = num_timesteps

        # 定义beta schedule
        self.betas = torch.linspace(beta_start, beta_end, num_timesteps)
        self.alphas = 1 - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1 - self.alphas_cumprod)

    def add_noise(self, x_0, t, noise=None):
        """前向过程:添加噪声"""
        if noise is None:
            noise = torch.randn_like(x_0)

        sqrt_alpha = self.sqrt_alphas_cumprod[t][:, None, None, None]
        sqrt_one_minus_alpha = self.sqrt_one_minus_alphas_cumprod[t][:, None, None, None]

        return sqrt_alpha * x_0 + sqrt_one_minus_alpha * noise, noise

    def sample(self, model, shape, device):
        """反向过程:采样"""
        x = torch.randn(shape, device=device)

        for t in reversed(range(self.num_timesteps)):
            t_batch = torch.full((shape[0],), t, device=device, dtype=torch.long)

            # 预测噪声
            predicted_noise = model(x, t_batch)

            # 计算去噪后的x
            alpha = self.alphas[t]
            alpha_cumprod = self.alphas_cumprod[t]
            beta = self.betas[t]

            if t > 0:
                noise = torch.randn_like(x)
            else:
                noise = torch.zeros_like(x)

            x = (1 / torch.sqrt(alpha)) * (
                x - ((1 - alpha) / torch.sqrt(1 - alpha_cumprod)) * predicted_noise
            ) + torch.sqrt(beta) * noise

        return x

11.4 使用HuggingFace Diffusers

from diffusers import StableDiffusionPipeline
import torch

# 加载预训练模型(类似DALL-E的开源实现)
pipe = StableDiffusionPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5",
    torch_dtype=torch.float16
)
pipe = pipe.to("cuda")

# 生成图像
prompt = "A red apple on a wooden table, photorealistic"
image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]

# 保存图像
image.save("generated_apple.png")

12. 对比分析

12.1 DALL-E系列对比

特性 DALL-E 1 DALL-E 2 DALL-E 3
发布时间 2021.1 2022.4 2023.9
核心架构 dVAE + Transformer CLIP + Diffusion 改进Diffusion + ChatGPT
生成方式 自回归 扩散 扩散
分辨率 256×256 1024×1024 1024×1024
参数量 12B ~5B 未公开
文本理解 BPE + Transformer CLIP T5 + ChatGPT改写
生成质量 良好 优秀 卓越
文本遵循 中等 良好 优秀
多样性 中等 中等
训练数据 6亿图文对 4亿图文对 改进的图文对

12.2 与其他模型对比

模型 公司 架构 分辨率 特点
DALL-E 2 OpenAI CLIP + Diffusion 1024×1024 照片级真实感
Stable Diffusion Stability AI Latent Diffusion 512×512+ 开源、高效
Midjourney Midjourney 未公开 1024×1024+ 艺术风格强
Imagen Google T5 + Diffusion 1024×1024 文本理解强
Parti Google 自回归 256×256 高保真度
CogView 清华 自回归 256×256 中文支持

12.3 技术路线对比

自回归 vs 扩散
方面 自回归 (DALL-E 1) 扩散 (DALL-E 2/3)
生成质量 中等
生成速度 慢(逐token) 快(并行去噪)
训练稳定性 中等
可控性 中等
多样性 中等
离散 vs 连续
方面 离散表示 (dVAE) 连续表示 (Latent)
信息保留 有损 更完整
训练难度 中等 较高
与语言模型兼容 需要适配
生成质量 中等

13. 参考资料

核心论文

  1. DALL-E 1:

    • Ramesh et al. (2021). “Zero-Shot Text-to-Image Generation”
    • 链接: https://arxiv.org/abs/2102.12092
  2. DALL-E 2:

    • Ramesh et al. (2022). “Hierarchical Text-Conditional Image Generation with CLIP Latents”
    • 链接: https://arxiv.org/abs/2204.06125
  3. DALL-E 3:

    • Betker et al. (2023). “Improving Image Generation with Better Captions”
    • 链接: https://cdn.openai.com/papers/dall-e-3.pdf

相关技术论文

  1. CLIP:

    • Radford et al. (2021). “Learning Transferable Visual Models From Natural Language Supervision”
    • 链接: https://arxiv.org/abs/2103.00020
  2. Diffusion Models:

    • Ho et al. (2020). “Denoising Diffusion Probabilistic Models”
    • 链接: https://arxiv.org/abs/2006.11239
  3. DDIM:

    • Song et al. (2020). “Denoising Diffusion Implicit Models”
    • 链接: https://arxiv.org/abs/2010.02502
  4. Classifier-Free Guidance:

    • Ho & Salimans (2022). “Classifier-Free Diffusion Guidance”
    • 链接: https://arxiv.org/abs/2207.12598

开源实现

  1. DALL-E Flow:

    • https://github.com/jina-ai/dalle-flow
  2. Stable Diffusion:

    • https://github.com/CompVis/stable-diffusion
  3. Diffusers:

    • https://github.com/huggingface/diffusers

推荐阅读

  1. Lilian Weng’s Blog: “What are Diffusion Models?”
  2. Jay Alammar’s Blog: “The Illustrated Stable Diffusion”
  3. OpenAI Blog: “DALL-E 2 Preview”, “DALL-E 3 System Card”

总结

DALL-E系列代表了文本到图像生成技术的重大突破:

核心创新

  1. DALL-E 1: 首次将图像表示为离散token,实现大规模文本到图像生成
  2. DALL-E 2: 引入CLIP和扩散模型,实现照片级真实感
  3. DALL-E 3: 通过ChatGPT改写标题,大幅提升文本理解能力

技术演进

自回归 (DALL-E 1) → 扩散 (DALL-E 2) → 改进扩散 (DALL-E 3)

离散token → 连续潜空间 → 改进的潜空间

简单文本理解 → CLIP理解 → ChatGPT增强理解

未来方向

  1. 更高分辨率: 4K、8K图像生成
  2. 更长文本: 支持更复杂的描述
  3. 更强控制: 精确的属性控制
  4. 视频生成: 从图像到视频
  5. 3D生成: 从2D到3D
  6. 多模态理解: 更深入的图文理解
Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐