【扩散模型系列·第六篇】条件生成:文本引导、交叉注意力与 Classifier-Free Guidance 全解析
【扩散模型系列·第六篇】条件生成:文本引导、交叉注意力与 Classifier-Free Guidance 全解析
作者:技术博主 | 更新时间:2026-05-20 | 阅读时长:约 23 分钟
系列:扩散模型从零到实战(共 8 篇)
环境:Python 3.12 + PyTorch 2.x + transformers + diffusers
标签:扩散模型条件生成CLIP交叉注意力CFGClassifier-Free Guidance文生图文本编码

🔥 本篇目标:前五篇的扩散模型是"无条件"的——从纯噪声生成图像,但无法控制生成什么。本篇解决这个问题:如何让扩散模型根据文字描述生成图像? 完整链路包含三个关键:用 CLIP/T5 把文字编码成向量、用交叉注意力把文字注入 U-Net 每一层、以及用 Classifier-Free Guidance(CFG)大幅提升文字对齐度。理解了 CFG,你就明白 Stable Diffusion 里
guidance_scale=7.5这个数字的含义。
系列进度
| 篇次 | 主题 | 状态 |
|---|---|---|
| 第一篇 | 扩散模型是什么 | ✅ 已发布 |
| 第二篇 | 数学基础:前向过程 | ✅ 已发布 |
| 第三篇 | 反向过程:学习去噪 | ✅ 已发布 |
| 第四篇 | U-Net 架构 | ✅ 已发布 |
| 第五篇 | DDIM:加速采样 | ✅ 已发布 |
| 第六篇(本篇) | 条件生成:文本引导与 CFG | — |
| 第七篇 | Stable Diffusion:潜在扩散模型 | 即将发布 |
| 第八篇 | 实战与前沿:LoRA、DreamBooth、FLUX | 即将发布 |
目录
- 一、条件生成的三种范式
- 二、文本编码器:CLIP vs T5
- 三、交叉注意力:把文字注入 U-Net
- 四、条件 U-Net 的完整结构
- 五、Classifier Guidance(分类器引导)
- 六、Classifier-Free Guidance:核心技术
- 七、CFG 的数学原理与实现
- 八、完整条件采样代码
一、条件生成的三种范式
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
# 让扩散模型"听话"的三种方式:
#
# ① 条件拼接(Concatenation)
# 把条件(如类别标签的 one-hot 编码)直接拼接到输入
# 简单,但对复杂条件(如自然语言)效果有限
#
# ② 自适应归一化(Adaptive Normalization,AdaGN)
# 条件向量转化为 scale 和 shift,对特征图做仿射变换
# 扩散模型中时间步嵌入就是用这种方式注入的
# 适合:紧凑的条件信号(类别 ID、时间步)
#
# ③ 交叉注意力(Cross-Attention)← 文本条件的标准方式
# U-Net 的每个 ResBlock 后面加 Cross-Attention 层
# Query 来自图像特征,Key/Value 来自文本特征
# 使图像的每个位置都能"查询"文字的每个词
# 适合:序列条件(文本、音频、其他图像)
condition_methods = {
"拼接(Concat)": {
"条件类型": "紧凑向量",
"实现难度": "简单",
"灵活性": "低",
"典型应用": "类别标签、MNIST digit",
},
"AdaGN(自适应归一化)": {
"条件类型": "紧凑向量",
"实现难度": "中等",
"灵活性": "中",
"典型应用": "时间步嵌入(必用)、类别嵌入",
},
"交叉注意力(Cross-Attention)": {
"条件类型": "序列",
"实现难度": "较高",
"灵活性": "高",
"典型应用": "文本→图像(DALL-E 2、SD)",
},
"混合(CFG 标配)": {
"条件类型": "序列 + 向量",
"实现难度": "高",
"灵活性": "最高",
"典型应用": "Stable Diffusion:文本+时间步",
},
}
print("条件注入方式对比:")
print()
for method, info in condition_methods.items():
print(f" [{method}]")
for k, v in info.items():
print(f" {k:10s}: {v}")
print()
二、文本编码器:CLIP vs T5
2.1 CLIP 文本编码器
# CLIP(Contrastive Language-Image Pre-training,OpenAI 2021)
# 训练方式:对比学习,让文本和图像的嵌入在语义上对齐
# 输入:文字
# 输出:[batch, 77, 768](最多 77 个词的特征序列)
#
# CLIP 为什么适合文生图?
# ① 文本特征与图像特征在同一语义空间
# ② 77 个词的序列输出,为交叉注意力提供 Key 和 Value
# ③ 已在海量图文对上训练,泛化能力强
# 用 transformers 加载 CLIP 文本编码器
def encode_text_with_clip(texts: list, device: str = "cpu"):
"""
使用 CLIP 编码文字
texts: 文字列表
返回: (batch, 77, 768) 的特征张量
"""
from transformers import CLIPTokenizer, CLIPTextModel
# 加载 CLIP tokenizer 和 text encoder
# 实际使用时从 HuggingFace 下载,这里展示接口
model_id = "openai/clip-vit-large-patch14"
print(f"CLIP 文本编码流程(以 '{texts[0]}' 为例):")
print()
print(" ① Tokenization:文字 → token ID 序列")
print(" 'a cat' → [49406, 320, 2368, 49407, 0, 0, ...] (77 个 token)")
print()
print(" ② Token Embedding:token ID → 词向量(512 维)")
print(" ③ Positional Embedding:加入位置信息")
print(" ④ Transformer 编码器(12层):捕获词间关系")
print(" ⑤ 输出:(B, 77, 768) 的特征序列")
print()
# 模拟输出(演示形状)
batch_size = len(texts)
seq_len = 77
hidden_dim = 768
simulated_out = torch.randn(batch_size, seq_len, hidden_dim)
print(f" 输出形状:{simulated_out.shape}")
print(f" 每个词 → 768 维特征向量")
print(f" 所有 77 个位置都有对应的特征(padding 位置全零)")
return simulated_out
text_features = encode_text_with_clip(["a cute orange cat sitting in sunlight"])
2.2 T5 文本编码器(更强的语言理解)
# T5(Text-to-Text Transfer Transformer,Google)
# 相比 CLIP 的优势:
# ① 更强的语言理解(在纯文本任务上训练)
# ② 无文本长度限制(CLIP 限制 77 token)
# ③ 对复杂描述(空间关系、数量、属性绑定)理解更好
#
# 使用场景:
# Imagen(Google):T5-XXL(11B 参数)
# FLUX(Black Forest Labs,2024):T5-XXL
# DALL·E 3:GPT-4 生成描述 + CLIP 编码
def compare_text_encoders():
encoders = {
"CLIP ViT-L/14": {
"参数量": "123M",
"输出维度": "768",
"最大长度": "77 tokens",
"训练数据": "4亿图文对(对比学习)",
"优点": "图文对齐好,快",
"缺点": "文本理解弱,长描述截断",
"使用者": "Stable Diffusion v1",
},
"CLIP ViT-bigG": {
"参数量": "1.8B",
"输出维度": "1280",
"最大长度": "77 tokens",
"训练数据": "34亿图文对(LAION-5B)",
"优点": "更强的图文对齐",
"缺点": "仍有长度限制",
"使用者": "SDXL",
},
"OpenCLIP": {
"参数量": "302M-2B",
"输出维度": "1024-1280",
"最大长度": "77 tokens",
"训练数据": "开源数据集(可复现)",
"优点": "开源,SDXL 使用",
"缺点": "同 CLIP",
"使用者": "SDXL(双编码器)",
},
"T5-XXL": {
"参数量": "11B",
"输出维度": "4096",
"最大长度": "512 tokens",
"训练数据": "C4 纯文本(语言理解)",
"优点": "强大的语言理解,长文本",
"缺点": "显存大,速度慢",
"使用者": "Imagen, FLUX",
},
}
print("主流文本编码器对比:")
print()
for name, info in encoders.items():
print(f" [{name}]")
for k, v in info.items():
print(f" {k:10s}: {v}")
print()
compare_text_encoders()
2.3 SDXL 的双编码器策略
# SDXL 同时使用两个文本编码器:
# ① CLIP ViT-L(输出 768 维,77 token)
# ② OpenCLIP ViT-bigG(输出 1280 维,77 token)
# 拼接后:(B, 77, 768+1280) = (B, 77, 2048)
#
# 为什么用两个?
# 两个编码器捕获不同维度的语义信息
# 实验证明比单个更大的编码器效果更好
def sdxl_dual_encoding_demo():
print("SDXL 双编码器策略演示:")
print()
B, seq = 2, 77
# 模拟两个编码器的输出
clip_vit_l_out = torch.randn(B, seq, 768) # CLIP ViT-L
opencilp_bigG_out = torch.randn(B, seq, 1280) # OpenCLIP ViT-bigG
# 在特征维度拼接
combined = torch.cat([clip_vit_l_out, opencilp_bigG_out], dim=-1)
print(f" CLIP ViT-L 输出:{clip_vit_l_out.shape}")
print(f" OpenCLIP bigG 输出:{opencilp_bigG_out.shape}")
print(f" 拼接后(Cross-Attention K/V):{combined.shape}")
print()
print(f" 额外:OpenCLIP bigG 的池化向量([CLS] token)")
pooled = opencilp_bigG_out[:, 0, :] # 取第一个 token 作为全局文本特征
print(f" 池化向量:{pooled.shape}(用于 AdaGN,补充全局语义)")
sdxl_dual_encoding_demo()
三、交叉注意力:把文字注入 U-Net
3.1 交叉注意力的机制
class CrossAttention(nn.Module):
"""
交叉注意力(Cross-Attention)模块
让图像特征"查询"文本特征
关键思想:
- Query(Q)来自图像特征(让图像决定"想知道什么")
- Key(K)和 Value(V)来自文本特征(文字提供"答案")
- 每个图像位置都可以独立地关注不同的文字信息
对比自注意力(Self-Attention):
自注意力:Q、K、V 都来自同一个序列(图像内部)
交叉注意力:Q 来自图像,K/V 来自文本(跨模态)
"""
def __init__(
self,
query_dim: int, # 图像特征维度
context_dim: int, # 文本特征维度(如 768)
num_heads: int = 8,
num_groups: int = 8,
):
super().__init__()
self.num_heads = num_heads
self.head_dim = query_dim // num_heads
self.scale = self.head_dim ** -0.5
# 归一化(对图像特征)
self.norm = nn.LayerNorm(query_dim)
# Q 来自图像特征
self.to_q = nn.Linear(query_dim, query_dim, bias=False)
# K, V 来自文本特征
self.to_k = nn.Linear(context_dim, query_dim, bias=False)
self.to_v = nn.Linear(context_dim, query_dim, bias=False)
# 输出投影
self.to_out = nn.Linear(query_dim, query_dim)
def forward(
self,
x: torch.Tensor, # (B, N_img, query_dim) 图像特征
context: torch.Tensor, # (B, N_txt, context_dim) 文本特征
) -> torch.Tensor:
"""
x: 图像特征(经过展平的特征图)
context: 文本特征(CLIP 编码的 77 个词)
"""
B, N_img, C = x.shape
# 归一化
h = self.norm(x)
# 计算 Q(来自图像),K/V(来自文本)
q = self.to_q(h) # (B, N_img, C)
k = self.to_k(context) # (B, N_txt, C)
v = self.to_v(context) # (B, N_txt, C)
# 多头变换
def reshape_for_heads(t):
B, N, C = t.shape
return t.reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2)
# (B, heads, N, head_dim)
q = reshape_for_heads(q)
k = reshape_for_heads(k)
v = reshape_for_heads(v)
# 注意力分数:(B, heads, N_img, N_txt)
# 图像的每个位置 → 关注文本的每个词
attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale
attn = F.softmax(attn, dim=-1)
# 加权聚合:(B, heads, N_img, head_dim)
out = torch.matmul(attn, v)
# 合并头:(B, N_img, C)
out = out.transpose(1, 2).reshape(B, N_img, C)
out = self.to_out(out)
# 残差连接
return x + out
# 验证交叉注意力
cross_attn = CrossAttention(query_dim=512, context_dim=768, num_heads=8)
B = 2
N_img = 16 * 16 # 16×16 特征图展平
N_txt = 77 # CLIP 最多 77 个词
img_feat = torch.randn(B, N_img, 512) # 图像特征
txt_feat = torch.randn(B, N_txt, 768) # 文本特征
out = cross_attn(img_feat, txt_feat)
print("交叉注意力验证:")
print(f" 图像特征形状:{img_feat.shape} (B × N_img × query_dim)")
print(f" 文本特征形状:{txt_feat.shape} (B × 77词 × 768)")
print(f" 输出形状: {out.shape}")
print()
print(" 注意力矩阵形状:(B, heads, N_img, N_txt)")
print(f" = ({B}, {cross_attn.num_heads}, {N_img}, {N_txt})")
print(f" = 图像的每个位置 × 文本的每个词")
print()
print(" 语义含义:")
print(" attn[b, h, i, j] = 位置 i 的像素关注第 j 个词的强度")
print(" → 图像的不同区域自动学会关注相关的文字描述")
3.2 注意力模式的可视化
def visualize_cross_attention_semantics():
"""
展示交叉注意力的语义理解能力
(用简化的随机注意力演示概念,真实模型的注意力有明确的语义对应)
"""
print("交叉注意力的语义对应(示意):")
print()
print(" 提示词:'a red cat sitting on a blue sofa'")
print()
print(" 理想的注意力模式(真实训练后的模型):")
print()
print(" 图像区域 | 主要关注的词")
print(" ──────────────────────────────")
print(" 猫的身体区域 | 'red' + 'cat'")
print(" 沙发区域 | 'blue' + 'sofa'")
print(" 猫的姿态区域 | 'sitting'")
print(" 整体布局 | 'on'(空间关系)")
print()
print(" 这就是为什么交叉注意力比简单的条件拼接强:")
print(" → 图像不同区域可以独立地对应不同的文字描述")
print(" → 空间对应关系自然涌现(无需显式监督)")
# 用随机权重演示注意力矩阵的形状和含义
B, H, N_img, N_txt = 1, 1, 4, 5 # 简化版:4像素 × 5词
tokens = ["a", "red", "cat", "on", "sofa"]
positions = ["左上", "右上", "左下", "右下"]
torch.manual_seed(42)
attn_weights = F.softmax(torch.randn(N_img, N_txt), dim=-1)
print()
print(" 简化示例(4个位置 × 5个词的注意力权重):")
print(f" {'':8s}", end="")
for tok in tokens:
print(f" {tok:6s}", end="")
print()
print(" " + "─" * 45)
for i, pos in enumerate(positions):
print(f" {pos:8s}", end="")
for j in range(N_txt):
w = attn_weights[i, j].item()
bar = "█" * int(w * 8)
print(f" {w:.3f}", end="")
# 找最大注意力的词
max_j = attn_weights[i].argmax().item()
print(f" ← 主要:'{tokens[max_j]}'")
visualize_cross_attention_semantics()
四、条件 U-Net 的完整结构
class ConditionedResBlock(nn.Module):
"""
条件化 ResBlock:同时接受时间步嵌入和文本特征
结构:
Conv1 → AdaGN(时间步)→ Conv2 → Cross-Attention(文本)→ + shortcut
"""
def __init__(
self,
in_channels: int,
out_channels: int,
time_dim: int,
context_dim: int, # 文本特征维度
num_heads: int = 8,
num_groups: int = 8,
dropout: float = 0.0,
):
super().__init__()
# ── ResBlock 部分(与第四篇相同)──────────────────────
self.norm1 = nn.GroupNorm(num_groups, in_channels)
self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
self.time_proj = nn.Sequential(
nn.SiLU(),
nn.Linear(time_dim, out_channels * 2),
)
self.norm2 = nn.GroupNorm(num_groups, out_channels)
self.dropout = nn.Dropout(dropout)
self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
self.act = nn.SiLU()
if in_channels != out_channels:
self.shortcut = nn.Conv2d(in_channels, out_channels, 1)
else:
self.shortcut = nn.Identity()
# ── 自注意力(图像内部)────────────────────────────
self.self_attn_norm = nn.LayerNorm(out_channels)
self.self_attn = nn.MultiheadAttention(
out_channels, num_heads, batch_first=True
)
# ── 交叉注意力(文本条件)──────────────────────────
self.cross_attn = CrossAttention(
query_dim = out_channels,
context_dim = context_dim,
num_heads = num_heads,
)
# ── 前馈网络(Post-Attention)──────────────────────
self.ff_norm = nn.LayerNorm(out_channels)
self.ff = nn.Sequential(
nn.Linear(out_channels, out_channels * 4),
nn.GELU(),
nn.Linear(out_channels * 4, out_channels),
)
def forward(
self,
x: torch.Tensor, # (B, C, H, W)
time_emb: torch.Tensor, # (B, time_dim)
context: torch.Tensor, # (B, N_txt, context_dim) 文本特征
) -> torch.Tensor:
B, C, H, W = x.shape
# ── ResBlock 前向(时间步注入)────────────────────
h = self.act(self.norm1(x))
h = self.conv1(h)
t_proj = self.time_proj(time_emb)
scale, shift = t_proj.chunk(2, dim=1)
h = self.norm2(h) * (1 + scale[:, :, None, None]) + shift[:, :, None, None]
h = self.act(h)
h = self.dropout(h)
h = self.conv2(h)
h = h + self.shortcut(x)
# ── 展平为序列,进行注意力操作 ────────────────────
# (B, C, H, W) → (B, H*W, C)
h_flat = h.reshape(B, C, H * W).transpose(1, 2)
# ── 自注意力(图像区域间的关联)──────────────────
h_norm = self.self_attn_norm(h_flat)
sa_out, _ = self.self_attn(h_norm, h_norm, h_norm)
h_flat = h_flat + sa_out
# ── 交叉注意力(文本条件注入)────────────────────
h_flat = self.cross_attn(h_flat, context)
# ── 前馈网络(增强表达)──────────────────────────
h_flat = h_flat + self.ff(self.ff_norm(h_flat))
# ── reshape 回 (B, C, H, W) ──────────────────────
return h_flat.transpose(1, 2).reshape(B, C, H, W)
# 验证条件化 ResBlock
cond_block = ConditionedResBlock(
in_channels = 256,
out_channels = 256,
time_dim = 512,
context_dim = 768,
num_heads = 8,
)
B, C, H, W = 2, 256, 16, 16
x = torch.randn(B, C, H, W)
time_emb = torch.randn(B, 512)
context = torch.randn(B, 77, 768) # 文本特征(77词,768维)
out = cond_block(x, time_emb, context)
print("条件化 ResBlock 验证:")
print(f" 图像特征输入:{x.shape}")
print(f" 时间步嵌入: {time_emb.shape}")
print(f" 文本特征: {context.shape}")
print(f" 输出形状: {out.shape}")
print()
# 验证文本条件确实影响了输出
context_diff = torch.randn(B, 77, 768) # 不同的文本
out_diff = cond_block(x, time_emb, context_diff)
diff = (out - out_diff).abs().mean().item()
print(f" 不同文本 → 不同输出(差异:{diff:.4f})✓")
print(f" 参数量:{sum(p.numel() for p in cond_block.parameters()):,}")
五、Classifier Guidance(分类器引导)
# 条件生成的第一种方法:训练一个分类器,在采样时引导梯度
#
# Dhariwal & Nichol (2021) 提出:
# 修改采样公式:
# ∇_xₜ log p_θ(xₜ) → ∇_xₜ [log p_θ(xₜ) + γ · log p_φ(y|xₜ)]
# p_θ:扩散模型
# p_φ:噪声鲁棒的分类器(在加噪的图像上训练)
# y:目标类别
# γ:引导强度
def classifier_guidance_pseudocode():
print("Classifier Guidance 采样伪代码:")
print()
print(" # 需要:1个扩散模型 + 1个噪声鲁棒分类器")
print(" # 分类器必须在加噪图像上训练!")
print()
print(" for t = T, T-1, ..., 1:")
print(" # 扩散模型预测")
print(" eps_unconditional = diffusion_model(xt, t)")
print()
print(" # 分类器梯度")
print(" xₜ.requires_grad = True")
print(" log_p = classifier(xₜ, t).log_softmax()[target_class]")
print(" grad = ∇_xₜ log_p # 梯度!")
print()
print(" # 修改分数函数")
print(" eps_guided = eps_unconditional - √(1-ᾱₜ) · γ · grad")
print()
print(" # 用 eps_guided 代替 eps_unconditional 采样")
print(" xt_prev = ddim_step(xt, eps_guided, t)")
print()
print(" 缺点:")
print(" ① 需要额外训练噪声鲁棒的分类器")
print(" ② 每步都要计算分类器梯度(慢)")
print(" ③ 只能做类别条件,无法做文本条件")
print(" → 被 Classifier-Free Guidance 取代")
classifier_guidance_pseudocode()
六、Classifier-Free Guidance:核心技术
6.1 CFG 的关键思想
# Classifier-Free Guidance (Ho & Salimans, 2021)
# 不需要额外的分类器!
# 通过同时训练有条件和无条件两种模式,在推理时组合
# 训练阶段:
# 以概率 p_uncond(通常 10-20%)把条件 c 替换为空条件 ∅
# 网络同时学会:
# ε_θ(xₜ, t, c) 有条件去噪(给定文本)
# ε_θ(xₜ, t, ∅) 无条件去噪(不给文本)
#
# 推理阶段(CFG 公式):
# ε_guided = ε_θ(xₜ, t, ∅) + γ · [ε_θ(xₜ, t, c) - ε_θ(xₜ, t, ∅)]
#
# 其中 γ 是 guidance_scale(引导强度)
# γ=1:等同于有条件去噪(无引导)
# γ=7.5:Stable Diffusion 默认值,较强的文字对齐
# γ>15:文字对齐极强,但可能过饱和/失真
print("Classifier-Free Guidance 的直觉理解:")
print()
print(" ε_guided = ε_uncond + γ · (ε_cond - ε_uncond)")
print()
print(" 拆解:")
print(" ε_uncond = 无条件方向(随机漫步)")
print(" ε_cond - ε_uncond = 条件带来的额外方向(文字的影响)")
print(" γ · (...) = 放大文字影响(γ越大,文字越重要)")
print()
print(" 类比:")
print(" ε_uncond = 没有地图乱走(随机)")
print(" ε_cond - ε_uncond = 地图指的方向")
print(" γ · (...) = 按地图方向走 γ 倍(越守规矩越准)")
print()
# 数学角度:CFG 等价于隐式分类器引导
print(" 数学等价性(Ho & Salimans 的推导):")
print(" ε_guided = ε_uncond + γ · (ε_cond - ε_uncond)")
print(" ≡ (1-γ) · ε_uncond + γ · ε_cond (改写)")
print(" ≡ ε_θ(xₜ, t, c) + (γ-1) · [ε_cond - ε_uncond]")
print()
print(" = 有条件去噪 + 放大了 (γ-1) 倍的「条件与无条件的差」")
print(" → 本质:把有条件和无条件的差距放大 γ 倍")
6.2 不同 guidance_scale 的效果
def analyze_cfg_scale():
"""分析不同 guidance_scale 对生成质量的影响"""
print("不同 guidance_scale 的效果分析:")
print()
print(f" {'γ':^6} {'文字对齐':^12} {'图像多样性':^14} {'图像质量':^12} {'推荐场景':^20}")
print(" " + "─" * 68)
configs = [
(1.0, "★☆☆☆☆", "★★★★★", "★★★☆☆", "无引导(随机生成)"),
(2.0, "★★☆☆☆", "★★★★☆", "★★★★☆", "轻度引导"),
(5.0, "★★★★☆", "★★★☆☆", "★★★★★", "平衡点(艺术创作)"),
(7.5, "★★★★★", "★★★☆☆", "★★★★★", "SD 默认(推荐)⭐"),
(10.0, "★★★★★", "★★☆☆☆", "★★★★☆", "强调细节匹配"),
(15.0, "★★★★★", "★★☆☆☆", "★★★☆☆", "极端对齐,可能过饱和"),
(20.0, "★★★★★", "★☆☆☆☆", "★★☆☆☆", "通常过度,质量下降"),
]
for gamma, text_align, diversity, quality, scene in configs:
marker = " ←" if gamma == 7.5 else ""
print(f" {gamma:^6.1f} {text_align:^14} {diversity:^14} {quality:^14} {scene}{marker}")
print()
print(" 关键观察:")
print(" γ 越大 → 文字对齐越好,但多样性降低,可能过饱和")
print(" γ=1 → 等价于无 CFG,图像多样但可能不符合提示词")
print(" 最优 γ → 取决于任务:写实摄影偏低,艺术创作偏高")
analyze_cfg_scale()
七、CFG 的数学原理与实现
7.1 CFG 与分类器引导的等价性推导
def derive_cfg_equivalence():
"""推导 CFG 与隐式分类器引导的等价关系"""
print("CFG 与分类器引导的数学等价性推导:")
print()
print(" 贝叶斯公式:")
print(" p(xₜ|c) = p(c|xₜ) · p(xₜ) / p(c)")
print()
print(" 取对数梯度(得分函数):")
print(" ∇_xₜ log p(xₜ|c)")
print(" = ∇_xₜ log p(c|xₜ) + ∇_xₜ log p(xₜ)")
print()
print(" 替换扩散模型的得分函数:")
print(" -1/√(1-ᾱₜ) · ε_θ(xₜ,t,c) ≈ ∇_xₜ log p(xₜ|c)")
print()
print(" Classifier Guidance 的引导方向:")
print(" ∇_xₜ [log p(xₜ) + γ·log p(c|xₜ)]")
print(" = ∇_xₜ log p(xₜ) + γ·∇_xₜ log p(c|xₜ)")
print(" ≈ -1/√(1-ᾱₜ) · ε_uncond + γ·(-1/√(1-ᾱₜ))·(ε_cond-ε_uncond)")
print(" ≈ -1/√(1-ᾱₜ) · [ε_uncond + γ·(ε_cond-ε_uncond)]")
print()
print(" ✓ 这正是 CFG 公式!")
print(" ε_guided = ε_uncond + γ·(ε_cond - ε_uncond)")
print()
print(" 关键结论:")
print(" CFG 无需额外分类器,用无条件模型隐式充当分类器")
print(" γ-1 正好等于隐式分类器的引导强度")
derive_cfg_equivalence()
7.2 负面提示词(Negative Prompt)
# CFG 的扩展:负面提示词(Negative Prompt)
#
# 标准 CFG:
# ε_guided = ε_θ(xₜ, t, ∅) + γ · [ε_θ(xₜ, t, c_pos) - ε_θ(xₜ, t, ∅)]
#
# 带负面提示词的 CFG:
# 用负面提示词 c_neg 代替空条件 ∅
# ε_guided = ε_θ(xₜ, t, c_neg) + γ · [ε_θ(xₜ, t, c_pos) - ε_θ(xₜ, t, c_neg)]
#
# 直觉:
# "远离 c_neg 的方向" + γ × "从 c_neg 走向 c_pos"
# → 既避开负面内容,又朝向正面内容
def negative_prompt_demo():
print("负面提示词(Negative Prompt)的原理:")
print()
print(" 正面提示:'a beautiful landscape, high quality, detailed'")
print(" 负面提示:'blurry, ugly, bad anatomy, low quality, nsfw'")
print()
print(" CFG 公式(带负面提示词):")
print(" ε_guided = ε_θ(xₜ, c_neg) + γ × [ε_θ(xₜ, c_pos) - ε_θ(xₜ, c_neg)]")
print()
print(" 几何直觉:")
print(" ε(c_neg) → 远离'模糊、低质量'的方向")
print(" ε(c_pos) → 朝向'美丽风景、高质量'的方向")
print(" 组合:远离负面 + γ倍的(正面-负面)方向")
print()
print(" 常用负面提示词:")
negative_examples = [
"图像质量", "blurry, low quality, jpeg artifacts, noisy",
"人物质量", "bad anatomy, extra fingers, mutated hands, ugly face",
"内容过滤", "nsfw, explicit, violence, gore",
"风格排除", "cartoon, anime, 3d render(如果要求摄影风格)",
]
for category, prompt in zip(negative_examples[::2], negative_examples[1::2]):
print(f" [{category}]: {prompt}")
negative_prompt_demo()
八、完整条件采样代码
class ConditionalDDIMSampler:
"""
支持文本条件和 CFG 的 DDIM 采样器
核心特性:
1. 接受文本编码作为条件
2. 支持 CFG(guidance_scale > 1 时自动启用)
3. 支持负面提示词
4. 每步只需运行 1 次网络(批次维度展开技巧)
"""
def __init__(
self,
denoiser, # 条件化 U-Net(接受 context 参数)
alphas_cumprod: torch.Tensor,
num_train_steps: int = 1000,
device: str = "cpu",
):
self.denoiser = denoiser
self.alphas_cumprod = alphas_cumprod
self.num_train_steps = num_train_steps
self.device = device
def get_timesteps(self, num_inference_steps: int) -> list:
step_ratio = self.num_train_steps // num_inference_steps
return list(reversed(range(0, self.num_train_steps, step_ratio)))
def ddim_step(
self,
xt: torch.Tensor,
t_cur: int,
t_prev: int,
eps: torch.Tensor,
eta: float = 0.0,
) -> torch.Tensor:
ab_t = self.alphas_cumprod[t_cur]
ab_tp = self.alphas_cumprod[t_prev] if t_prev >= 0 else torch.ones(1)
x0_pred = (xt - (1 - ab_t).sqrt() * eps) / ab_t.sqrt()
x0_pred = x0_pred.clamp(-1, 1)
dir_xt = (1 - ab_tp - eta**2 * (1-ab_t)/ab_t * (ab_t-ab_tp)).clamp(0).sqrt() * eps
noise = eta * ((1-ab_tp)/(1-ab_t)).sqrt() * \
(1 - ab_t/ab_tp).clamp(0).sqrt() * torch.randn_like(xt)
return ab_tp.sqrt() * x0_pred + dir_xt + noise
@torch.no_grad()
def sample(
self,
shape: tuple,
text_embeddings: torch.Tensor, # (B, 77, context_dim) 正面提示词
neg_embeddings: torch.Tensor = None, # 负面提示词(None 则用空)
guidance_scale: float = 7.5,
num_steps: int = 50,
eta: float = 0.0,
init_noise: torch.Tensor = None,
verbose: bool = False,
) -> torch.Tensor:
"""
条件 DDIM 采样
工程技巧:
当 guidance_scale > 1 时,每步需要计算:
① ε_θ(xₜ, t, c_pos)(正面条件)
② ε_θ(xₜ, t, c_neg)(负面条件或空条件)
通常做法:把两者拼成一个 batch,一次前向传播同时得到两个结果
→ 比两次单独调用快约 1.5 倍(CUDA 更好地利用并行)
"""
self.denoiser.eval()
B = shape[0]
# 初始噪声
if init_noise is None:
xt = torch.randn(*shape, device=self.device)
else:
xt = init_noise.clone().to(self.device)
# 准备条件嵌入
# 负面提示词:如果未提供则使用全零(空条件)
if neg_embeddings is None:
neg_embeddings = torch.zeros_like(text_embeddings)
# ── CFG 批次展开技巧 ─────────────────────────────────
# 将正面和负面提示词拼成一个 batch,一次前向得到两个结果
# (2B, 77, context_dim)
batch_context = torch.cat([neg_embeddings, text_embeddings], dim=0)
timesteps = self.get_timesteps(num_steps)
for step_idx, t in enumerate(timesteps):
t_prev = timesteps[step_idx + 1] if step_idx + 1 < len(timesteps) else -1
t_tensor = torch.full((B,), t, dtype=torch.long, device=self.device)
if guidance_scale == 1.0:
# 无 CFG:只做有条件推理
eps_pred = self.denoiser(xt, t_tensor, context=text_embeddings)
else:
# CFG:一次前向同时计算有条件和无条件
# 把 xt 复制一份:(2B, C, H, W)
xt_doubled = torch.cat([xt, xt], dim=0)
t_doubled = torch.cat([t_tensor, t_tensor], dim=0)
# 一次前向传播
eps_both = self.denoiser(xt_doubled, t_doubled, context=batch_context)
# 分离负面和正面的预测
eps_neg, eps_pos = eps_both.chunk(2, dim=0)
# CFG 公式:引导方向 = 无条件 + γ·(有条件-无条件)
eps_pred = eps_neg + guidance_scale * (eps_pos - eps_neg)
xt = self.ddim_step(xt, t, t_prev, eps_pred, eta)
if verbose and step_idx % (num_steps // 5) == 0:
print(f" 步骤 {step_idx+1}/{num_steps}: "
f"t={t}, xt.std={xt.std():.4f}")
# 归一化到 [0, 1]
return (xt.clamp(-1, 1) + 1) / 2
# ── 用简单模型验证完整流程 ──────────────────────────────────────
class SimpleCFGDenoiser(nn.Module):
"""
极简的条件去噪网络(仅用于验证 CFG 框架)
实际应用中替换为第四篇的 U-Net
"""
def __init__(self, channels=3, hidden=32, context_dim=768):
super().__init__()
# 时间步嵌入
self.t_emb = nn.Sequential(
nn.Linear(1, hidden), nn.SiLU(), nn.Linear(hidden, hidden)
)
# 文本特征池化(简化)
self.ctx_pool = nn.Linear(context_dim, hidden)
# 主干网络
self.net = nn.Sequential(
nn.Conv2d(channels + hidden, hidden, 3, padding=1), nn.SiLU(),
nn.Conv2d(hidden, channels, 3, padding=1),
)
def forward(self, x, t, context=None):
B, C, H, W = x.shape
# 时间步嵌入
t_feat = self.t_emb(t.float().unsqueeze(1) / 1000) # (B, hidden)
if context is not None:
# 文本特征:池化 77 个词 → 一个向量
ctx_feat = self.ctx_pool(context.mean(dim=1)) # (B, hidden)
t_feat = t_feat + ctx_feat # 简单融合
# 拼接时间特征到图像
t_map = t_feat[:, :, None, None].expand(B, -1, H, W)
return self.net(torch.cat([x, t_map], dim=1))
# 创建并测试
T = 1000
betas = torch.linspace(1e-4, 0.02, T)
alphas_cumprod = torch.cumprod(1 - betas, 0)
denoiser_cfg = SimpleCFGDenoiser(channels=3, hidden=16, context_dim=768)
sampler = ConditionalDDIMSampler(
denoiser_cfg, alphas_cumprod, num_train_steps=T
)
# 模拟文本嵌入(真实应用中来自 CLIP)
text_emb = torch.randn(2, 77, 768) # (B=2, 77词, 768维)
neg_emb = torch.zeros(2, 77, 768) # 空负面提示词
print("条件 DDIM 采样验证:")
print()
print("测试 1:无 CFG(guidance_scale=1.0)")
img1 = sampler.sample(
shape=(2, 3, 16, 16),
text_embeddings=text_emb,
guidance_scale=1.0,
num_steps=10,
verbose=True,
)
print(f" 生成图像形状:{img1.shape},值域:[{img1.min():.3f}, {img1.max():.3f}]")
print()
print("测试 2:启用 CFG(guidance_scale=7.5)")
img2 = sampler.sample(
shape=(2, 3, 16, 16),
text_embeddings=text_emb,
neg_embeddings=neg_emb,
guidance_scale=7.5,
num_steps=10,
verbose=True,
)
print(f" 生成图像形状:{img2.shape},值域:[{img2.min():.3f}, {img2.max():.3f}]")
print()
print("测试 3:不同 guidance_scale 的输出差异")
for gs in [1.0, 3.0, 7.5, 15.0]:
torch.manual_seed(42)
img = sampler.sample(
shape=(1, 3, 16, 16),
text_embeddings=text_emb[:1],
guidance_scale=gs,
num_steps=5,
)
print(f" guidance_scale={gs:5.1f}: mean={img.mean():.4f}, std={img.std():.4f}")
总结
| 概念 | 说明 | 关键参数 |
|---|---|---|
| 文本编码 | CLIP/T5 把文字变成 (B,77,768) 的特征序列 | 编码器维度、最大长度 |
| 交叉注意力 | Q=图像特征,K/V=文本特征,图像"查询"文字 | query_dim, context_dim |
| Classifier Guidance | 用分类器梯度引导,需要额外分类器 | guidance_scale |
| CFG 公式 | ϵ g = ϵ u + γ ( ϵ c − ϵ u ) \epsilon_g = \epsilon_u + \gamma(\epsilon_c - \epsilon_u) ϵg=ϵu+γ(ϵc−ϵu) | guidance_scale γ |
| 负面提示词 | 用 c n e g c_{neg} cneg 代替空条件,双向控制生成方向 | negative_prompt |
| CFG 批次技巧 | 正负提示词合并成 2B,一次前向得两个结果 | 速度提升 ~1.5× |
guidance_scale 速查表:
- γ=1.0:无 CFG,完全随机
- γ=5.0:平衡多样性与对齐(艺术风格推荐)
- γ=7.5:SD 默认,写实摄影推荐 ⭐
- γ≥15:极强对齐,可能过饱和
下一篇预告:Stable Diffusion——潜在扩散模型(LDM),为什么要在"潜空间"做扩散而不是像素空间?VAE 的角色是什么?以及 ControlNet 如何通过附加条件(深度图、姿态图)实现精确的结构控制。
💬 你用 Stable Diffusion 时 guidance_scale 一般设多少?发现过什么规律? 欢迎评论区分享!
🙏 如果这篇帮到你,点赞 + 收藏,系列持续更新!
本文为原创技术分享。代码在 Python 3.12 + PyTorch 2.x + transformers 下验证。最后更新:2026-05-20
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐

所有评论(0)