【扩散模型系列·第六篇】条件生成:文本引导、交叉注意力与 Classifier-Free Guidance 全解析

作者:技术博主 | 更新时间:2026-05-20 | 阅读时长:约 23 分钟
系列:扩散模型从零到实战(共 8 篇)
环境:Python 3.12 + PyTorch 2.x + transformers + diffusers
标签扩散模型 条件生成 CLIP 交叉注意力 CFG Classifier-Free Guidance 文生图 文本编码


在这里插入图片描述

🔥 本篇目标:前五篇的扩散模型是"无条件"的——从纯噪声生成图像,但无法控制生成什么。本篇解决这个问题:如何让扩散模型根据文字描述生成图像? 完整链路包含三个关键:用 CLIP/T5 把文字编码成向量、用交叉注意力把文字注入 U-Net 每一层、以及用 Classifier-Free Guidance(CFG)大幅提升文字对齐度。理解了 CFG,你就明白 Stable Diffusion 里 guidance_scale=7.5 这个数字的含义。


系列进度

篇次 主题 状态
第一篇 扩散模型是什么 ✅ 已发布
第二篇 数学基础:前向过程 ✅ 已发布
第三篇 反向过程:学习去噪 ✅ 已发布
第四篇 U-Net 架构 ✅ 已发布
第五篇 DDIM:加速采样 ✅ 已发布
第六篇(本篇) 条件生成:文本引导与 CFG
第七篇 Stable Diffusion:潜在扩散模型 即将发布
第八篇 实战与前沿:LoRA、DreamBooth、FLUX 即将发布

目录


一、条件生成的三种范式

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np

# 让扩散模型"听话"的三种方式:
#
# ① 条件拼接(Concatenation)
#    把条件(如类别标签的 one-hot 编码)直接拼接到输入
#    简单,但对复杂条件(如自然语言)效果有限
#
# ② 自适应归一化(Adaptive Normalization,AdaGN)
#    条件向量转化为 scale 和 shift,对特征图做仿射变换
#    扩散模型中时间步嵌入就是用这种方式注入的
#    适合:紧凑的条件信号(类别 ID、时间步)
#
# ③ 交叉注意力(Cross-Attention)← 文本条件的标准方式
#    U-Net 的每个 ResBlock 后面加 Cross-Attention 层
#    Query 来自图像特征,Key/Value 来自文本特征
#    使图像的每个位置都能"查询"文字的每个词
#    适合:序列条件(文本、音频、其他图像)

condition_methods = {
    "拼接(Concat)": {
        "条件类型": "紧凑向量",
        "实现难度": "简单",
        "灵活性": "低",
        "典型应用": "类别标签、MNIST digit",
    },
    "AdaGN(自适应归一化)": {
        "条件类型": "紧凑向量",
        "实现难度": "中等",
        "灵活性": "中",
        "典型应用": "时间步嵌入(必用)、类别嵌入",
    },
    "交叉注意力(Cross-Attention)": {
        "条件类型": "序列",
        "实现难度": "较高",
        "灵活性": "高",
        "典型应用": "文本→图像(DALL-E 2、SD)",
    },
    "混合(CFG 标配)": {
        "条件类型": "序列 + 向量",
        "实现难度": "高",
        "灵活性": "最高",
        "典型应用": "Stable Diffusion:文本+时间步",
    },
}

print("条件注入方式对比:")
print()
for method, info in condition_methods.items():
    print(f"  [{method}]")
    for k, v in info.items():
        print(f"    {k:10s}: {v}")
    print()

二、文本编码器:CLIP vs T5

2.1 CLIP 文本编码器

# CLIP(Contrastive Language-Image Pre-training,OpenAI 2021)
# 训练方式:对比学习,让文本和图像的嵌入在语义上对齐
# 输入:文字
# 输出:[batch, 77, 768](最多 77 个词的特征序列)
#
# CLIP 为什么适合文生图?
# ① 文本特征与图像特征在同一语义空间
# ② 77 个词的序列输出,为交叉注意力提供 Key 和 Value
# ③ 已在海量图文对上训练,泛化能力强

# 用 transformers 加载 CLIP 文本编码器
def encode_text_with_clip(texts: list, device: str = "cpu"):
    """
    使用 CLIP 编码文字
    texts: 文字列表
    返回: (batch, 77, 768) 的特征张量
    """
    from transformers import CLIPTokenizer, CLIPTextModel

    # 加载 CLIP tokenizer 和 text encoder
    # 实际使用时从 HuggingFace 下载,这里展示接口
    model_id = "openai/clip-vit-large-patch14"

    print(f"CLIP 文本编码流程(以 '{texts[0]}' 为例):")
    print()
    print("  ① Tokenization:文字 → token ID 序列")
    print("     'a cat' → [49406, 320, 2368, 49407, 0, 0, ...] (77 个 token)")
    print()
    print("  ② Token Embedding:token ID → 词向量(512 维)")
    print("  ③ Positional Embedding:加入位置信息")
    print("  ④ Transformer 编码器(12层):捕获词间关系")
    print("  ⑤ 输出:(B, 77, 768) 的特征序列")
    print()

    # 模拟输出(演示形状)
    batch_size     = len(texts)
    seq_len        = 77
    hidden_dim     = 768
    simulated_out  = torch.randn(batch_size, seq_len, hidden_dim)

    print(f"  输出形状:{simulated_out.shape}")
    print(f"  每个词 → 768 维特征向量")
    print(f"  所有 77 个位置都有对应的特征(padding 位置全零)")
    return simulated_out

text_features = encode_text_with_clip(["a cute orange cat sitting in sunlight"])

2.2 T5 文本编码器(更强的语言理解)

# T5(Text-to-Text Transfer Transformer,Google)
# 相比 CLIP 的优势:
# ① 更强的语言理解(在纯文本任务上训练)
# ② 无文本长度限制(CLIP 限制 77 token)
# ③ 对复杂描述(空间关系、数量、属性绑定)理解更好
#
# 使用场景:
# Imagen(Google):T5-XXL(11B 参数)
# FLUX(Black Forest Labs,2024):T5-XXL
# DALL·E 3:GPT-4 生成描述 + CLIP 编码

def compare_text_encoders():
    encoders = {
        "CLIP ViT-L/14": {
            "参数量":     "123M",
            "输出维度":   "768",
            "最大长度":   "77 tokens",
            "训练数据":   "4亿图文对(对比学习)",
            "优点":       "图文对齐好,快",
            "缺点":       "文本理解弱,长描述截断",
            "使用者":     "Stable Diffusion v1",
        },
        "CLIP ViT-bigG": {
            "参数量":     "1.8B",
            "输出维度":   "1280",
            "最大长度":   "77 tokens",
            "训练数据":   "34亿图文对(LAION-5B)",
            "优点":       "更强的图文对齐",
            "缺点":       "仍有长度限制",
            "使用者":     "SDXL",
        },
        "OpenCLIP": {
            "参数量":     "302M-2B",
            "输出维度":   "1024-1280",
            "最大长度":   "77 tokens",
            "训练数据":   "开源数据集(可复现)",
            "优点":       "开源,SDXL 使用",
            "缺点":       "同 CLIP",
            "使用者":     "SDXL(双编码器)",
        },
        "T5-XXL": {
            "参数量":     "11B",
            "输出维度":   "4096",
            "最大长度":   "512 tokens",
            "训练数据":   "C4 纯文本(语言理解)",
            "优点":       "强大的语言理解,长文本",
            "缺点":       "显存大,速度慢",
            "使用者":     "Imagen, FLUX",
        },
    }

    print("主流文本编码器对比:")
    print()
    for name, info in encoders.items():
        print(f"  [{name}]")
        for k, v in info.items():
            print(f"    {k:10s}: {v}")
        print()

compare_text_encoders()

2.3 SDXL 的双编码器策略

# SDXL 同时使用两个文本编码器:
# ① CLIP ViT-L(输出 768 维,77 token)
# ② OpenCLIP ViT-bigG(输出 1280 维,77 token)
# 拼接后:(B, 77, 768+1280) = (B, 77, 2048)
#
# 为什么用两个?
# 两个编码器捕获不同维度的语义信息
# 实验证明比单个更大的编码器效果更好

def sdxl_dual_encoding_demo():
    print("SDXL 双编码器策略演示:")
    print()

    B, seq = 2, 77

    # 模拟两个编码器的输出
    clip_vit_l_out    = torch.randn(B, seq, 768)    # CLIP ViT-L
    opencilp_bigG_out = torch.randn(B, seq, 1280)   # OpenCLIP ViT-bigG

    # 在特征维度拼接
    combined = torch.cat([clip_vit_l_out, opencilp_bigG_out], dim=-1)

    print(f"  CLIP ViT-L 输出:{clip_vit_l_out.shape}")
    print(f"  OpenCLIP bigG 输出:{opencilp_bigG_out.shape}")
    print(f"  拼接后(Cross-Attention K/V):{combined.shape}")
    print()
    print(f"  额外:OpenCLIP bigG 的池化向量([CLS] token)")
    pooled = opencilp_bigG_out[:, 0, :]   # 取第一个 token 作为全局文本特征
    print(f"  池化向量:{pooled.shape}(用于 AdaGN,补充全局语义)")

sdxl_dual_encoding_demo()

三、交叉注意力:把文字注入 U-Net

3.1 交叉注意力的机制

class CrossAttention(nn.Module):
    """
    交叉注意力(Cross-Attention)模块
    让图像特征"查询"文本特征

    关键思想:
    - Query(Q)来自图像特征(让图像决定"想知道什么")
    - Key(K)和 Value(V)来自文本特征(文字提供"答案")
    - 每个图像位置都可以独立地关注不同的文字信息

    对比自注意力(Self-Attention):
    自注意力:Q、K、V 都来自同一个序列(图像内部)
    交叉注意力:Q 来自图像,K/V 来自文本(跨模态)
    """

    def __init__(
        self,
        query_dim:   int,        # 图像特征维度
        context_dim: int,        # 文本特征维度(如 768)
        num_heads:   int = 8,
        num_groups:  int = 8,
    ):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim  = query_dim // num_heads
        self.scale     = self.head_dim ** -0.5

        # 归一化(对图像特征)
        self.norm = nn.LayerNorm(query_dim)

        # Q 来自图像特征
        self.to_q = nn.Linear(query_dim, query_dim, bias=False)
        # K, V 来自文本特征
        self.to_k = nn.Linear(context_dim, query_dim, bias=False)
        self.to_v = nn.Linear(context_dim, query_dim, bias=False)

        # 输出投影
        self.to_out = nn.Linear(query_dim, query_dim)

    def forward(
        self,
        x:       torch.Tensor,   # (B, N_img, query_dim) 图像特征
        context: torch.Tensor,   # (B, N_txt, context_dim) 文本特征
    ) -> torch.Tensor:
        """
        x: 图像特征(经过展平的特征图)
        context: 文本特征(CLIP 编码的 77 个词)
        """
        B, N_img, C = x.shape

        # 归一化
        h = self.norm(x)

        # 计算 Q(来自图像),K/V(来自文本)
        q = self.to_q(h)                # (B, N_img, C)
        k = self.to_k(context)          # (B, N_txt, C)
        v = self.to_v(context)          # (B, N_txt, C)

        # 多头变换
        def reshape_for_heads(t):
            B, N, C = t.shape
            return t.reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2)
            # (B, heads, N, head_dim)

        q = reshape_for_heads(q)
        k = reshape_for_heads(k)
        v = reshape_for_heads(v)

        # 注意力分数:(B, heads, N_img, N_txt)
        # 图像的每个位置 → 关注文本的每个词
        attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale
        attn = F.softmax(attn, dim=-1)

        # 加权聚合:(B, heads, N_img, head_dim)
        out = torch.matmul(attn, v)

        # 合并头:(B, N_img, C)
        out = out.transpose(1, 2).reshape(B, N_img, C)
        out = self.to_out(out)

        # 残差连接
        return x + out


# 验证交叉注意力
cross_attn = CrossAttention(query_dim=512, context_dim=768, num_heads=8)

B        = 2
N_img    = 16 * 16   # 16×16 特征图展平
N_txt    = 77        # CLIP 最多 77 个词

img_feat  = torch.randn(B, N_img, 512)   # 图像特征
txt_feat  = torch.randn(B, N_txt, 768)   # 文本特征

out = cross_attn(img_feat, txt_feat)

print("交叉注意力验证:")
print(f"  图像特征形状:{img_feat.shape}  (B × N_img × query_dim)")
print(f"  文本特征形状:{txt_feat.shape}  (B × 77词 × 768)")
print(f"  输出形状:    {out.shape}")
print()
print("  注意力矩阵形状:(B, heads, N_img, N_txt)")
print(f"  = ({B}, {cross_attn.num_heads}, {N_img}, {N_txt})")
print(f"  = 图像的每个位置 × 文本的每个词")
print()
print("  语义含义:")
print("  attn[b, h, i, j] = 位置 i 的像素关注第 j 个词的强度")
print("  → 图像的不同区域自动学会关注相关的文字描述")

3.2 注意力模式的可视化

def visualize_cross_attention_semantics():
    """
    展示交叉注意力的语义理解能力
    (用简化的随机注意力演示概念,真实模型的注意力有明确的语义对应)
    """
    print("交叉注意力的语义对应(示意):")
    print()
    print("  提示词:'a red cat sitting on a blue sofa'")
    print()
    print("  理想的注意力模式(真实训练后的模型):")
    print()
    print("  图像区域        | 主要关注的词")
    print("  ──────────────────────────────")
    print("  猫的身体区域    | 'red' + 'cat'")
    print("  沙发区域        | 'blue' + 'sofa'")
    print("  猫的姿态区域    | 'sitting'")
    print("  整体布局        | 'on'(空间关系)")
    print()
    print("  这就是为什么交叉注意力比简单的条件拼接强:")
    print("  → 图像不同区域可以独立地对应不同的文字描述")
    print("  → 空间对应关系自然涌现(无需显式监督)")

    # 用随机权重演示注意力矩阵的形状和含义
    B, H, N_img, N_txt = 1, 1, 4, 5  # 简化版:4像素 × 5词

    tokens = ["a", "red", "cat", "on", "sofa"]
    positions = ["左上", "右上", "左下", "右下"]

    torch.manual_seed(42)
    attn_weights = F.softmax(torch.randn(N_img, N_txt), dim=-1)

    print()
    print("  简化示例(4个位置 × 5个词的注意力权重):")
    print(f"  {'':8s}", end="")
    for tok in tokens:
        print(f"  {tok:6s}", end="")
    print()
    print("  " + "─" * 45)
    for i, pos in enumerate(positions):
        print(f"  {pos:8s}", end="")
        for j in range(N_txt):
            w = attn_weights[i, j].item()
            bar = "█" * int(w * 8)
            print(f"  {w:.3f}", end="")
        # 找最大注意力的词
        max_j = attn_weights[i].argmax().item()
        print(f"  ← 主要:'{tokens[max_j]}'")


visualize_cross_attention_semantics()

四、条件 U-Net 的完整结构

class ConditionedResBlock(nn.Module):
    """
    条件化 ResBlock:同时接受时间步嵌入和文本特征
    结构:
      Conv1 → AdaGN(时间步)→ Conv2 → Cross-Attention(文本)→ + shortcut
    """

    def __init__(
        self,
        in_channels:  int,
        out_channels: int,
        time_dim:     int,
        context_dim:  int,    # 文本特征维度
        num_heads:    int = 8,
        num_groups:   int = 8,
        dropout:      float = 0.0,
    ):
        super().__init__()

        # ── ResBlock 部分(与第四篇相同)──────────────────────
        self.norm1     = nn.GroupNorm(num_groups, in_channels)
        self.conv1     = nn.Conv2d(in_channels, out_channels, 3, padding=1)
        self.time_proj = nn.Sequential(
            nn.SiLU(),
            nn.Linear(time_dim, out_channels * 2),
        )
        self.norm2   = nn.GroupNorm(num_groups, out_channels)
        self.dropout = nn.Dropout(dropout)
        self.conv2   = nn.Conv2d(out_channels, out_channels, 3, padding=1)
        self.act     = nn.SiLU()

        if in_channels != out_channels:
            self.shortcut = nn.Conv2d(in_channels, out_channels, 1)
        else:
            self.shortcut = nn.Identity()

        # ── 自注意力(图像内部)────────────────────────────
        self.self_attn_norm = nn.LayerNorm(out_channels)
        self.self_attn      = nn.MultiheadAttention(
            out_channels, num_heads, batch_first=True
        )

        # ── 交叉注意力(文本条件)──────────────────────────
        self.cross_attn = CrossAttention(
            query_dim   = out_channels,
            context_dim = context_dim,
            num_heads   = num_heads,
        )

        # ── 前馈网络(Post-Attention)──────────────────────
        self.ff_norm = nn.LayerNorm(out_channels)
        self.ff      = nn.Sequential(
            nn.Linear(out_channels, out_channels * 4),
            nn.GELU(),
            nn.Linear(out_channels * 4, out_channels),
        )

    def forward(
        self,
        x:        torch.Tensor,   # (B, C, H, W)
        time_emb: torch.Tensor,   # (B, time_dim)
        context:  torch.Tensor,   # (B, N_txt, context_dim) 文本特征
    ) -> torch.Tensor:
        B, C, H, W = x.shape

        # ── ResBlock 前向(时间步注入)────────────────────
        h = self.act(self.norm1(x))
        h = self.conv1(h)

        t_proj = self.time_proj(time_emb)
        scale, shift = t_proj.chunk(2, dim=1)
        h = self.norm2(h) * (1 + scale[:, :, None, None]) + shift[:, :, None, None]
        h = self.act(h)
        h = self.dropout(h)
        h = self.conv2(h)
        h = h + self.shortcut(x)

        # ── 展平为序列,进行注意力操作 ────────────────────
        # (B, C, H, W) → (B, H*W, C)
        h_flat = h.reshape(B, C, H * W).transpose(1, 2)

        # ── 自注意力(图像区域间的关联)──────────────────
        h_norm = self.self_attn_norm(h_flat)
        sa_out, _ = self.self_attn(h_norm, h_norm, h_norm)
        h_flat = h_flat + sa_out

        # ── 交叉注意力(文本条件注入)────────────────────
        h_flat = self.cross_attn(h_flat, context)

        # ── 前馈网络(增强表达)──────────────────────────
        h_flat = h_flat + self.ff(self.ff_norm(h_flat))

        # ── reshape 回 (B, C, H, W) ──────────────────────
        return h_flat.transpose(1, 2).reshape(B, C, H, W)


# 验证条件化 ResBlock
cond_block = ConditionedResBlock(
    in_channels  = 256,
    out_channels = 256,
    time_dim     = 512,
    context_dim  = 768,
    num_heads    = 8,
)

B, C, H, W = 2, 256, 16, 16
x        = torch.randn(B, C, H, W)
time_emb = torch.randn(B, 512)
context  = torch.randn(B, 77, 768)   # 文本特征(77词,768维)

out = cond_block(x, time_emb, context)

print("条件化 ResBlock 验证:")
print(f"  图像特征输入:{x.shape}")
print(f"  时间步嵌入:  {time_emb.shape}")
print(f"  文本特征:    {context.shape}")
print(f"  输出形状:    {out.shape}")
print()

# 验证文本条件确实影响了输出
context_diff = torch.randn(B, 77, 768)   # 不同的文本
out_diff     = cond_block(x, time_emb, context_diff)
diff = (out - out_diff).abs().mean().item()
print(f"  不同文本 → 不同输出(差异:{diff:.4f})✓")
print(f"  参数量:{sum(p.numel() for p in cond_block.parameters()):,}")

五、Classifier Guidance(分类器引导)

# 条件生成的第一种方法:训练一个分类器,在采样时引导梯度
#
# Dhariwal & Nichol (2021) 提出:
# 修改采样公式:
# ∇_xₜ log p_θ(xₜ) → ∇_xₜ [log p_θ(xₜ) + γ · log p_φ(y|xₜ)]
# p_θ:扩散模型
# p_φ:噪声鲁棒的分类器(在加噪的图像上训练)
# y:目标类别
# γ:引导强度

def classifier_guidance_pseudocode():
    print("Classifier Guidance 采样伪代码:")
    print()
    print("  # 需要:1个扩散模型 + 1个噪声鲁棒分类器")
    print("  # 分类器必须在加噪图像上训练!")
    print()
    print("  for t = T, T-1, ..., 1:")
    print("    # 扩散模型预测")
    print("    eps_unconditional = diffusion_model(xt, t)")
    print()
    print("    # 分类器梯度")
    print("    xₜ.requires_grad = True")
    print("    log_p = classifier(xₜ, t).log_softmax()[target_class]")
    print("    grad = ∇_xₜ log_p   # 梯度!")
    print()
    print("    # 修改分数函数")
    print("    eps_guided = eps_unconditional - √(1-ᾱₜ) · γ · grad")
    print()
    print("    # 用 eps_guided 代替 eps_unconditional 采样")
    print("    xt_prev = ddim_step(xt, eps_guided, t)")
    print()
    print("  缺点:")
    print("  ① 需要额外训练噪声鲁棒的分类器")
    print("  ② 每步都要计算分类器梯度(慢)")
    print("  ③ 只能做类别条件,无法做文本条件")
    print("  → 被 Classifier-Free Guidance 取代")


classifier_guidance_pseudocode()

六、Classifier-Free Guidance:核心技术

6.1 CFG 的关键思想

# Classifier-Free Guidance (Ho & Salimans, 2021)
# 不需要额外的分类器!
# 通过同时训练有条件和无条件两种模式,在推理时组合

# 训练阶段:
#   以概率 p_uncond(通常 10-20%)把条件 c 替换为空条件 ∅
#   网络同时学会:
#     ε_θ(xₜ, t, c)    有条件去噪(给定文本)
#     ε_θ(xₜ, t, ∅)    无条件去噪(不给文本)
#
# 推理阶段(CFG 公式):
#   ε_guided = ε_θ(xₜ, t, ∅) + γ · [ε_θ(xₜ, t, c) - ε_θ(xₜ, t, ∅)]
#
#   其中 γ 是 guidance_scale(引导强度)
#   γ=1:等同于有条件去噪(无引导)
#   γ=7.5:Stable Diffusion 默认值,较强的文字对齐
#   γ>15:文字对齐极强,但可能过饱和/失真

print("Classifier-Free Guidance 的直觉理解:")
print()
print("  ε_guided = ε_uncond + γ · (ε_cond - ε_uncond)")
print()
print("  拆解:")
print("  ε_uncond           = 无条件方向(随机漫步)")
print("  ε_cond - ε_uncond  = 条件带来的额外方向(文字的影响)")
print("  γ · (...)          = 放大文字影响(γ越大,文字越重要)")
print()
print("  类比:")
print("  ε_uncond           = 没有地图乱走(随机)")
print("  ε_cond - ε_uncond  = 地图指的方向")
print("  γ · (...)          = 按地图方向走 γ 倍(越守规矩越准)")
print()

# 数学角度:CFG 等价于隐式分类器引导
print("  数学等价性(Ho & Salimans 的推导):")
print("  ε_guided = ε_uncond + γ · (ε_cond - ε_uncond)")
print("  ≡ (1-γ) · ε_uncond + γ · ε_cond  (改写)")
print("  ≡ ε_θ(xₜ, t, c) + (γ-1) · [ε_cond - ε_uncond]")
print()
print("  = 有条件去噪 + 放大了 (γ-1) 倍的「条件与无条件的差」")
print("  → 本质:把有条件和无条件的差距放大 γ 倍")

6.2 不同 guidance_scale 的效果

def analyze_cfg_scale():
    """分析不同 guidance_scale 对生成质量的影响"""

    print("不同 guidance_scale 的效果分析:")
    print()
    print(f"  {'γ':^6} {'文字对齐':^12} {'图像多样性':^14} {'图像质量':^12} {'推荐场景':^20}")
    print("  " + "─" * 68)

    configs = [
        (1.0,  "★☆☆☆☆", "★★★★★", "★★★☆☆", "无引导(随机生成)"),
        (2.0,  "★★☆☆☆", "★★★★☆", "★★★★☆", "轻度引导"),
        (5.0,  "★★★★☆", "★★★☆☆", "★★★★★", "平衡点(艺术创作)"),
        (7.5,  "★★★★★", "★★★☆☆", "★★★★★", "SD 默认(推荐)⭐"),
        (10.0, "★★★★★", "★★☆☆☆", "★★★★☆", "强调细节匹配"),
        (15.0, "★★★★★", "★★☆☆☆", "★★★☆☆", "极端对齐,可能过饱和"),
        (20.0, "★★★★★", "★☆☆☆☆", "★★☆☆☆", "通常过度,质量下降"),
    ]

    for gamma, text_align, diversity, quality, scene in configs:
        marker = " ←" if gamma == 7.5 else ""
        print(f"  {gamma:^6.1f} {text_align:^14} {diversity:^14} {quality:^14} {scene}{marker}")

    print()
    print("  关键观察:")
    print("  γ 越大 → 文字对齐越好,但多样性降低,可能过饱和")
    print("  γ=1    → 等价于无 CFG,图像多样但可能不符合提示词")
    print("  最优 γ  → 取决于任务:写实摄影偏低,艺术创作偏高")


analyze_cfg_scale()

七、CFG 的数学原理与实现

7.1 CFG 与分类器引导的等价性推导

def derive_cfg_equivalence():
    """推导 CFG 与隐式分类器引导的等价关系"""

    print("CFG 与分类器引导的数学等价性推导:")
    print()
    print("  贝叶斯公式:")
    print("  p(xₜ|c) = p(c|xₜ) · p(xₜ) / p(c)")
    print()
    print("  取对数梯度(得分函数):")
    print("  ∇_xₜ log p(xₜ|c)")
    print("  = ∇_xₜ log p(c|xₜ) + ∇_xₜ log p(xₜ)")
    print()
    print("  替换扩散模型的得分函数:")
    print("  -1/√(1-ᾱₜ) · ε_θ(xₜ,t,c) ≈ ∇_xₜ log p(xₜ|c)")
    print()
    print("  Classifier Guidance 的引导方向:")
    print("  ∇_xₜ [log p(xₜ) + γ·log p(c|xₜ)]")
    print("  = ∇_xₜ log p(xₜ) + γ·∇_xₜ log p(c|xₜ)")
    print("  ≈ -1/√(1-ᾱₜ) · ε_uncond + γ·(-1/√(1-ᾱₜ))·(ε_cond-ε_uncond)")
    print("  ≈ -1/√(1-ᾱₜ) · [ε_uncond + γ·(ε_cond-ε_uncond)]")
    print()
    print("  ✓ 这正是 CFG 公式!")
    print("  ε_guided = ε_uncond + γ·(ε_cond - ε_uncond)")
    print()
    print("  关键结论:")
    print("  CFG 无需额外分类器,用无条件模型隐式充当分类器")
    print("  γ-1 正好等于隐式分类器的引导强度")

derive_cfg_equivalence()

7.2 负面提示词(Negative Prompt)

# CFG 的扩展:负面提示词(Negative Prompt)
#
# 标准 CFG:
#   ε_guided = ε_θ(xₜ, t, ∅) + γ · [ε_θ(xₜ, t, c_pos) - ε_θ(xₜ, t, ∅)]
#
# 带负面提示词的 CFG:
#   用负面提示词 c_neg 代替空条件 ∅
#   ε_guided = ε_θ(xₜ, t, c_neg) + γ · [ε_θ(xₜ, t, c_pos) - ε_θ(xₜ, t, c_neg)]
#
# 直觉:
#   "远离 c_neg 的方向" + γ × "从 c_neg 走向 c_pos"
#   → 既避开负面内容,又朝向正面内容

def negative_prompt_demo():
    print("负面提示词(Negative Prompt)的原理:")
    print()
    print("  正面提示:'a beautiful landscape, high quality, detailed'")
    print("  负面提示:'blurry, ugly, bad anatomy, low quality, nsfw'")
    print()
    print("  CFG 公式(带负面提示词):")
    print("  ε_guided = ε_θ(xₜ, c_neg) + γ × [ε_θ(xₜ, c_pos) - ε_θ(xₜ, c_neg)]")
    print()
    print("  几何直觉:")
    print("  ε(c_neg) → 远离'模糊、低质量'的方向")
    print("  ε(c_pos) → 朝向'美丽风景、高质量'的方向")
    print("  组合:远离负面 + γ倍的(正面-负面)方向")
    print()
    print("  常用负面提示词:")
    negative_examples = [
        "图像质量", "blurry, low quality, jpeg artifacts, noisy",
        "人物质量", "bad anatomy, extra fingers, mutated hands, ugly face",
        "内容过滤", "nsfw, explicit, violence, gore",
        "风格排除", "cartoon, anime, 3d render(如果要求摄影风格)",
    ]
    for category, prompt in zip(negative_examples[::2], negative_examples[1::2]):
        print(f"    [{category}]: {prompt}")


negative_prompt_demo()

八、完整条件采样代码

class ConditionalDDIMSampler:
    """
    支持文本条件和 CFG 的 DDIM 采样器
    核心特性:
    1. 接受文本编码作为条件
    2. 支持 CFG(guidance_scale > 1 时自动启用)
    3. 支持负面提示词
    4. 每步只需运行 1 次网络(批次维度展开技巧)
    """

    def __init__(
        self,
        denoiser,                    # 条件化 U-Net(接受 context 参数)
        alphas_cumprod: torch.Tensor,
        num_train_steps: int = 1000,
        device: str = "cpu",
    ):
        self.denoiser        = denoiser
        self.alphas_cumprod  = alphas_cumprod
        self.num_train_steps = num_train_steps
        self.device          = device

    def get_timesteps(self, num_inference_steps: int) -> list:
        step_ratio = self.num_train_steps // num_inference_steps
        return list(reversed(range(0, self.num_train_steps, step_ratio)))

    def ddim_step(
        self,
        xt:      torch.Tensor,
        t_cur:   int,
        t_prev:  int,
        eps:     torch.Tensor,
        eta:     float = 0.0,
    ) -> torch.Tensor:
        ab_t  = self.alphas_cumprod[t_cur]
        ab_tp = self.alphas_cumprod[t_prev] if t_prev >= 0 else torch.ones(1)

        x0_pred = (xt - (1 - ab_t).sqrt() * eps) / ab_t.sqrt()
        x0_pred = x0_pred.clamp(-1, 1)

        dir_xt  = (1 - ab_tp - eta**2 * (1-ab_t)/ab_t * (ab_t-ab_tp)).clamp(0).sqrt() * eps
        noise   = eta * ((1-ab_tp)/(1-ab_t)).sqrt() * \
                  (1 - ab_t/ab_tp).clamp(0).sqrt() * torch.randn_like(xt)

        return ab_tp.sqrt() * x0_pred + dir_xt + noise

    @torch.no_grad()
    def sample(
        self,
        shape:          tuple,
        text_embeddings: torch.Tensor,   # (B, 77, context_dim) 正面提示词
        neg_embeddings:  torch.Tensor = None,  # 负面提示词(None 则用空)
        guidance_scale:  float = 7.5,
        num_steps:       int   = 50,
        eta:             float = 0.0,
        init_noise:      torch.Tensor = None,
        verbose:         bool  = False,
    ) -> torch.Tensor:
        """
        条件 DDIM 采样

        工程技巧:
        当 guidance_scale > 1 时,每步需要计算:
        ① ε_θ(xₜ, t, c_pos)(正面条件)
        ② ε_θ(xₜ, t, c_neg)(负面条件或空条件)
        通常做法:把两者拼成一个 batch,一次前向传播同时得到两个结果
        → 比两次单独调用快约 1.5 倍(CUDA 更好地利用并行)
        """
        self.denoiser.eval()
        B = shape[0]

        # 初始噪声
        if init_noise is None:
            xt = torch.randn(*shape, device=self.device)
        else:
            xt = init_noise.clone().to(self.device)

        # 准备条件嵌入
        # 负面提示词:如果未提供则使用全零(空条件)
        if neg_embeddings is None:
            neg_embeddings = torch.zeros_like(text_embeddings)

        # ── CFG 批次展开技巧 ─────────────────────────────────
        # 将正面和负面提示词拼成一个 batch,一次前向得到两个结果
        # (2B, 77, context_dim)
        batch_context = torch.cat([neg_embeddings, text_embeddings], dim=0)

        timesteps = self.get_timesteps(num_steps)

        for step_idx, t in enumerate(timesteps):
            t_prev = timesteps[step_idx + 1] if step_idx + 1 < len(timesteps) else -1

            t_tensor = torch.full((B,), t, dtype=torch.long, device=self.device)

            if guidance_scale == 1.0:
                # 无 CFG:只做有条件推理
                eps_pred = self.denoiser(xt, t_tensor, context=text_embeddings)
            else:
                # CFG:一次前向同时计算有条件和无条件
                # 把 xt 复制一份:(2B, C, H, W)
                xt_doubled = torch.cat([xt, xt], dim=0)
                t_doubled  = torch.cat([t_tensor, t_tensor], dim=0)

                # 一次前向传播
                eps_both = self.denoiser(xt_doubled, t_doubled, context=batch_context)

                # 分离负面和正面的预测
                eps_neg, eps_pos = eps_both.chunk(2, dim=0)

                # CFG 公式:引导方向 = 无条件 + γ·(有条件-无条件)
                eps_pred = eps_neg + guidance_scale * (eps_pos - eps_neg)

            xt = self.ddim_step(xt, t, t_prev, eps_pred, eta)

            if verbose and step_idx % (num_steps // 5) == 0:
                print(f"  步骤 {step_idx+1}/{num_steps}: "
                      f"t={t}, xt.std={xt.std():.4f}")

        # 归一化到 [0, 1]
        return (xt.clamp(-1, 1) + 1) / 2


# ── 用简单模型验证完整流程 ──────────────────────────────────────

class SimpleCFGDenoiser(nn.Module):
    """
    极简的条件去噪网络(仅用于验证 CFG 框架)
    实际应用中替换为第四篇的 U-Net
    """

    def __init__(self, channels=3, hidden=32, context_dim=768):
        super().__init__()
        # 时间步嵌入
        self.t_emb = nn.Sequential(
            nn.Linear(1, hidden), nn.SiLU(), nn.Linear(hidden, hidden)
        )
        # 文本特征池化(简化)
        self.ctx_pool = nn.Linear(context_dim, hidden)
        # 主干网络
        self.net = nn.Sequential(
            nn.Conv2d(channels + hidden, hidden, 3, padding=1), nn.SiLU(),
            nn.Conv2d(hidden, channels, 3, padding=1),
        )

    def forward(self, x, t, context=None):
        B, C, H, W = x.shape
        # 时间步嵌入
        t_feat = self.t_emb(t.float().unsqueeze(1) / 1000)  # (B, hidden)
        if context is not None:
            # 文本特征:池化 77 个词 → 一个向量
            ctx_feat = self.ctx_pool(context.mean(dim=1))     # (B, hidden)
            t_feat   = t_feat + ctx_feat                      # 简单融合
        # 拼接时间特征到图像
        t_map = t_feat[:, :, None, None].expand(B, -1, H, W)
        return self.net(torch.cat([x, t_map], dim=1))


# 创建并测试
T = 1000
betas          = torch.linspace(1e-4, 0.02, T)
alphas_cumprod = torch.cumprod(1 - betas, 0)

denoiser_cfg = SimpleCFGDenoiser(channels=3, hidden=16, context_dim=768)

sampler = ConditionalDDIMSampler(
    denoiser_cfg, alphas_cumprod, num_train_steps=T
)

# 模拟文本嵌入(真实应用中来自 CLIP)
text_emb = torch.randn(2, 77, 768)  # (B=2, 77词, 768维)
neg_emb  = torch.zeros(2, 77, 768)  # 空负面提示词

print("条件 DDIM 采样验证:")
print()
print("测试 1:无 CFG(guidance_scale=1.0)")
img1 = sampler.sample(
    shape=(2, 3, 16, 16),
    text_embeddings=text_emb,
    guidance_scale=1.0,
    num_steps=10,
    verbose=True,
)
print(f"  生成图像形状:{img1.shape},值域:[{img1.min():.3f}, {img1.max():.3f}]")

print()
print("测试 2:启用 CFG(guidance_scale=7.5)")
img2 = sampler.sample(
    shape=(2, 3, 16, 16),
    text_embeddings=text_emb,
    neg_embeddings=neg_emb,
    guidance_scale=7.5,
    num_steps=10,
    verbose=True,
)
print(f"  生成图像形状:{img2.shape},值域:[{img2.min():.3f}, {img2.max():.3f}]")

print()
print("测试 3:不同 guidance_scale 的输出差异")
for gs in [1.0, 3.0, 7.5, 15.0]:
    torch.manual_seed(42)
    img = sampler.sample(
        shape=(1, 3, 16, 16),
        text_embeddings=text_emb[:1],
        guidance_scale=gs,
        num_steps=5,
    )
    print(f"  guidance_scale={gs:5.1f}: mean={img.mean():.4f}, std={img.std():.4f}")

总结

概念 说明 关键参数
文本编码 CLIP/T5 把文字变成 (B,77,768) 的特征序列 编码器维度、最大长度
交叉注意力 Q=图像特征,K/V=文本特征,图像"查询"文字 query_dim, context_dim
Classifier Guidance 用分类器梯度引导,需要额外分类器 guidance_scale
CFG 公式 ϵ g = ϵ u + γ ( ϵ c − ϵ u ) \epsilon_g = \epsilon_u + \gamma(\epsilon_c - \epsilon_u) ϵg=ϵu+γ(ϵcϵu) guidance_scale γ
负面提示词 c n e g c_{neg} cneg 代替空条件,双向控制生成方向 negative_prompt
CFG 批次技巧 正负提示词合并成 2B,一次前向得两个结果 速度提升 ~1.5×

guidance_scale 速查表:

  • γ=1.0:无 CFG,完全随机
  • γ=5.0:平衡多样性与对齐(艺术风格推荐)
  • γ=7.5:SD 默认,写实摄影推荐 ⭐
  • γ≥15:极强对齐,可能过饱和

下一篇预告:Stable Diffusion——潜在扩散模型(LDM),为什么要在"潜空间"做扩散而不是像素空间?VAE 的角色是什么?以及 ControlNet 如何通过附加条件(深度图、姿态图)实现精确的结构控制。


💬 你用 Stable Diffusion 时 guidance_scale 一般设多少?发现过什么规律? 欢迎评论区分享!

🙏 如果这篇帮到你,点赞 + 收藏,系列持续更新!


本文为原创技术分享。代码在 Python 3.12 + PyTorch 2.x + transformers 下验证。最后更新:2026-05-20

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐