【扩散模型系列·第八篇·收官】实战与前沿:LoRA 微调、DreamBooth、FLUX 与一致性模型
【扩散模型系列·第八篇·收官】实战与前沿:LoRA 微调、DreamBooth、FLUX 与一致性模型
作者:技术博主 | 更新时间:2026-05-20 | 阅读时长:约 25 分钟
系列:扩散模型从零到实战(共 8 篇)· 收官篇
环境:Python 3.12 + PyTorch 2.x + diffusers + peft
标签:扩散模型LoRADreamBoothFLUX一致性模型微调个性化生成前沿进展

🔥 本篇目标:系列的最后一篇,聚焦"实用"与"前沿"。实用部分:用 LoRA 微调让 SD 在 1 小时内学会你的风格或人物(只需 10-20 张图)。前沿部分:FLUX(2024 年最强开源文生图模型)的架构革新,以及一致性模型(Consistency Models)如何把采样步数压缩到 1-2 步。读完本篇,你对整个扩散模型的技术栈——从数学基础到最新前沿——将有完整的认识。
系列完整进度
| 篇次 | 主题 | 状态 |
|---|---|---|
| 第一篇 | 扩散模型是什么:从加噪到去噪的直觉 | ✅ |
| 第二篇 | 数学基础:前向过程与马尔可夫链 | ✅ |
| 第三篇 | 反向过程:学习去噪 | ✅ |
| 第四篇 | U-Net 架构:去噪网络的设计 | ✅ |
| 第五篇 | DDIM:加速采样 | ✅ |
| 第六篇 | 条件生成:文本引导与 CFG | ✅ |
| 第七篇 | Stable Diffusion:潜在扩散模型 | ✅ |
| 第八篇(本篇·收官) | 实战与前沿 | — |
目录
- 一、微调扩散模型的动机
- 二、LoRA:低秩适配,高效微调
- 三、DreamBooth:个性化生成
- 四、Textual Inversion:嵌入层微调
- 五、FLUX:2024 年的架构革新
- 六、一致性模型:1 步生成
- 七、系列收官:完整知识图谱
一、微调扩散模型的动机
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from dataclasses import dataclass
from typing import Optional
print("为什么需要微调扩散模型?")
print()
print(" 预训练 SD 的局限:")
print(" ① 不认识特定的人物(你自己、虚拟IP)")
print(" ② 不能精确生成特定风格(某个画家/插画师的独特笔触)")
print(" ③ 不熟悉特定物体(某款汽车、特定商品)")
print(" ④ 专业领域数据不足(医疗影像、工业设计)")
print()
print(" 三种微调方案的对比:")
print()
methods = {
"全量微调(Full Fine-tuning)": {
"可训练参数": "~860M(全部)",
"训练时间": "数天(A100)",
"显存需求": "~80GB",
"数据需求": "1000+ 张",
"灵活性": "★★★★★",
"适用": "从头定制,资源充足",
},
"DreamBooth": {
"可训练参数": "~860M(全部 U-Net)",
"训练时间": "15-30 分钟",
"显存需求": "~24GB",
"数据需求": "3-30 张",
"灵活性": "★★★★☆",
"适用": "特定人物/物体",
},
"Textual Inversion": {
"可训练参数": "1 个文本 token(~768维)",
"训练时间": "1-2 小时",
"显存需求": "~8GB",
"数据需求": "5-20 张",
"灵活性": "★★★☆☆",
"适用": "新概念嵌入",
},
"LoRA": {
"可训练参数": "~3M(<0.5%)",
"训练时间": "30-60 分钟",
"显存需求": "~12GB",
"数据需求": "10-50 张",
"灵活性": "★★★★★",
"适用": "风格/人物,最流行 ⭐",
},
}
for name, info in methods.items():
print(f" [{name}]")
for k, v in info.items():
print(f" {k:12s}: {v}")
print()
二、LoRA:低秩适配,高效微调
2.1 LoRA 的数学原理
# LoRA(Low-Rank Adaptation,Hu et al. 2021)
#
# 核心思想:
# 大模型的权重矩阵 W ∈ R^{m×n} 在微调时的变化 ΔW
# 通常具有低秩结构(intrinsic rank 很低)
#
# 因此:
# 不微调 W,而是学习 ΔW = B·A
# B ∈ R^{m×r},A ∈ R^{r×n},其中 r << min(m, n)
#
# 前向传播:
# h = (W + α/r · B·A) · x = W·x + α/r · B·A·x
#
# 初始化:A 用正态分布初始化,B 初始化为 0
# → 训练开始时 ΔW = 0,不影响原始模型
class LoRALinear(nn.Module):
"""
带 LoRA 的线性层
替换原始的 nn.Linear,添加低秩旁路
"""
def __init__(
self,
in_features: int,
out_features: int,
rank: int = 4, # LoRA 秩,通常 4-64
alpha: float = 1.0, # 缩放系数
dropout: float = 0.0,
):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.rank = rank
self.alpha = alpha
self.scaling = alpha / rank # 实际缩放:α/r
# 原始权重(冻结,不训练)
self.weight = nn.Parameter(
torch.randn(out_features, in_features), requires_grad=False
)
self.bias = nn.Parameter(torch.zeros(out_features), requires_grad=False)
# LoRA 低秩矩阵(可训练!)
self.lora_A = nn.Parameter(
torch.randn(rank, in_features) * 0.01 # 小值初始化
)
self.lora_B = nn.Parameter(
torch.zeros(out_features, rank) # 零初始化!保证 ΔW=0
)
self.lora_dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# 原始线性变换
result = F.linear(x, self.weight, self.bias)
# LoRA 旁路:x → A → B → scale
lora_out = self.lora_dropout(x) @ self.lora_A.T @ self.lora_B.T
lora_out = lora_out * self.scaling
return result + lora_out
@property
def effective_weight(self) -> torch.Tensor:
"""合并后的有效权重(推理时可以合并,加速推理)"""
return self.weight + self.scaling * (self.lora_B @ self.lora_A)
def merge_weights(self):
"""把 LoRA 权重合并回主权重(推理加速)"""
self.weight.data += self.scaling * (self.lora_B @ self.lora_A)
# 清空 LoRA 矩阵(已合并,不再需要)
self.lora_A.data.zero_()
self.lora_B.data.zero_()
# 验证 LoRA
lora_layer = LoRALinear(768, 768, rank=8, alpha=8.0)
# 统计参数量
total_params = sum(p.numel() for p in lora_layer.parameters())
trainable = sum(p.numel() for p in lora_layer.parameters() if p.requires_grad)
frozen = total_params - trainable
print("LoRA 层参数统计:")
print(f" 原始权重 W (768×768):{768*768:,} 个(冻结)")
print(f" LoRA A (r×768={8}×768):{8*768:,} 个(可训练)")
print(f" LoRA B (768×r=768×{8}):{768*8:,} 个(可训练)")
print(f" 可训练参数:{trainable:,}({100*trainable/total_params:.2f}%)")
print(f" 参数效率:{frozen:,} 个权重不需要梯度")
print()
# 验证:初始化时 LoRA 输出为 0(不影响原始模型)
x = torch.randn(4, 768)
# 原始权重的输出
with torch.no_grad():
original_out = F.linear(x, lora_layer.weight, lora_layer.bias)
lora_out = lora_layer(x)
delta = (lora_out - original_out).abs().mean().item()
print(f" 初始化时 LoRA 的额外输出:{delta:.8f}(应接近 0)✓")
print(f" (lora_B=0 → ΔW=B·A=0 → 不影响原始模型)")
2.2 在 SD 中应用 LoRA
def apply_lora_to_unet(unet: nn.Module, rank: int = 4, alpha: float = 4.0) -> dict:
"""
给 U-Net 的关键线性层应用 LoRA
标准做法:只对 Cross-Attention 层的 Q/K/V/Out 矩阵施加 LoRA
为什么选 Cross-Attention?
① Cross-Attention 连接文本和图像,最与"内容/风格"相关
② Self-Attention 改动影响全局结构,容易破坏原模型能力
③ Feed-Forward 矩阵也可加(有时加,有时不加)
"""
lora_layers = {}
lora_params = 0
total_params = 0
for name, module in unet.named_modules():
total_params += sum(p.numel() for p in module.parameters(recurse=False))
# 针对 Cross-Attention 的 Q、K、V、输出投影矩阵
if isinstance(module, nn.Linear):
# 判断是否在 Cross-Attention 块中(按名称过滤)
is_attn_proj = any(
key in name for key in ["to_q", "to_k", "to_v", "to_out"]
)
if is_attn_proj:
in_f = module.in_features
out_f = module.out_features
# 用 LoRALinear 替换
lora = LoRALinear(in_f, out_f, rank=rank, alpha=alpha)
lora.weight.data = module.weight.data.clone()
if module.bias is not None:
lora.bias.data = module.bias.data.clone()
lora_layers[name] = lora
lora_params += 2 * rank * in_f # A + B 的参数
print(f"LoRA 应用统计:")
print(f" 替换的层数:{len(lora_layers)}")
print(f" LoRA 可训练参数:{lora_params:,}")
print(f" 参数效率:{100 * lora_params / max(total_params, 1):.4f}%")
return lora_layers
# 统计不同 rank 的参数量
print("\n不同 rank 的 LoRA 参数量(以 SD 1.5 U-Net 为例,~180个 Cross-Attn 投影层):")
print(f"{'rank':^8} {'LoRA参数量':^16} {'比例':^12} {'文件大小':^12}")
print("─" * 50)
n_attn_layers = 180 # SD 1.5 的近似数量
avg_dim = 768 # 平均维度
for r in [1, 2, 4, 8, 16, 32, 64]:
lora_params = n_attn_layers * 2 * r * avg_dim
total = 860e6 # SD 1.5 总参数
ratio = lora_params / total * 100
filesize_mb = lora_params * 4 / 1e6 # FP32
print(f" {r:^8} {lora_params:>12,} {ratio:^12.4f}% {filesize_mb:>8.1f} MB")
print()
print(" → rank=4:只有 ~4M 参数,文件仅 ~16MB!")
print(" → 整个 SD 模型是 ~3.5GB,LoRA 仅 0.5%")
2.3 LoRA 训练循环
class LoRATrainer:
"""
LoRA 微调训练器
适用于 Stable Diffusion 的 U-Net 组件
"""
def __init__(
self,
unet: nn.Module,
vae,
text_encoder,
alphas_cumprod: torch.Tensor,
lora_rank: int = 4,
lora_alpha: float = 4.0,
learning_rate: float = 1e-4,
device: str = "cpu",
):
self.device = device
self.alphas_cumprod = alphas_cumprod
# 冻结所有基础模型
self.vae = vae.eval().to(device)
self.text_encoder = text_encoder.eval().to(device)
self.unet = unet.to(device)
for p in self.vae.parameters():
p.requires_grad = False
for p in self.text_encoder.parameters():
p.requires_grad = False
for p in self.unet.parameters():
p.requires_grad = False # 先全冻结
# 注入 LoRA 层(只解冻 LoRA 参数)
self.lora_layers = self._inject_lora(lora_rank, lora_alpha)
# 只优化 LoRA 参数
lora_params = [p for p in self.unet.parameters() if p.requires_grad]
self.optimizer = torch.optim.AdamW(
lora_params,
lr=learning_rate,
weight_decay=1e-2,
)
n_trainable = sum(p.numel() for p in lora_params)
n_total = sum(p.numel() for p in self.unet.parameters())
print(f"LoRA 初始化完成:")
print(f" 可训练参数:{n_trainable:,} / {n_total:,} "
f"({100*n_trainable/n_total:.3f}%)")
def _inject_lora(self, rank: int, alpha: float) -> list:
"""向 U-Net 的 Cross-Attention 层注入 LoRA"""
lora_layers = []
for name, module in self.unet.named_modules():
if isinstance(module, nn.Linear) and any(
k in name for k in ["to_q", "to_k", "to_v", "to_out.0"]
):
# 把原 Linear 替换为 LoRALinear
parent = self._get_parent(name)
attr = name.split(".")[-1]
in_f = module.in_features
out_f = module.out_features
lora_layer = LoRALinear(in_f, out_f, rank=rank, alpha=alpha)
lora_layer.weight.data = module.weight.data.clone()
if module.bias is not None:
lora_layer.bias.data = module.bias.data.clone()
setattr(parent, attr, lora_layer)
lora_layers.append(lora_layer)
return lora_layers
def _get_parent(self, name: str):
"""获取模块的父模块"""
parts = name.split(".")
parent = self.unet
for part in parts[:-1]:
parent = getattr(parent, part)
return parent
def train_step(
self,
pixel_values: torch.Tensor, # 训练图像
text_features: torch.Tensor, # 文本嵌入
) -> float:
"""单步训练"""
self.unet.train()
self.optimizer.zero_grad()
B = pixel_values.shape[0]
# 1. 编码到潜空间(不计算梯度)
with torch.no_grad():
latents = self.vae.encode(pixel_values)
# 缩放(假设 scaling_factor=0.18215)
latents = latents * 0.18215
# 2. 随机时间步和噪声
t = torch.randint(0, 1000, (B,), device=self.device)
noise = torch.randn_like(latents)
# 3. 前向加噪
ab = self.alphas_cumprod[t].reshape(-1, 1, 1, 1)
xt = ab.sqrt() * latents + (1 - ab).sqrt() * noise
# 4. 预测噪声(LoRA 参数参与梯度计算)
pred_noise = self.unet(xt, t, context=text_features)
# 5. 损失(与标准 DDPM 相同)
loss = F.mse_loss(pred_noise, noise)
loss.backward()
torch.nn.utils.clip_grad_norm_(
[p for p in self.unet.parameters() if p.requires_grad],
max_norm=1.0
)
self.optimizer.step()
return loss.item()
def save_lora(self, path: str):
"""保存 LoRA 权重(只保存可训练参数,文件很小!)"""
lora_state = {
name: param
for name, param in self.unet.named_parameters()
if param.requires_grad
}
torch.save(lora_state, path)
filesize = sum(p.numel() * 4 for p in lora_state.values()) / 1e6
print(f"LoRA 保存完成:{path}({filesize:.1f} MB)")
# 演示 LoRA 训练配置
print("\nLoRA 训练推荐配置:")
configs = {
"人物风格(写实)": {
"rank": 8, "alpha": 8.0, "lr": "1e-4",
"steps": "1000-2000", "data": "15-30 张(多角度)",
"instance_prompt": '"a photo of sks person"',
"class_prompt": '"a photo of a person"',
},
"画风/艺术风格": {
"rank": 16, "alpha": 16.0, "lr": "1e-4",
"steps": "2000-3000", "data": "20-50 张(同风格)",
"instance_prompt": '"painting in sks style"',
"class_prompt": '"a painting"',
},
"特定物体": {
"rank": 4, "alpha": 4.0, "lr": "5e-5",
"steps": "500-1000", "data": "5-20 张(多视角)",
"instance_prompt": '"a photo of sks object"',
"class_prompt": '"a photo of an object"',
},
}
for task, cfg in configs.items():
print(f" [{task}]")
for k, v in cfg.items():
print(f" {k:20s}: {v}")
print()
三、DreamBooth:个性化生成
# DreamBooth(Ruiz et al., 2022,Google Research)
# 目标:让 SD 认识特定人物/物体,能生成"他在各种场景下"的图像
#
# 关键技术:
# ① 使用稀有词 "sks"(或其他稀有词)作为绑定词
# → 最小化与已有概念的干扰
# ② Prior Preservation Loss:防止"遗忘"原有概念
# → 同时用"a photo of a person"生成伪图,参与训练
# → 避免训练后 SD 只会画特定的人,忘了"人"的概念
def explain_dreambooth():
print("DreamBooth 训练流程:")
print()
print(" 数据准备:")
print(" ├── instance images(你的图): 15 张不同角度的照片")
print(" │ instance prompt: 'a photo of [sks] person'")
print(" └── class images(先验图): 200 张 SD 自己生成的")
print(" class prompt: 'a photo of a person'")
print()
print(" 训练目标:")
print(" L = L_instance + λ·L_prior")
print()
print(" L_instance:让模型能根据 'sks person' 生成你的脸")
print(" L_prior: 让模型保持对'person'概念的一般认知")
print(" 防止过拟合(catastrophic forgetting)")
print()
print(" λ(prior_loss_weight)通常设为 1.0")
print()
print(" DreamBooth vs LoRA 的选择:")
print()
differences = [
("可训练参数", "全部 U-Net(~860M)", "只有低秩矩阵(~3M)"),
("训练时间", "~15-30 分钟", "~30-60 分钟"),
("显存", "~24GB(A100)", "~12GB(消费级可)"),
("生成质量", "★★★★★(通常更好)", "★★★★☆"),
("文件大小", "~3.5GB(完整模型)", "~16-150MB"),
("可组合性", "低(独立模型)", "高(多 LoRA 叠加)"),
("开源社区", "较少(太重)", "CivitAI 数万个"),
]
print(f" {'指标':^16} {'DreamBooth':^22} {'LoRA':^22}")
print(" " + "─" * 62)
for metric, db, lora in differences:
print(f" {metric:^16} {db:^22} {lora:^22}")
print()
print(" 实践建议:")
print(" 资源有限(消费级 GPU)→ LoRA")
print(" 追求最高质量(A100)→ DreamBooth + LoRA(DreamBooth-LoRA)")
print(" 社区分享 → LoRA(CivitAI 的主流格式)")
explain_dreambooth()
四、Textual Inversion:嵌入层微调
# Textual Inversion(Gal et al., 2022,NVIDIA)
# 最轻量的个性化方法:
# 只学习一个新的文本 token 嵌入向量(~768 维)
# 其他参数完全冻结
class TextualInversionEmbedding(nn.Module):
"""
Textual Inversion:学习新的文本 token 嵌入
原理:
在 CLIP 的词汇表中添加一个新词 '*'(或 'sks')
只训练这个词的嵌入向量,使 CLIP 能用它描述新概念
"""
def __init__(
self,
text_encoder: nn.Module, # CLIP 文本编码器
placeholder: str = "*", # 占位符 token
init_token: str = "artwork", # 初始化参考词
embedding_dim: int = 768,
):
super().__init__()
self.text_encoder = text_encoder
self.placeholder = placeholder
# 新 token 的嵌入(这是唯一可训练的参数!)
self.token_embedding = nn.Parameter(
torch.randn(1, embedding_dim) * 0.01
)
# 冻结 text encoder 的所有其他参数
for param in self.text_encoder.parameters():
param.requires_grad = False
def forward(self, tokens: torch.Tensor) -> torch.Tensor:
"""
将特殊 token 替换为学到的嵌入
其他 token 使用原始 CLIP 嵌入
"""
# 实际实现中:
# 找到 placeholder token 的位置
# 用 self.token_embedding 替换对应位置的嵌入
# 其余位置使用原始 CLIP 词嵌入
pass # 完整实现需要 CLIP 内部访问
# 统计:Textual Inversion 的参数量极少
embedding_dim = 768
n_params = embedding_dim
print(f"Textual Inversion 参数量:{n_params} 个(!)")
print(f"文件大小:{n_params * 4 / 1024:.1f} KB(极小!)")
print()
print("优点:文件极小,可以分享为 '.pt' 嵌入文件")
print("缺点:表达能力有限,生成质量不如 LoRA 和 DreamBooth")
print()
print("典型使用场景:")
print(" 学习一种美学风格(如'印象派油画感')")
print(" 学习一种情绪/氛围(如'赛博朋克感')")
print(" 在 SD 中快速引入小概念")
五、FLUX:2024 年的架构革新
# FLUX(Black Forest Labs,2024年8月)
# 由 Stable Diffusion 的核心创始团队打造
# 在文字渲染、指令遵循、图像质量上全面超越 SD XL
def explain_flux_architecture():
print("FLUX 的核心架构创新:")
print()
print(" 与 Stable Diffusion 的主要区别:")
print()
innovations = [
("模型架构",
"SD:U-Net",
"FLUX:Diffusion Transformer(DiT)\n 摒弃 U-Net,用纯 Transformer 结构"),
("文本编码器",
"SD:CLIP(77 token)",
"FLUX:CLIP + T5-XXL(512 token)\n 更强的长文本理解"),
("流匹配训练",
"SD:DDPM/DDIM(噪声预测)",
"FLUX:Flow Matching(速度预测)\n 训练更稳定,收敛更快"),
("图像 token 化",
"SD:VAE 潜空间 4×64×64",
"FLUX:VAE 潜空间 16×128×128\n 更高压缩比,更高分辨率"),
("注意力机制",
"SD:Self-Attn + Cross-Attn 分离",
"FLUX:双流 Transformer\n 图像 token 和文本 token 共同注意力"),
("文字渲染",
"SD:极差(文字经常变形)",
"FLUX:优秀(能正确渲染多行文字)"),
]
for aspect, sd, flux in innovations:
print(f" [{aspect}]")
print(f" SD: {sd}")
print(f" FLUX: {flux}")
print()
print(" FLUX 的三个版本:")
print()
versions = [
("FLUX.1 [pro]", "最高质量", "闭源,API 付费", "12B 参数"),
("FLUX.1 [dev]", "高质量", "半开源(非商用)", "12B 参数"),
("FLUX.1 [schnell]", "快速版", "Apache 2.0 开源", "12B 参数,4步"),
]
for name, quality, license, size in versions:
print(f" {name:26s}: {quality:8s} | {license:20s} | {size}")
explain_flux_architecture()
def explain_flow_matching():
"""解释 FLUX 使用的 Flow Matching 训练目标"""
print("\nFlow Matching vs DDPM 的训练目标对比:")
print()
print(" DDPM(第二、三篇介绍的):")
print(" 前向过程:xₜ = √ᾱₜ·x₀ + √(1-ᾱₜ)·ε")
print(" 训练目标:预测噪声 ε")
print(" 采样:反向 SDE(随机微分方程)")
print()
print(" Flow Matching(Lipman et al., 2022):")
print(" 前向过程:xₜ = (1-t)·x₀ + t·ε,t ∈ [0,1]")
print(" (简单线性插值!)")
print(" 训练目标:预测速度向量 v = ε - x₀")
print(" 采样:ODE(确定性常微分方程)")
print()
print(" Flow Matching 的优势:")
print(" ① 更简单的前向过程(线性插值)")
print(" ② 训练更稳定(梯度不受 ᾱₜ 的非线性影响)")
print(" ③ 采样轨迹更直(ODE 路径近似直线)")
print(" ④ 更少步数就能得到好结果")
print()
print(" 直觉类比:")
print(" DDPM:从 A 到 B 走一条复杂的曲线路")
print(" Flow Matching:走一条近似直线,更短更快")
explain_flow_matching()
5.2 DiT(Diffusion Transformer)架构
class DiTBlock(nn.Module):
"""
Diffusion Transformer Block(FLUX 使用的基本单元)
参考:Peebles & Xie (2023)
与 U-Net 的最大区别:纯 Transformer,无卷积,无 U 形结构
双流设计(FLUX 特有):
图像 token 序列 + 文本 token 序列 → 各自的注意力 + 联合注意力
"""
def __init__(
self,
hidden_dim: int = 1024,
num_heads: int = 16,
mlp_ratio: float = 4.0,
):
super().__init__()
self.num_heads = num_heads
self.head_dim = hidden_dim // num_heads
self.scale = self.head_dim ** -0.5
mlp_dim = int(hidden_dim * mlp_ratio)
# ── 图像 token 流 ────────────────────────────────────
self.img_norm1 = nn.LayerNorm(hidden_dim, elementwise_affine=False)
self.img_attn = nn.MultiheadAttention(hidden_dim, num_heads, batch_first=True)
self.img_norm2 = nn.LayerNorm(hidden_dim, elementwise_affine=False)
self.img_ff = nn.Sequential(
nn.Linear(hidden_dim, mlp_dim),
nn.GELU(approximate="tanh"),
nn.Linear(mlp_dim, hidden_dim),
)
# ── 文本 token 流 ────────────────────────────────────
self.txt_norm1 = nn.LayerNorm(hidden_dim, elementwise_affine=False)
self.txt_attn = nn.MultiheadAttention(hidden_dim, num_heads, batch_first=True)
self.txt_norm2 = nn.LayerNorm(hidden_dim, elementwise_affine=False)
self.txt_ff = nn.Sequential(
nn.Linear(hidden_dim, mlp_dim),
nn.GELU(approximate="tanh"),
nn.Linear(mlp_dim, hidden_dim),
)
# ── 时间步调制(AdaLN-Zero)────────────────────────
# 时间步和文本 → 6 个调制参数(shift,scale × 2 + gate × 2)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_dim, 6 * hidden_dim),
)
def forward(
self,
img: torch.Tensor, # 图像 token (B, N_img, D)
txt: torch.Tensor, # 文本 token (B, N_txt, D)
timestep: torch.Tensor, # 时间步嵌入 (B, D)
) -> tuple:
# 从时间步获取 6 个调制参数
mod = self.adaLN_modulation(timestep).chunk(6, dim=-1)
shift1_img, scale1_img, gate1_img, shift2_img, scale2_img, gate2_img = mod
# ── 图像 token:注意力 ────────────────────────────
h_img = self.img_norm1(img)
# AdaLN 调制:缩放 + 偏移
h_img = h_img * (1 + scale1_img.unsqueeze(1)) + shift1_img.unsqueeze(1)
h_img_attn, _ = self.img_attn(h_img, h_img, h_img)
img = img + gate1_img.unsqueeze(1) * h_img_attn
# ── 图像 token:FFN ───────────────────────────────
h_img = self.img_norm2(img)
h_img = h_img * (1 + scale2_img.unsqueeze(1)) + shift2_img.unsqueeze(1)
img = img + gate2_img.unsqueeze(1) * self.img_ff(h_img)
# ── 文本 token:(简化,实际 FLUX 做联合注意力)──
txt = txt + self.txt_attn(self.txt_norm1(txt),
self.txt_norm1(txt),
self.txt_norm1(txt))[0]
txt = txt + self.txt_ff(self.txt_norm2(txt))
return img, txt
# 验证 DiT Block
B, N_img, N_txt, D = 2, 1024, 77, 256 # 简化版
dit_block = DiTBlock(hidden_dim=D, num_heads=8)
img_tokens = torch.randn(B, N_img, D)
txt_tokens = torch.randn(B, N_txt, D)
time_embed = torch.randn(B, D)
img_out, txt_out = dit_block(img_tokens, txt_tokens, time_embed)
print("DiT Block 验证:")
print(f" 图像 token 输入:{img_tokens.shape}")
print(f" 文本 token 输入:{txt_tokens.shape}")
print(f" 时间步嵌入: {time_embed.shape}")
print(f" 图像 token 输出:{img_out.shape}")
print(f" 文本 token 输出:{txt_out.shape}")
print(f" 参数量:{sum(p.numel() for p in dit_block.parameters()):,}")
六、一致性模型:1 步生成
# Consistency Models(Song et al., 2023,OpenAI)
# 目标:1-2 步就能生成高质量图像(比 DDIM 快 10-50 倍)
#
# 核心思想:
# 在同一条 ODE 轨迹上的所有点(x₀, x₁, ..., x_T)
# 应该都映射到同一个 x₀
# 训练一个"一致性函数" f_θ:任意 xₜ → x₀(一步还原!)
#
# 两种训练方式:
# Consistency Distillation(CD):从已有扩散模型蒸馏
# Consistency Training(CT):从头训练
def explain_consistency_models():
print("一致性模型(Consistency Models)核心思想:")
print()
print(" 传统扩散模型(DDIM):")
print(" x_T → x_{900} → x_{800} → ... → x₀")
print(" 每步都是 U-Net,50步 = 50次推理")
print()
print(" 一致性模型:")
print(" 定义:一致性函数 f_θ(xₜ, t) = x₀")
print(" ↑ 无论从轨迹的哪个点出发,直接跳回 x₀!")
print()
print(" 采样:x_T → f_θ(x_T, T) = x̂₀(1步!)")
print()
print(" 多步采样(更高质量):")
print(" x_T → x̂₀ → 加少量噪声到 xₜ₁ → x̂₀ → ...")
print(" 每次添加噪声后再去噪,类似'打磨'效果")
print()
# 一致性函数的形式
print(" 一致性函数的参数化:")
print(" f_θ(xₜ, t) = c_skip(t)·xₜ + c_out(t)·F_θ(xₜ, t)")
print()
print(" c_skip(t):当 t→0 时,f→xₜ(边界条件)")
print(" c_out(t): 控制网络输出的权重")
print(" F_θ: 实际的神经网络(与 U-Net 结构相同)")
print()
print(" 各方法的速度质量对比:")
print()
comparison = [
("DDPM(1000步)", 1000, "★★★★★", "1×(基准)"),
("DDIM(50步)", 50, "★★★★☆", "20×"),
("DPM-Solver++(20步)", 20, "★★★★★", "50×"),
("LCM(4步)", 4, "★★★☆☆", "250×"),
("Consistency Model(2步)",2, "★★★☆☆", "500×"),
("Consistency Model(1步)",1, "★★★☆☆", "1000×"),
("SDXL-Turbo(1步)", 1, "★★★★☆", "1000×(质量更好)"),
]
print(f" {'方法':^30} {'步数':^6} {'质量':^12} {'速度':^10}")
print(" " + "─" * 62)
for name, steps, quality, speed in comparison:
print(f" {name:30s} {steps:^6} {quality:^12} {speed:^10}")
print()
print(" 实践建议:")
print(" 普通生成:DPM-Solver++ 20步(质量最优)")
print(" 实时交互:LCM 或 SDXL-Turbo(1-4步)")
print(" 研究极限:Consistency Models(1步)")
explain_consistency_models()
class ConsistencyFunction(nn.Module):
"""
一致性函数的核心结构
f_θ(xₜ, t) = c_skip(t)·xₜ + c_out(t)·F_θ(xₜ, t)
"""
def __init__(self, base_network: nn.Module, sigma_data: float = 0.5):
super().__init__()
self.F_theta = base_network # 底层网络(可以是 U-Net 或 DiT)
self.sigma_data = sigma_data # 数据的标准差
def c_skip(self, t: torch.Tensor) -> torch.Tensor:
"""边界条件系数:t→0 时 c_skip→1,网络输出被忽略"""
return self.sigma_data ** 2 / (t ** 2 + self.sigma_data ** 2)
def c_out(self, t: torch.Tensor) -> torch.Tensor:
"""网络输出的缩放系数"""
return t * self.sigma_data / (t ** 2 + self.sigma_data ** 2) ** 0.5
def forward(
self,
xt: torch.Tensor, # 加噪输入
t: torch.Tensor, # 时间步(这里是连续值)
context: torch.Tensor = None,
) -> torch.Tensor:
"""
一步还原 x₀
无论 t 多大(噪声多少),输出都应该是对 x₀ 的估计
"""
# 各样本的缩放系数
cs = self.c_skip(t).reshape(-1, 1, 1, 1)
co = self.c_out(t).reshape(-1, 1, 1, 1)
# 网络预测
F_out = self.F_theta(xt, t, context) if context is not None \
else self.F_theta(xt, t)
# 一致性函数:边界项 + 网络项
return cs * xt + co * F_out
七、系列收官:完整知识图谱
def print_series_summary():
"""打印系列完整知识图谱"""
print("=" * 70)
print(" 扩散模型从零到实战系列:完整知识图谱")
print("=" * 70)
print()
chapters = [
("第一篇", "扩散模型是什么",
["前向过程(加噪)与反向过程(去噪)的直觉",
"与 GAN/VAE 的对比:为什么扩散模型赢了",
"DDPM 的三个核心组件",
"diffusers 代码初体验"]),
("第二篇", "前向过程数学",
["马尔可夫链:q(xₜ|xₜ₋₁) = N(√(1-β)·xₜ₋₁, βI)",
"⭐ 重参数化:xₜ = √ᾱₜ·x₀ + √(1-ᾱₜ)·ε(直跳任意步!)",
"三种噪声调度:linear / cosine / scaled_linear",
"信噪比(SNR)与 Min-SNR 加权损失"]),
("第三篇", "反向过程与训练目标",
["后验分布 q(xₜ₋₁|xₜ,x₀) 的解析形式",
"⭐ ELBO → 简化 MSE 损失:L = ||ε - ε_θ(xₜ,t)||²",
"三种参数化:预测 ε / x₀ / v",
"完整 DDPMTrainer 和 DDPMSampler 实现"]),
("第四篇", "U-Net 去噪网络",
["正弦时间步嵌入(Sinusoidal + MLP)",
"⭐ ResBlock + AdaGN(时间步注入 scale+shift)",
"SelfAttention(仅在 ≤16×16 分辨率加)",
"U 形结构:编码器 + 跳跃连接 + 解码器"]),
("第五篇", "DDIM 加速采样",
["非马尔可夫过程:边缘分布与 DDPM 相同",
"子序列采样:1000步→50步,速度 20×",
"⭐ η=0 确定性采样:相同噪声→相同图像",
"DDIM Inversion:图像反演,精确编辑基础"]),
("第六篇", "条件生成与 CFG",
["CLIP 文本编码:文字 → (B,77,768) 序列",
"⭐ 交叉注意力:Q=图像,K/V=文本",
"⭐ CFG 公式:ε_g = ε_u + γ·(ε_c - ε_u)",
"负面提示词、guidance_scale=7.5 推荐"]),
("第七篇", "Stable Diffusion / LDM",
["VAE 编码器/解码器:像素↔潜空间(压缩 48×)",
"⭐ 潜空间扩散:速度 30-50×,显存 1/48",
"三条推理管线:text2img / img2img / inpainting",
"ControlNet:Zero Conv 技巧,精确结构控制"]),
("第八篇(本篇)", "实战与前沿",
["⭐ LoRA:低秩适配,<0.5% 参数,30分钟微调",
"DreamBooth:个性化生成,Prior Preservation Loss",
"FLUX:DiT 架构 + Flow Matching + T5-XXL",
"一致性模型:1步生成,ODE 轨迹自洽"]),
]
for chapter, title, points in chapters:
print(f" ┌─ {chapter}:{title}")
for i, point in enumerate(points):
last = i == len(points) - 1
prefix = " └─── " if last else " ├─── "
print(f" {prefix}{point}")
print()
print()
print(" 贯穿全系列的三个核心公式:")
print()
print(" ① 前向过程(重参数化):")
print(" xₜ = √ᾱₜ · x₀ + √(1-ᾱₜ) · ε")
print()
print(" ② 训练目标(简化 MSE):")
print(" L = E[||ε - ε_θ(√ᾱₜ·x₀ + √(1-ᾱₜ)·ε, t)||²]")
print()
print(" ③ Classifier-Free Guidance:")
print(" ε_guided = ε_uncond + γ · (ε_cond - ε_uncond)")
print()
print("=" * 70)
print(" 感谢你跟随这个系列走完了扩散模型的完整旅程!")
print(" 从第一篇的物理直觉,到最后的 LoRA 实战和 FLUX 前沿,")
print(" 希望你不只是学会了用,更理解了为什么。")
print("=" * 70)
print_series_summary()
总结:微调方案选择指南
| 方案 | 参数量 | 时间 | 显存 | 数据 | 最佳适用 |
|---|---|---|---|---|---|
| LoRA ⭐ | ~3M(0.5%) | 30-60 min | ~12GB | 10-50张 | 风格/人物,最流行 |
| DreamBooth | ~860M(全量) | 15-30 min | ~24GB | 3-30张 | 高质量人物,资源充足 |
| Textual Inversion | ~768 维 | 1-2 h | ~8GB | 5-20张 | 轻量概念嵌入 |
| DreamBooth-LoRA | ~3M | 20-40 min | ~12GB | 5-30张 | 最佳平衡 ⭐⭐ |
2024-2026 前沿趋势:
- FLUX 取代 SDXL 成为新主流(文字渲染、指令遵循大幅提升)
- 实时生成(LCM/Turbo/Consistency)进入主流产品
- 视频扩散(Sora、CogVideoX)成为新战场
- 多模态条件(IP-Adapter、InstantID、ControlNet++)让控制更精确
💬 跟完了八篇系列,你现在最想动手实现的是哪个部分?LoRA 微调自己的人物,还是从零复现一个 DDPM? 欢迎评论区分享!
🙏 「扩散模型从零到实战」系列(八篇)完结撒花!从物理直觉到数学推导,从代码实现到最新前沿,如果整个系列帮到你,最后一次三连(点赞👍 + 收藏⭐ + 关注)!感谢一路相伴!
本文为原创技术分享。代码在 Python 3.12 + PyTorch 2.x + diffusers 0.27 下验证。最后更新:2026-05-20
扩散模型从零到实战系列(八篇)完结 🎉
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐

所有评论(0)