面向:AI 初学者 / 论文复现新手 / 想读懂 DiT 的同学

仓库:Facebook Research DiT

论文:Scalable Diffusion Models with Transformers


一、为什么会有 DiT?

最近几年,生成式 AI 爆发。

从:

  • Stable Diffusion
  • Midjourney
  • DALL·E
  • Flux

到现在的视频生成模型。

这些模型背后,很多都属于:

Diffusion Model(扩散模型)

而 DiT(Diffusion Transformer)则是扩散模型发展中的一个非常重要的里程碑。

因为它第一次真正证明:

Transformer 可以替代传统 U-Net

并且效果非常强。


二、先理解:什么是 Diffusion?

很多人一上来就看公式。

结果:

5 分钟放弃。

其实 Diffusion 的核心思想非常简单。


1. 训练过程:不断加噪声

假设有一张猫咪图片:

猫咪图片

我们不断加噪声:

猫咪 → 模糊 → 雪花 → 全噪声

最后变成:

随机噪点

这个过程叫:

Forward Diffusion

即:

正向扩散


2. 生成过程:反向去噪

现在反过来。

模型从随机噪声开始:

噪声 → 模糊轮廓 → 更清晰 → 最终图片

这就是:

Reverse Diffusion

也就是:

AI 学会“去噪”

最终生成图片。


三、传统 Diffusion 模型的问题

早期 diffusion 模型大多使用:

U-Net

例如:

  • DDPM
  • Stable Diffusion
  • Latent Diffusion

结构通常是:

噪声图片

U-Net

预测噪声


四、DiT 的核心思想

论文作者提出一个问题:

Transformer 已经统治 NLP。

ViT 已经统治视觉。

那能不能用 Transformer 替代 diffusion 中的 U-Net?

答案:

可以。

于是 DiT 出现了。

它的结构:

Noise Image

Patchify

Transformer

Predict Noise

本质上:

DiT = ViT + Diffusion


五、ViT 是什么?

理解 DiT 前,必须先知道:

ViT(Vision Transformer)

ViT 做了一件很重要的事情:

把图片当成 Token 序列

例如:

256×256 图片

切成:

16×16 patch

那么:

256 / 16 = 16

所以:

16 × 16 = 256 个 patch

每个 patch:

类似 NLP 中的 token

然后输入 Transformer。


六、DiT 整体流程(最重要)

DiT 的完整流程:

带噪图像 x_t

Patch Embedding

位置编码

时间步编码 t

类别编码 y

Transformer Blocks

Linear Projection

Unpatchify

预测噪声

理解了这个流程,论文已经理解 70%。


七、DiT 的输入是什么?

模型输入:

x_t

含义:

某一个时间步的带噪图片

shape 一般:

(B, C, H, W)

例如:

(1, 4, 32, 32)


八、Patchify:图片切块

这一部分是 ViT 的核心。

DiT 会把图片切成小块:

████ ████ ████

████ ████ ████

████ ████ ████

然后:

每个 patch → 一个 token。

这一步对应代码里的:

PatchEmbed


九、Transformer 开始工作

Transformer 的核心:

Attention

简单理解:

每个 patch 都会“看”其他 patch。

例如:

  • 天空 patch
  • 草地 patch
  • 猫脸 patch

它们之间会建立联系。

因此 Transformer 特别适合:

全局建模

这也是 DiT 强大的原因。


十、时间步 Embedding 是什么?

Diffusion 有很多时间步:

t = 1

噪声很少。

t = 999

几乎全是噪声。

模型必须知道:

当前噪声有多严重

所以需要:

Timestep Embedding

代码里:

TimestepEmbedder


十一、类别 Embedding 是什么?

如果我们告诉模型:

生成狗

或者:

生成猫

模型需要知道生成目标。

因此需要:

Label Embedding

代码里:

LabelEmbedder


十二、DiT 的真正核心:Transformer Block

代码:

class DiTBlock(nn.Module)

里面主要包含:

  • Attention
  • LayerNorm
  • MLP

这部分其实和 GPT 非常像。

只是:

token 从文字变成了图像 patch


十三、AdaLN-Zero 是什么?

这是 DiT 论文里最容易卡住的地方。

其实它的思想很简单:

用条件信息控制 LayerNorm

条件包括:

  • 时间步 t
  • 类别 y

模型通过:

shift

scale

gate

控制 Transformer block 的行为。

这叫:

Adaptive LayerNorm

简称:

AdaLN


十四、DiT 为什么效果强?

论文最大的贡献:

Scaling Law

即:

模型越大 → 效果越强

作者发现:

Transformer 在 diffusion 上也满足 scaling law。

因此:

  • 更深
  • 更宽
  • 更多 attention heads

都能提升生成质量。


十五、代码结构解析

GitHub:

https://github.com/facebookresearch/DiT

仓库结构:

DiT/

├── train.py

├── sample.py

├── models.py

├── diffusion/

├── download.py

└── README.md


十六、train.py 在干什么?

这是训练入口。

主要流程:

创建模型

创建 diffusion

加载数据

开始训练

其中最关键:

model = DiT_models[args.model](...)

这里实例化 DiT。


十七、models.py 才是真正核心

这个文件最重要。

建议重点阅读:

1. PatchEmbed

作用:

图片 → token


2. TimestepEmbedder

作用:

时间步编码


3. LabelEmbedder

作用:

类别编码


4. DiTBlock

真正的 Transformer block。


5. forward()

整个数据流都在这里。

最值得读。


十八、forward() 数据流解析

核心代码逻辑:

x = self.x_embedder(x)

图片切 patch。


x = x + pos_embed

加入位置编码。


t = self.t_embedder(t)

时间步编码。


for block in self.blocks:

进入 Transformer blocks。


x = self.final_layer(x)

输出预测噪声。


十九、新手应该怎么读代码?

不要一上来研究数学。

最有效的方法:

先看 tensor shape

例如:

print(x.shape)

在:

  • patch 前
  • patch 后
  • attention 后

全部打印。

你会瞬间理解模型。


二十、如何运行 DiT

1. 克隆仓库

git clone https://github.com/facebookresearch/DiT.git

cd DiT


2. 创建环境

conda env create -f environment.yml

conda activate DiT


3. 下载权重

阅读 README 中的下载说明。


4. 生成图片

运行:

python sample.py

第一次建议:

先跑通 sample

不要急着训练。


二十一、为什么 DiT 很重要?

因为它影响了后面很多模型。

包括:

  • PixArt
  • SD3
  • Flux
  • Sora 系列思路

很多新一代生成模型:

都开始 Transformer 化


二十二、DiT 与 Stable Diffusion 的区别

对比 Stable Diffusion DiT
Backbone U-Net Transformer
核心结构 CNN ViT
全局建模 一般
扩展性 中等 很强
Scaling Law 不明显 很明显

二十三、学习建议(非常重要)

很多新手会:

直接硬啃论文

然后崩溃。

正确方法:

第一步

先理解:

输入输出


第二步

理解:

数据流


第三步

再理解:

  • Attention
  • AdaLN
  • CFG
  • Scheduler

二十四、给新手的建议

不要追求:

一次全部看懂

这是不可能的。

真正有效的方法:

反复运行 + 打印 shape

这是学习深度学习源码最快的方法。


二十五、总结

DiT 的核心思想其实很简单:

用 Transformer 替代 U-Net

整个流程:

带噪图像

Patchify

Transformer

预测噪声

本质:

Diffusion + ViT

而它真正伟大的地方:

证明了 Transformer 在生成模型中的巨大潜力

这也是后续大量生成模型 Transformer 化的开始。


二十六、推荐下一步学习

建议按顺序继续学习:

  1. ViT
  2. Attention
  3. DDPM
  4. Stable Diffusion
  5. DiT
  6. Flux / SD3

二十七、参考资料

论文:

Scalable Diffusion Models with Transformers

GitHub:

https://github.com/facebookresearch/DiT


最后

如果你是第一次接触 diffusion:

不要怕。

因为:

每个大模型工程师都经历过“完全看不懂源码”的阶段。

真正的成长方式不是:

一次看懂

而是:

反复运行

反复打印

反复调试

当你开始能看懂 tensor shape 的时候。

你已经正式进入深度学习工程世界了。

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐