从零理解 DiT:新手也能看懂的 Diffusion Transformer(附代码解析)
面向:AI 初学者 / 论文复现新手 / 想读懂 DiT 的同学
仓库:Facebook Research DiT
论文:Scalable Diffusion Models with Transformers
一、为什么会有 DiT?
最近几年,生成式 AI 爆发。
从:
- Stable Diffusion
- Midjourney
- DALL·E
- Flux
到现在的视频生成模型。
这些模型背后,很多都属于:
Diffusion Model(扩散模型)
而 DiT(Diffusion Transformer)则是扩散模型发展中的一个非常重要的里程碑。
因为它第一次真正证明:
Transformer 可以替代传统 U-Net
并且效果非常强。
二、先理解:什么是 Diffusion?
很多人一上来就看公式。
结果:
5 分钟放弃。
其实 Diffusion 的核心思想非常简单。
1. 训练过程:不断加噪声
假设有一张猫咪图片:
猫咪图片
我们不断加噪声:
猫咪 → 模糊 → 雪花 → 全噪声
最后变成:
随机噪点
这个过程叫:
Forward Diffusion
即:
正向扩散
2. 生成过程:反向去噪
现在反过来。
模型从随机噪声开始:
噪声 → 模糊轮廓 → 更清晰 → 最终图片
这就是:
Reverse Diffusion
也就是:
AI 学会“去噪”
最终生成图片。
三、传统 Diffusion 模型的问题
早期 diffusion 模型大多使用:
U-Net
例如:
- DDPM
- Stable Diffusion
- Latent Diffusion
结构通常是:
噪声图片
↓
U-Net
↓
预测噪声
四、DiT 的核心思想
论文作者提出一个问题:
Transformer 已经统治 NLP。
ViT 已经统治视觉。
那能不能用 Transformer 替代 diffusion 中的 U-Net?
答案:
可以。
于是 DiT 出现了。
它的结构:
Noise Image
↓
Patchify
↓
Transformer
↓
Predict Noise
本质上:
DiT = ViT + Diffusion
五、ViT 是什么?
理解 DiT 前,必须先知道:
ViT(Vision Transformer)
ViT 做了一件很重要的事情:
把图片当成 Token 序列
例如:
256×256 图片
切成:
16×16 patch
那么:
256 / 16 = 16
所以:
16 × 16 = 256 个 patch
每个 patch:
类似 NLP 中的 token
然后输入 Transformer。
六、DiT 整体流程(最重要)
DiT 的完整流程:
带噪图像 x_t
↓
Patch Embedding
↓
位置编码
↓
时间步编码 t
↓
类别编码 y
↓
Transformer Blocks
↓
Linear Projection
↓
Unpatchify
↓
预测噪声
理解了这个流程,论文已经理解 70%。
七、DiT 的输入是什么?
模型输入:
x_t
含义:
某一个时间步的带噪图片
shape 一般:
(B, C, H, W)
例如:
(1, 4, 32, 32)
八、Patchify:图片切块
这一部分是 ViT 的核心。
DiT 会把图片切成小块:
████ ████ ████
████ ████ ████
████ ████ ████
然后:
每个 patch → 一个 token。
这一步对应代码里的:
PatchEmbed
九、Transformer 开始工作
Transformer 的核心:
Attention
简单理解:
每个 patch 都会“看”其他 patch。
例如:
- 天空 patch
- 草地 patch
- 猫脸 patch
它们之间会建立联系。
因此 Transformer 特别适合:
全局建模
这也是 DiT 强大的原因。
十、时间步 Embedding 是什么?
Diffusion 有很多时间步:
t = 1
噪声很少。
t = 999
几乎全是噪声。
模型必须知道:
当前噪声有多严重
所以需要:
Timestep Embedding
代码里:
TimestepEmbedder
十一、类别 Embedding 是什么?
如果我们告诉模型:
生成狗
或者:
生成猫
模型需要知道生成目标。
因此需要:
Label Embedding
代码里:
LabelEmbedder
十二、DiT 的真正核心:Transformer Block
代码:
class DiTBlock(nn.Module)
里面主要包含:
- Attention
- LayerNorm
- MLP
这部分其实和 GPT 非常像。
只是:
token 从文字变成了图像 patch
十三、AdaLN-Zero 是什么?
这是 DiT 论文里最容易卡住的地方。
其实它的思想很简单:
用条件信息控制 LayerNorm
条件包括:
- 时间步 t
- 类别 y
模型通过:
shift
scale
gate
控制 Transformer block 的行为。
这叫:
Adaptive LayerNorm
简称:
AdaLN
十四、DiT 为什么效果强?
论文最大的贡献:
Scaling Law
即:
模型越大 → 效果越强
作者发现:
Transformer 在 diffusion 上也满足 scaling law。
因此:
- 更深
- 更宽
- 更多 attention heads
都能提升生成质量。
十五、代码结构解析
GitHub:
https://github.com/facebookresearch/DiT
仓库结构:
DiT/
├── train.py
├── sample.py
├── models.py
├── diffusion/
├── download.py
└── README.md
十六、train.py 在干什么?
这是训练入口。
主要流程:
创建模型
创建 diffusion
加载数据
开始训练
其中最关键:
model = DiT_models[args.model](...)
这里实例化 DiT。
十七、models.py 才是真正核心
这个文件最重要。
建议重点阅读:
1. PatchEmbed
作用:
图片 → token
2. TimestepEmbedder
作用:
时间步编码
3. LabelEmbedder
作用:
类别编码
4. DiTBlock
真正的 Transformer block。
5. forward()
整个数据流都在这里。
最值得读。
十八、forward() 数据流解析
核心代码逻辑:
x = self.x_embedder(x)
图片切 patch。
x = x + pos_embed
加入位置编码。
t = self.t_embedder(t)
时间步编码。
for block in self.blocks:
进入 Transformer blocks。
x = self.final_layer(x)
输出预测噪声。
十九、新手应该怎么读代码?
不要一上来研究数学。
最有效的方法:
先看 tensor shape
例如:
print(x.shape)
在:
- patch 前
- patch 后
- attention 后
全部打印。
你会瞬间理解模型。
二十、如何运行 DiT
1. 克隆仓库
git clone https://github.com/facebookresearch/DiT.git
cd DiT
2. 创建环境
conda env create -f environment.yml
conda activate DiT
3. 下载权重
阅读 README 中的下载说明。
4. 生成图片
运行:
python sample.py
第一次建议:
先跑通 sample
不要急着训练。
二十一、为什么 DiT 很重要?
因为它影响了后面很多模型。
包括:
- PixArt
- SD3
- Flux
- Sora 系列思路
很多新一代生成模型:
都开始 Transformer 化
二十二、DiT 与 Stable Diffusion 的区别
| 对比 | Stable Diffusion | DiT |
|---|---|---|
| Backbone | U-Net | Transformer |
| 核心结构 | CNN | ViT |
| 全局建模 | 一般 | 强 |
| 扩展性 | 中等 | 很强 |
| Scaling Law | 不明显 | 很明显 |
二十三、学习建议(非常重要)
很多新手会:
直接硬啃论文
然后崩溃。
正确方法:
第一步
先理解:
输入输出
第二步
理解:
数据流
第三步
再理解:
- Attention
- AdaLN
- CFG
- Scheduler
二十四、给新手的建议
不要追求:
一次全部看懂
这是不可能的。
真正有效的方法:
反复运行 + 打印 shape
这是学习深度学习源码最快的方法。
二十五、总结
DiT 的核心思想其实很简单:
用 Transformer 替代 U-Net
整个流程:
带噪图像
↓
Patchify
↓
Transformer
↓
预测噪声
本质:
Diffusion + ViT
而它真正伟大的地方:
证明了 Transformer 在生成模型中的巨大潜力
这也是后续大量生成模型 Transformer 化的开始。
二十六、推荐下一步学习
建议按顺序继续学习:
- ViT
- Attention
- DDPM
- Stable Diffusion
- DiT
- Flux / SD3
二十七、参考资料
论文:
Scalable Diffusion Models with Transformers
GitHub:
https://github.com/facebookresearch/DiT
最后
如果你是第一次接触 diffusion:
不要怕。
因为:
每个大模型工程师都经历过“完全看不懂源码”的阶段。
真正的成长方式不是:
一次看懂
而是:
反复运行
反复打印
反复调试
当你开始能看懂 tensor shape 的时候。
你已经正式进入深度学习工程世界了。
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐



所有评论(0)