【扩散模型系列·第八篇·收官】实战与前沿:LoRA 微调、DreamBooth、FLUX 与一致性模型

作者:技术博主 | 更新时间:2026-05-20 | 阅读时长:约 25 分钟
系列:扩散模型从零到实战(共 8 篇)· 收官篇
环境:Python 3.12 + PyTorch 2.x + diffusers + peft
标签扩散模型 LoRA DreamBooth FLUX 一致性模型 微调 个性化生成 前沿进展


在这里插入图片描述

🔥 本篇目标:系列的最后一篇,聚焦"实用"与"前沿"。实用部分:用 LoRA 微调让 SD 在 1 小时内学会你的风格或人物(只需 10-20 张图)。前沿部分:FLUX(2024 年最强开源文生图模型)的架构革新,以及一致性模型(Consistency Models)如何把采样步数压缩到 1-2 步。读完本篇,你对整个扩散模型的技术栈——从数学基础到最新前沿——将有完整的认识。


系列完整进度

篇次 主题 状态
第一篇 扩散模型是什么:从加噪到去噪的直觉
第二篇 数学基础:前向过程与马尔可夫链
第三篇 反向过程:学习去噪
第四篇 U-Net 架构:去噪网络的设计
第五篇 DDIM:加速采样
第六篇 条件生成:文本引导与 CFG
第七篇 Stable Diffusion:潜在扩散模型
第八篇(本篇·收官) 实战与前沿

目录


一、微调扩散模型的动机

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from dataclasses import dataclass
from typing import Optional

print("为什么需要微调扩散模型?")
print()
print("  预训练 SD 的局限:")
print("  ① 不认识特定的人物(你自己、虚拟IP)")
print("  ② 不能精确生成特定风格(某个画家/插画师的独特笔触)")
print("  ③ 不熟悉特定物体(某款汽车、特定商品)")
print("  ④ 专业领域数据不足(医疗影像、工业设计)")
print()
print("  三种微调方案的对比:")
print()

methods = {
    "全量微调(Full Fine-tuning)": {
        "可训练参数": "~860M(全部)",
        "训练时间": "数天(A100)",
        "显存需求": "~80GB",
        "数据需求": "1000+ 张",
        "灵活性": "★★★★★",
        "适用": "从头定制,资源充足",
    },
    "DreamBooth": {
        "可训练参数": "~860M(全部 U-Net)",
        "训练时间": "15-30 分钟",
        "显存需求": "~24GB",
        "数据需求": "3-30 张",
        "灵活性": "★★★★☆",
        "适用": "特定人物/物体",
    },
    "Textual Inversion": {
        "可训练参数": "1 个文本 token(~768维)",
        "训练时间": "1-2 小时",
        "显存需求": "~8GB",
        "数据需求": "5-20 张",
        "灵活性": "★★★☆☆",
        "适用": "新概念嵌入",
    },
    "LoRA": {
        "可训练参数": "~3M(<0.5%)",
        "训练时间": "30-60 分钟",
        "显存需求": "~12GB",
        "数据需求": "10-50 张",
        "灵活性": "★★★★★",
        "适用": "风格/人物,最流行 ⭐",
    },
}

for name, info in methods.items():
    print(f"  [{name}]")
    for k, v in info.items():
        print(f"    {k:12s}: {v}")
    print()

二、LoRA:低秩适配,高效微调

2.1 LoRA 的数学原理

# LoRA(Low-Rank Adaptation,Hu et al. 2021)
#
# 核心思想:
# 大模型的权重矩阵 W ∈ R^{m×n} 在微调时的变化 ΔW
# 通常具有低秩结构(intrinsic rank 很低)
#
# 因此:
# 不微调 W,而是学习 ΔW = B·A
# B ∈ R^{m×r},A ∈ R^{r×n},其中 r << min(m, n)
#
# 前向传播:
# h = (W + α/r · B·A) · x = W·x + α/r · B·A·x
#
# 初始化:A 用正态分布初始化,B 初始化为 0
# → 训练开始时 ΔW = 0,不影响原始模型

class LoRALinear(nn.Module):
    """
    带 LoRA 的线性层
    替换原始的 nn.Linear,添加低秩旁路
    """

    def __init__(
        self,
        in_features:   int,
        out_features:  int,
        rank:          int   = 4,    # LoRA 秩,通常 4-64
        alpha:         float = 1.0,  # 缩放系数
        dropout:       float = 0.0,
    ):
        super().__init__()
        self.in_features  = in_features
        self.out_features = out_features
        self.rank         = rank
        self.alpha        = alpha
        self.scaling      = alpha / rank   # 实际缩放:α/r

        # 原始权重(冻结,不训练)
        self.weight = nn.Parameter(
            torch.randn(out_features, in_features), requires_grad=False
        )
        self.bias   = nn.Parameter(torch.zeros(out_features), requires_grad=False)

        # LoRA 低秩矩阵(可训练!)
        self.lora_A = nn.Parameter(
            torch.randn(rank, in_features) * 0.01  # 小值初始化
        )
        self.lora_B = nn.Parameter(
            torch.zeros(out_features, rank)          # 零初始化!保证 ΔW=0
        )

        self.lora_dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # 原始线性变换
        result = F.linear(x, self.weight, self.bias)

        # LoRA 旁路:x → A → B → scale
        lora_out = self.lora_dropout(x) @ self.lora_A.T @ self.lora_B.T
        lora_out = lora_out * self.scaling

        return result + lora_out

    @property
    def effective_weight(self) -> torch.Tensor:
        """合并后的有效权重(推理时可以合并,加速推理)"""
        return self.weight + self.scaling * (self.lora_B @ self.lora_A)

    def merge_weights(self):
        """把 LoRA 权重合并回主权重(推理加速)"""
        self.weight.data += self.scaling * (self.lora_B @ self.lora_A)
        # 清空 LoRA 矩阵(已合并,不再需要)
        self.lora_A.data.zero_()
        self.lora_B.data.zero_()


# 验证 LoRA
lora_layer = LoRALinear(768, 768, rank=8, alpha=8.0)

# 统计参数量
total_params = sum(p.numel() for p in lora_layer.parameters())
trainable    = sum(p.numel() for p in lora_layer.parameters() if p.requires_grad)
frozen       = total_params - trainable

print("LoRA 层参数统计:")
print(f"  原始权重 W (768×768):{768*768:,} 个(冻结)")
print(f"  LoRA A (r×768={8}×768):{8*768:,} 个(可训练)")
print(f"  LoRA B (768×r=768×{8}):{768*8:,} 个(可训练)")
print(f"  可训练参数:{trainable:,}{100*trainable/total_params:.2f}%)")
print(f"  参数效率:{frozen:,} 个权重不需要梯度")
print()

# 验证:初始化时 LoRA 输出为 0(不影响原始模型)
x = torch.randn(4, 768)
# 原始权重的输出
with torch.no_grad():
    original_out = F.linear(x, lora_layer.weight, lora_layer.bias)
    lora_out     = lora_layer(x)
    delta        = (lora_out - original_out).abs().mean().item()

print(f"  初始化时 LoRA 的额外输出:{delta:.8f}(应接近 0)✓")
print(f"  (lora_B=0 → ΔW=B·A=0 → 不影响原始模型)")

2.2 在 SD 中应用 LoRA

def apply_lora_to_unet(unet: nn.Module, rank: int = 4, alpha: float = 4.0) -> dict:
    """
    给 U-Net 的关键线性层应用 LoRA
    标准做法:只对 Cross-Attention 层的 Q/K/V/Out 矩阵施加 LoRA

    为什么选 Cross-Attention?
    ① Cross-Attention 连接文本和图像,最与"内容/风格"相关
    ② Self-Attention 改动影响全局结构,容易破坏原模型能力
    ③ Feed-Forward 矩阵也可加(有时加,有时不加)
    """
    lora_layers = {}
    lora_params = 0
    total_params = 0

    for name, module in unet.named_modules():
        total_params += sum(p.numel() for p in module.parameters(recurse=False))

        # 针对 Cross-Attention 的 Q、K、V、输出投影矩阵
        if isinstance(module, nn.Linear):
            # 判断是否在 Cross-Attention 块中(按名称过滤)
            is_attn_proj = any(
                key in name for key in ["to_q", "to_k", "to_v", "to_out"]
            )
            if is_attn_proj:
                in_f  = module.in_features
                out_f = module.out_features
                # 用 LoRALinear 替换
                lora  = LoRALinear(in_f, out_f, rank=rank, alpha=alpha)
                lora.weight.data = module.weight.data.clone()
                if module.bias is not None:
                    lora.bias.data = module.bias.data.clone()

                lora_layers[name] = lora
                lora_params += 2 * rank * in_f  # A + B 的参数

    print(f"LoRA 应用统计:")
    print(f"  替换的层数:{len(lora_layers)}")
    print(f"  LoRA 可训练参数:{lora_params:,}")
    print(f"  参数效率:{100 * lora_params / max(total_params, 1):.4f}%")

    return lora_layers


# 统计不同 rank 的参数量
print("\n不同 rank 的 LoRA 参数量(以 SD 1.5 U-Net 为例,~180个 Cross-Attn 投影层):")
print(f"{'rank':^8} {'LoRA参数量':^16} {'比例':^12} {'文件大小':^12}")
print("─" * 50)

n_attn_layers = 180   # SD 1.5 的近似数量
avg_dim       = 768   # 平均维度

for r in [1, 2, 4, 8, 16, 32, 64]:
    lora_params = n_attn_layers * 2 * r * avg_dim
    total       = 860e6   # SD 1.5 总参数
    ratio       = lora_params / total * 100
    filesize_mb = lora_params * 4 / 1e6   # FP32
    print(f"  {r:^8} {lora_params:>12,}   {ratio:^12.4f}% {filesize_mb:>8.1f} MB")

print()
print("  → rank=4:只有 ~4M 参数,文件仅 ~16MB!")
print("  → 整个 SD 模型是 ~3.5GB,LoRA 仅 0.5%")

2.3 LoRA 训练循环

class LoRATrainer:
    """
    LoRA 微调训练器
    适用于 Stable Diffusion 的 U-Net 组件
    """

    def __init__(
        self,
        unet:            nn.Module,
        vae,
        text_encoder,
        alphas_cumprod:  torch.Tensor,
        lora_rank:       int   = 4,
        lora_alpha:      float = 4.0,
        learning_rate:   float = 1e-4,
        device:          str   = "cpu",
    ):
        self.device          = device
        self.alphas_cumprod  = alphas_cumprod

        # 冻结所有基础模型
        self.vae          = vae.eval().to(device)
        self.text_encoder = text_encoder.eval().to(device)
        self.unet         = unet.to(device)

        for p in self.vae.parameters():
            p.requires_grad = False
        for p in self.text_encoder.parameters():
            p.requires_grad = False
        for p in self.unet.parameters():
            p.requires_grad = False   # 先全冻结

        # 注入 LoRA 层(只解冻 LoRA 参数)
        self.lora_layers = self._inject_lora(lora_rank, lora_alpha)

        # 只优化 LoRA 参数
        lora_params = [p for p in self.unet.parameters() if p.requires_grad]
        self.optimizer = torch.optim.AdamW(
            lora_params,
            lr=learning_rate,
            weight_decay=1e-2,
        )

        n_trainable = sum(p.numel() for p in lora_params)
        n_total     = sum(p.numel() for p in self.unet.parameters())
        print(f"LoRA 初始化完成:")
        print(f"  可训练参数:{n_trainable:,} / {n_total:,} "
              f"({100*n_trainable/n_total:.3f}%)")

    def _inject_lora(self, rank: int, alpha: float) -> list:
        """向 U-Net 的 Cross-Attention 层注入 LoRA"""
        lora_layers = []
        for name, module in self.unet.named_modules():
            if isinstance(module, nn.Linear) and any(
                k in name for k in ["to_q", "to_k", "to_v", "to_out.0"]
            ):
                # 把原 Linear 替换为 LoRALinear
                parent = self._get_parent(name)
                attr   = name.split(".")[-1]
                in_f   = module.in_features
                out_f  = module.out_features

                lora_layer = LoRALinear(in_f, out_f, rank=rank, alpha=alpha)
                lora_layer.weight.data = module.weight.data.clone()
                if module.bias is not None:
                    lora_layer.bias.data = module.bias.data.clone()

                setattr(parent, attr, lora_layer)
                lora_layers.append(lora_layer)

        return lora_layers

    def _get_parent(self, name: str):
        """获取模块的父模块"""
        parts  = name.split(".")
        parent = self.unet
        for part in parts[:-1]:
            parent = getattr(parent, part)
        return parent

    def train_step(
        self,
        pixel_values:  torch.Tensor,   # 训练图像
        text_features: torch.Tensor,   # 文本嵌入
    ) -> float:
        """单步训练"""
        self.unet.train()
        self.optimizer.zero_grad()

        B = pixel_values.shape[0]

        # 1. 编码到潜空间(不计算梯度)
        with torch.no_grad():
            latents = self.vae.encode(pixel_values)
            # 缩放(假设 scaling_factor=0.18215)
            latents = latents * 0.18215

        # 2. 随机时间步和噪声
        t     = torch.randint(0, 1000, (B,), device=self.device)
        noise = torch.randn_like(latents)

        # 3. 前向加噪
        ab  = self.alphas_cumprod[t].reshape(-1, 1, 1, 1)
        xt  = ab.sqrt() * latents + (1 - ab).sqrt() * noise

        # 4. 预测噪声(LoRA 参数参与梯度计算)
        pred_noise = self.unet(xt, t, context=text_features)

        # 5. 损失(与标准 DDPM 相同)
        loss = F.mse_loss(pred_noise, noise)

        loss.backward()
        torch.nn.utils.clip_grad_norm_(
            [p for p in self.unet.parameters() if p.requires_grad],
            max_norm=1.0
        )
        self.optimizer.step()

        return loss.item()

    def save_lora(self, path: str):
        """保存 LoRA 权重(只保存可训练参数,文件很小!)"""
        lora_state = {
            name: param
            for name, param in self.unet.named_parameters()
            if param.requires_grad
        }
        torch.save(lora_state, path)
        filesize = sum(p.numel() * 4 for p in lora_state.values()) / 1e6
        print(f"LoRA 保存完成:{path}{filesize:.1f} MB)")


# 演示 LoRA 训练配置
print("\nLoRA 训练推荐配置:")
configs = {
    "人物风格(写实)": {
        "rank": 8, "alpha": 8.0, "lr": "1e-4",
        "steps": "1000-2000", "data": "15-30 张(多角度)",
        "instance_prompt": '"a photo of sks person"',
        "class_prompt":    '"a photo of a person"',
    },
    "画风/艺术风格": {
        "rank": 16, "alpha": 16.0, "lr": "1e-4",
        "steps": "2000-3000", "data": "20-50 张(同风格)",
        "instance_prompt": '"painting in sks style"',
        "class_prompt":    '"a painting"',
    },
    "特定物体": {
        "rank": 4, "alpha": 4.0, "lr": "5e-5",
        "steps": "500-1000", "data": "5-20 张(多视角)",
        "instance_prompt": '"a photo of sks object"',
        "class_prompt":    '"a photo of an object"',
    },
}

for task, cfg in configs.items():
    print(f"  [{task}]")
    for k, v in cfg.items():
        print(f"    {k:20s}: {v}")
    print()

三、DreamBooth:个性化生成

# DreamBooth(Ruiz et al., 2022,Google Research)
# 目标:让 SD 认识特定人物/物体,能生成"他在各种场景下"的图像
#
# 关键技术:
# ① 使用稀有词 "sks"(或其他稀有词)作为绑定词
#    → 最小化与已有概念的干扰
# ② Prior Preservation Loss:防止"遗忘"原有概念
#    → 同时用"a photo of a person"生成伪图,参与训练
#    → 避免训练后 SD 只会画特定的人,忘了"人"的概念

def explain_dreambooth():
    print("DreamBooth 训练流程:")
    print()
    print("  数据准备:")
    print("  ├── instance images(你的图): 15 张不同角度的照片")
    print("  │   instance prompt: 'a photo of [sks] person'")
    print("  └── class images(先验图): 200 张 SD 自己生成的")
    print("      class prompt: 'a photo of a person'")
    print()
    print("  训练目标:")
    print("  L = L_instance + λ·L_prior")
    print()
    print("  L_instance:让模型能根据 'sks person' 生成你的脸")
    print("  L_prior:   让模型保持对'person'概念的一般认知")
    print("             防止过拟合(catastrophic forgetting)")
    print()
    print("  λ(prior_loss_weight)通常设为 1.0")
    print()

    print("  DreamBooth vs LoRA 的选择:")
    print()
    differences = [
        ("可训练参数",  "全部 U-Net(~860M)",    "只有低秩矩阵(~3M)"),
        ("训练时间",    "~15-30 分钟",             "~30-60 分钟"),
        ("显存",        "~24GB(A100)",            "~12GB(消费级可)"),
        ("生成质量",    "★★★★★(通常更好)",    "★★★★☆"),
        ("文件大小",    "~3.5GB(完整模型)",       "~16-150MB"),
        ("可组合性",    "低(独立模型)",            "高(多 LoRA 叠加)"),
        ("开源社区",    "较少(太重)",              "CivitAI 数万个"),
    ]
    print(f"  {'指标':^16} {'DreamBooth':^22} {'LoRA':^22}")
    print("  " + "─" * 62)
    for metric, db, lora in differences:
        print(f"  {metric:^16} {db:^22} {lora:^22}")

    print()
    print("  实践建议:")
    print("  资源有限(消费级 GPU)→ LoRA")
    print("  追求最高质量(A100)→ DreamBooth + LoRA(DreamBooth-LoRA)")
    print("  社区分享 → LoRA(CivitAI 的主流格式)")


explain_dreambooth()

四、Textual Inversion:嵌入层微调

# Textual Inversion(Gal et al., 2022,NVIDIA)
# 最轻量的个性化方法:
# 只学习一个新的文本 token 嵌入向量(~768 维)
# 其他参数完全冻结

class TextualInversionEmbedding(nn.Module):
    """
    Textual Inversion:学习新的文本 token 嵌入

    原理:
    在 CLIP 的词汇表中添加一个新词 '*'(或 'sks')
    只训练这个词的嵌入向量,使 CLIP 能用它描述新概念
    """

    def __init__(
        self,
        text_encoder:   nn.Module,   # CLIP 文本编码器
        placeholder:    str  = "*",   # 占位符 token
        init_token:     str  = "artwork",  # 初始化参考词
        embedding_dim:  int  = 768,
    ):
        super().__init__()
        self.text_encoder = text_encoder
        self.placeholder  = placeholder

        # 新 token 的嵌入(这是唯一可训练的参数!)
        self.token_embedding = nn.Parameter(
            torch.randn(1, embedding_dim) * 0.01
        )

        # 冻结 text encoder 的所有其他参数
        for param in self.text_encoder.parameters():
            param.requires_grad = False

    def forward(self, tokens: torch.Tensor) -> torch.Tensor:
        """
        将特殊 token 替换为学到的嵌入
        其他 token 使用原始 CLIP 嵌入
        """
        # 实际实现中:
        # 找到 placeholder token 的位置
        # 用 self.token_embedding 替换对应位置的嵌入
        # 其余位置使用原始 CLIP 词嵌入
        pass   # 完整实现需要 CLIP 内部访问


# 统计:Textual Inversion 的参数量极少
embedding_dim = 768
n_params      = embedding_dim
print(f"Textual Inversion 参数量:{n_params} 个(!)")
print(f"文件大小:{n_params * 4 / 1024:.1f} KB(极小!)")
print()
print("优点:文件极小,可以分享为 '.pt' 嵌入文件")
print("缺点:表达能力有限,生成质量不如 LoRA 和 DreamBooth")
print()
print("典型使用场景:")
print("  学习一种美学风格(如'印象派油画感')")
print("  学习一种情绪/氛围(如'赛博朋克感')")
print("  在 SD 中快速引入小概念")

五、FLUX:2024 年的架构革新

# FLUX(Black Forest Labs,2024年8月)
# 由 Stable Diffusion 的核心创始团队打造
# 在文字渲染、指令遵循、图像质量上全面超越 SD XL

def explain_flux_architecture():
    print("FLUX 的核心架构创新:")
    print()
    print("  与 Stable Diffusion 的主要区别:")
    print()

    innovations = [
        ("模型架构",
         "SD:U-Net",
         "FLUX:Diffusion Transformer(DiT)\n         摒弃 U-Net,用纯 Transformer 结构"),
        ("文本编码器",
         "SD:CLIP(77 token)",
         "FLUX:CLIP + T5-XXL(512 token)\n         更强的长文本理解"),
        ("流匹配训练",
         "SD:DDPM/DDIM(噪声预测)",
         "FLUX:Flow Matching(速度预测)\n         训练更稳定,收敛更快"),
        ("图像 token 化",
         "SD:VAE 潜空间 4×64×64",
         "FLUX:VAE 潜空间 16×128×128\n         更高压缩比,更高分辨率"),
        ("注意力机制",
         "SD:Self-Attn + Cross-Attn 分离",
         "FLUX:双流 Transformer\n         图像 token 和文本 token 共同注意力"),
        ("文字渲染",
         "SD:极差(文字经常变形)",
         "FLUX:优秀(能正确渲染多行文字)"),
    ]

    for aspect, sd, flux in innovations:
        print(f"  [{aspect}]")
        print(f"    SD:   {sd}")
        print(f"    FLUX: {flux}")
        print()

    print("  FLUX 的三个版本:")
    print()
    versions = [
        ("FLUX.1 [pro]",    "最高质量",  "闭源,API 付费",    "12B 参数"),
        ("FLUX.1 [dev]",    "高质量",    "半开源(非商用)",   "12B 参数"),
        ("FLUX.1 [schnell]", "快速版",   "Apache 2.0 开源",   "12B 参数,4步"),
    ]
    for name, quality, license, size in versions:
        print(f"    {name:26s}: {quality:8s} | {license:20s} | {size}")


explain_flux_architecture()


def explain_flow_matching():
    """解释 FLUX 使用的 Flow Matching 训练目标"""

    print("\nFlow Matching vs DDPM 的训练目标对比:")
    print()
    print("  DDPM(第二、三篇介绍的):")
    print("    前向过程:xₜ = √ᾱₜ·x₀ + √(1-ᾱₜ)·ε")
    print("    训练目标:预测噪声 ε")
    print("    采样:反向 SDE(随机微分方程)")
    print()
    print("  Flow Matching(Lipman et al., 2022):")
    print("    前向过程:xₜ = (1-t)·x₀ + t·ε,t ∈ [0,1]")
    print("                   (简单线性插值!)")
    print("    训练目标:预测速度向量 v = ε - x₀")
    print("    采样:ODE(确定性常微分方程)")
    print()
    print("  Flow Matching 的优势:")
    print("  ① 更简单的前向过程(线性插值)")
    print("  ② 训练更稳定(梯度不受 ᾱₜ 的非线性影响)")
    print("  ③ 采样轨迹更直(ODE 路径近似直线)")
    print("  ④ 更少步数就能得到好结果")
    print()
    print("  直觉类比:")
    print("  DDPM:从 A 到 B 走一条复杂的曲线路")
    print("  Flow Matching:走一条近似直线,更短更快")


explain_flow_matching()

5.2 DiT(Diffusion Transformer)架构

class DiTBlock(nn.Module):
    """
    Diffusion Transformer Block(FLUX 使用的基本单元)
    参考:Peebles & Xie (2023)
    与 U-Net 的最大区别:纯 Transformer,无卷积,无 U 形结构

    双流设计(FLUX 特有):
    图像 token 序列 + 文本 token 序列 → 各自的注意力 + 联合注意力
    """

    def __init__(
        self,
        hidden_dim:  int = 1024,
        num_heads:   int = 16,
        mlp_ratio:   float = 4.0,
    ):
        super().__init__()
        self.num_heads  = num_heads
        self.head_dim   = hidden_dim // num_heads
        self.scale      = self.head_dim ** -0.5
        mlp_dim         = int(hidden_dim * mlp_ratio)

        # ── 图像 token 流 ────────────────────────────────────
        self.img_norm1  = nn.LayerNorm(hidden_dim, elementwise_affine=False)
        self.img_attn   = nn.MultiheadAttention(hidden_dim, num_heads, batch_first=True)
        self.img_norm2  = nn.LayerNorm(hidden_dim, elementwise_affine=False)
        self.img_ff     = nn.Sequential(
            nn.Linear(hidden_dim, mlp_dim),
            nn.GELU(approximate="tanh"),
            nn.Linear(mlp_dim, hidden_dim),
        )

        # ── 文本 token 流 ────────────────────────────────────
        self.txt_norm1  = nn.LayerNorm(hidden_dim, elementwise_affine=False)
        self.txt_attn   = nn.MultiheadAttention(hidden_dim, num_heads, batch_first=True)
        self.txt_norm2  = nn.LayerNorm(hidden_dim, elementwise_affine=False)
        self.txt_ff     = nn.Sequential(
            nn.Linear(hidden_dim, mlp_dim),
            nn.GELU(approximate="tanh"),
            nn.Linear(mlp_dim, hidden_dim),
        )

        # ── 时间步调制(AdaLN-Zero)────────────────────────
        # 时间步和文本 → 6 个调制参数(shift,scale × 2 + gate × 2)
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(hidden_dim, 6 * hidden_dim),
        )

    def forward(
        self,
        img:      torch.Tensor,   # 图像 token (B, N_img, D)
        txt:      torch.Tensor,   # 文本 token (B, N_txt, D)
        timestep: torch.Tensor,   # 时间步嵌入 (B, D)
    ) -> tuple:

        # 从时间步获取 6 个调制参数
        mod = self.adaLN_modulation(timestep).chunk(6, dim=-1)
        shift1_img, scale1_img, gate1_img, shift2_img, scale2_img, gate2_img = mod

        # ── 图像 token:注意力 ────────────────────────────
        h_img = self.img_norm1(img)
        # AdaLN 调制:缩放 + 偏移
        h_img = h_img * (1 + scale1_img.unsqueeze(1)) + shift1_img.unsqueeze(1)
        h_img_attn, _ = self.img_attn(h_img, h_img, h_img)
        img = img + gate1_img.unsqueeze(1) * h_img_attn

        # ── 图像 token:FFN ───────────────────────────────
        h_img = self.img_norm2(img)
        h_img = h_img * (1 + scale2_img.unsqueeze(1)) + shift2_img.unsqueeze(1)
        img   = img + gate2_img.unsqueeze(1) * self.img_ff(h_img)

        # ── 文本 token:(简化,实际 FLUX 做联合注意力)──
        txt = txt + self.txt_attn(self.txt_norm1(txt),
                                   self.txt_norm1(txt),
                                   self.txt_norm1(txt))[0]
        txt = txt + self.txt_ff(self.txt_norm2(txt))

        return img, txt


# 验证 DiT Block
B, N_img, N_txt, D = 2, 1024, 77, 256  # 简化版
dit_block = DiTBlock(hidden_dim=D, num_heads=8)

img_tokens  = torch.randn(B, N_img, D)
txt_tokens  = torch.randn(B, N_txt, D)
time_embed  = torch.randn(B, D)

img_out, txt_out = dit_block(img_tokens, txt_tokens, time_embed)

print("DiT Block 验证:")
print(f"  图像 token 输入:{img_tokens.shape}")
print(f"  文本 token 输入:{txt_tokens.shape}")
print(f"  时间步嵌入:    {time_embed.shape}")
print(f"  图像 token 输出:{img_out.shape}")
print(f"  文本 token 输出:{txt_out.shape}")
print(f"  参数量:{sum(p.numel() for p in dit_block.parameters()):,}")

六、一致性模型:1 步生成

# Consistency Models(Song et al., 2023,OpenAI)
# 目标:1-2 步就能生成高质量图像(比 DDIM 快 10-50 倍)
#
# 核心思想:
# 在同一条 ODE 轨迹上的所有点(x₀, x₁, ..., x_T)
# 应该都映射到同一个 x₀
# 训练一个"一致性函数" f_θ:任意 xₜ → x₀(一步还原!)
#
# 两种训练方式:
# Consistency Distillation(CD):从已有扩散模型蒸馏
# Consistency Training(CT):从头训练

def explain_consistency_models():
    print("一致性模型(Consistency Models)核心思想:")
    print()
    print("  传统扩散模型(DDIM):")
    print("    x_T → x_{900} → x_{800} → ... → x₀")
    print("    每步都是 U-Net,50步 = 50次推理")
    print()
    print("  一致性模型:")
    print("    定义:一致性函数 f_θ(xₜ, t) = x₀")
    print("    ↑ 无论从轨迹的哪个点出发,直接跳回 x₀!")
    print()
    print("    采样:x_T → f_θ(x_T, T) = x̂₀(1步!)")
    print()
    print("    多步采样(更高质量):")
    print("    x_T → x̂₀ → 加少量噪声到 xₜ₁ → x̂₀ → ...")
    print("    每次添加噪声后再去噪,类似'打磨'效果")
    print()

    # 一致性函数的形式
    print("  一致性函数的参数化:")
    print("    f_θ(xₜ, t) = c_skip(t)·xₜ + c_out(t)·F_θ(xₜ, t)")
    print()
    print("    c_skip(t):当 t→0 时,f→xₜ(边界条件)")
    print("    c_out(t): 控制网络输出的权重")
    print("    F_θ:      实际的神经网络(与 U-Net 结构相同)")
    print()

    print("  各方法的速度质量对比:")
    print()
    comparison = [
        ("DDPM(1000步)",         1000, "★★★★★", "1×(基准)"),
        ("DDIM(50步)",           50,   "★★★★☆", "20×"),
        ("DPM-Solver++(20步)",   20,   "★★★★★", "50×"),
        ("LCM(4步)",             4,    "★★★☆☆", "250×"),
        ("Consistency Model(2步)",2,   "★★★☆☆", "500×"),
        ("Consistency Model(1步)",1,   "★★★☆☆", "1000×"),
        ("SDXL-Turbo(1步)",      1,    "★★★★☆", "1000×(质量更好)"),
    ]
    print(f"  {'方法':^30} {'步数':^6} {'质量':^12} {'速度':^10}")
    print("  " + "─" * 62)
    for name, steps, quality, speed in comparison:
        print(f"  {name:30s} {steps:^6} {quality:^12} {speed:^10}")

    print()
    print("  实践建议:")
    print("  普通生成:DPM-Solver++ 20步(质量最优)")
    print("  实时交互:LCM 或 SDXL-Turbo(1-4步)")
    print("  研究极限:Consistency Models(1步)")


explain_consistency_models()


class ConsistencyFunction(nn.Module):
    """
    一致性函数的核心结构
    f_θ(xₜ, t) = c_skip(t)·xₜ + c_out(t)·F_θ(xₜ, t)
    """

    def __init__(self, base_network: nn.Module, sigma_data: float = 0.5):
        super().__init__()
        self.F_theta   = base_network   # 底层网络(可以是 U-Net 或 DiT)
        self.sigma_data = sigma_data    # 数据的标准差

    def c_skip(self, t: torch.Tensor) -> torch.Tensor:
        """边界条件系数:t→0 时 c_skip→1,网络输出被忽略"""
        return self.sigma_data ** 2 / (t ** 2 + self.sigma_data ** 2)

    def c_out(self, t: torch.Tensor) -> torch.Tensor:
        """网络输出的缩放系数"""
        return t * self.sigma_data / (t ** 2 + self.sigma_data ** 2) ** 0.5

    def forward(
        self,
        xt:      torch.Tensor,   # 加噪输入
        t:       torch.Tensor,   # 时间步(这里是连续值)
        context: torch.Tensor = None,
    ) -> torch.Tensor:
        """
        一步还原 x₀
        无论 t 多大(噪声多少),输出都应该是对 x₀ 的估计
        """
        # 各样本的缩放系数
        cs = self.c_skip(t).reshape(-1, 1, 1, 1)
        co = self.c_out(t).reshape(-1, 1, 1, 1)

        # 网络预测
        F_out = self.F_theta(xt, t, context) if context is not None \
                else self.F_theta(xt, t)

        # 一致性函数:边界项 + 网络项
        return cs * xt + co * F_out

七、系列收官:完整知识图谱

def print_series_summary():
    """打印系列完整知识图谱"""

    print("=" * 70)
    print("  扩散模型从零到实战系列:完整知识图谱")
    print("=" * 70)
    print()

    chapters = [
        ("第一篇", "扩散模型是什么",
         ["前向过程(加噪)与反向过程(去噪)的直觉",
          "与 GAN/VAE 的对比:为什么扩散模型赢了",
          "DDPM 的三个核心组件",
          "diffusers 代码初体验"]),

        ("第二篇", "前向过程数学",
         ["马尔可夫链:q(xₜ|xₜ₋₁) = N(√(1-β)·xₜ₋₁, βI)",
          "⭐ 重参数化:xₜ = √ᾱₜ·x₀ + √(1-ᾱₜ)·ε(直跳任意步!)",
          "三种噪声调度:linear / cosine / scaled_linear",
          "信噪比(SNR)与 Min-SNR 加权损失"]),

        ("第三篇", "反向过程与训练目标",
         ["后验分布 q(xₜ₋₁|xₜ,x₀) 的解析形式",
          "⭐ ELBO → 简化 MSE 损失:L = ||ε - ε_θ(xₜ,t)||²",
          "三种参数化:预测 ε / x₀ / v",
          "完整 DDPMTrainer 和 DDPMSampler 实现"]),

        ("第四篇", "U-Net 去噪网络",
         ["正弦时间步嵌入(Sinusoidal + MLP)",
          "⭐ ResBlock + AdaGN(时间步注入 scale+shift)",
          "SelfAttention(仅在 ≤16×16 分辨率加)",
          "U 形结构:编码器 + 跳跃连接 + 解码器"]),

        ("第五篇", "DDIM 加速采样",
         ["非马尔可夫过程:边缘分布与 DDPM 相同",
          "子序列采样:1000步→50步,速度 20×",
          "⭐ η=0 确定性采样:相同噪声→相同图像",
          "DDIM Inversion:图像反演,精确编辑基础"]),

        ("第六篇", "条件生成与 CFG",
         ["CLIP 文本编码:文字 → (B,77,768) 序列",
          "⭐ 交叉注意力:Q=图像,K/V=文本",
          "⭐ CFG 公式:ε_g = ε_u + γ·(ε_c - ε_u)",
          "负面提示词、guidance_scale=7.5 推荐"]),

        ("第七篇", "Stable Diffusion / LDM",
         ["VAE 编码器/解码器:像素↔潜空间(压缩 48×)",
          "⭐ 潜空间扩散:速度 30-50×,显存 1/48",
          "三条推理管线:text2img / img2img / inpainting",
          "ControlNet:Zero Conv 技巧,精确结构控制"]),

        ("第八篇(本篇)", "实战与前沿",
         ["⭐ LoRA:低秩适配,<0.5% 参数,30分钟微调",
          "DreamBooth:个性化生成,Prior Preservation Loss",
          "FLUX:DiT 架构 + Flow Matching + T5-XXL",
          "一致性模型:1步生成,ODE 轨迹自洽"]),
    ]

    for chapter, title, points in chapters:
        print(f"  ┌─ {chapter}{title}")
        for i, point in enumerate(points):
            last = i == len(points) - 1
            prefix = "  └─── " if last else "  ├─── "
            print(f"  {prefix}{point}")
        print()

    print()
    print("  贯穿全系列的三个核心公式:")
    print()
    print("  ① 前向过程(重参数化):")
    print("     xₜ = √ᾱₜ · x₀ + √(1-ᾱₜ) · ε")
    print()
    print("  ② 训练目标(简化 MSE):")
    print("     L = E[||ε - ε_θ(√ᾱₜ·x₀ + √(1-ᾱₜ)·ε, t)||²]")
    print()
    print("  ③ Classifier-Free Guidance:")
    print("     ε_guided = ε_uncond + γ · (ε_cond - ε_uncond)")
    print()
    print("=" * 70)
    print("  感谢你跟随这个系列走完了扩散模型的完整旅程!")
    print("  从第一篇的物理直觉,到最后的 LoRA 实战和 FLUX 前沿,")
    print("  希望你不只是学会了用,更理解了为什么。")
    print("=" * 70)


print_series_summary()

总结:微调方案选择指南

方案 参数量 时间 显存 数据 最佳适用
LoRA ~3M(0.5%) 30-60 min ~12GB 10-50张 风格/人物,最流行
DreamBooth ~860M(全量) 15-30 min ~24GB 3-30张 高质量人物,资源充足
Textual Inversion ~768 维 1-2 h ~8GB 5-20张 轻量概念嵌入
DreamBooth-LoRA ~3M 20-40 min ~12GB 5-30张 最佳平衡 ⭐⭐

2024-2026 前沿趋势:

  1. FLUX 取代 SDXL 成为新主流(文字渲染、指令遵循大幅提升)
  2. 实时生成(LCM/Turbo/Consistency)进入主流产品
  3. 视频扩散(Sora、CogVideoX)成为新战场
  4. 多模态条件(IP-Adapter、InstantID、ControlNet++)让控制更精确

💬 跟完了八篇系列,你现在最想动手实现的是哪个部分?LoRA 微调自己的人物,还是从零复现一个 DDPM? 欢迎评论区分享!

🙏 「扩散模型从零到实战」系列(八篇)完结撒花!从物理直觉到数学推导,从代码实现到最新前沿,如果整个系列帮到你,最后一次三连(点赞👍 + 收藏⭐ + 关注)!感谢一路相伴!


本文为原创技术分享。代码在 Python 3.12 + PyTorch 2.x + diffusers 0.27 下验证。最后更新:2026-05-20
扩散模型从零到实战系列(八篇)完结 🎉

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐