【DriveGen 文件详解】02——train.py
DriveGen/
├── configs/
│ └── default.yaml # 训练配置文件
├── DriveGen/
│ ├── __init__.py # 包初始化
│ ├── models/
│ │ ├── __init__.py
│ │ ├── embedding.py # Patch嵌入、时间步编码、位置编码
│ │ ├── attention.py # 空间注意力、时间注意力
│ │ ├── dit_block.py # AdaLN-Zero DiT Block
│ │ └── stdit.py # STDiT 完整模型
│ ├── data/
│ │ ├── __init__.py
│ │ └── dataset.py # 合成数据集 + nuScenes 适配器
│ ├── schedules/
│ │ ├── __init__.py
│ │ └── noise_schedule.py # 线性/余弦噪声调度
│ └── utils/
│ ├── __init__.py
│ ├── visualization.py # 视频保存、对比图、损失曲线
│ └── logger.py # 日志工具
├── train.py # 训练脚本
├── inference.py # 推理脚本(DDPM 采样 + CFG)
├── evaluate.py # 评估脚本(FID 计算)
├── requirements.txt # 依赖清单
├── setup.py # 安装配置
└── README.md # 本文件
LQY-hh/DriveGen-Transformer-: 自动驾驶技术的发展离不开海量数据的支撑,但稀有场景(如极端天气、突发事故)的数据采集成本极高。**DriveGen** 旨在通过扩散模型生成高质量的驾驶场景视频,为自动驾驶算法提供无限的虚拟训练数据。 ### 核心价值
https://github.com/LQY-hh/DriveGen-Transformer-DriveGen 训练脚本技术文档
1. 文件概述
train.py 是 DriveGen 项目的核心训练入口,实现了一个完整的条件视频扩散模型训练流程。该脚本基于扩散模型(Diffusion Model)原理,从驾驶场景的条件帧出发,学习预测未来帧序列,可应用于自动驾驶仿真、数据增强等场景。
核心功能
|
模块 |
功能描述 |
|---|---|
|
参数解析 |
支持命令行参数与 YAML 配置文件混合使用 |
|
噪声调度 |
管理扩散过程的加噪/去噪时序 |
|
模型构建 |
支持 STDiT(时空扩散Transformer)和 SimpleUNet |
|
训练循环 |
实现完整的扩散训练流程 |
|
检查点管理 |
支持保存/恢复训练进度 |
|
推理测试 |
训练结束后自动生成测试样本 |
2. 技术架构
2.1 整体架构图
┌─────────────────────────────────────────────────────────────────────────┐
│ train.py 训练流程 │
├─────────────────────────────────────────────────────────────────────────┤
│ 命令行参数 ──┐ │
│ │ │
│ YAML 配置 ──┼──→ 配置合并 ──→ 训练配置 │
│ │ │
│ ↓ │
│ ┌─────────────────────────────────────────────────────────────────┐ │
│ │ 初始化阶段 │ │
│ │ 随机种子 → 设备检测 → 日志设置 → 噪声调度器 → 数据加载器 │ │
│ │ → 模型创建 → 优化器 │ │
│ └─────────────────────────────────────────────────────────────────┘ │
│ ↓ │
│ ┌─────────────────────────────────────────────────────────────────┐ │
│ │ 训练循环 │ │
│ │ ┌─────────────────────────────────────────────────────────┐ │ │
│ │ │ 前向扩散: 随机时间步 → 生成噪声 → 加噪到未来帧 │ │ │
│ │ └─────────────────────────────────────────────────────────┘ │ │
│ │ ↓ │ │ │
│ │ ┌─────────────────────────────────────────────────────────┐ │ │
│ │ │ 模型预测: noisy_future + t + condition → noise_pred │ │ │
│ │ └─────────────────────────────────────────────────────────┘ │ │
│ │ ↓ │ │ │
│ │ ┌─────────────────────────────────────────────────────────┐ │ │
│ │ │ 损失计算: MSE(noise_pred, noise) │ │ │
│ │ └─────────────────────────────────────────────────────────┘ │ │
│ │ ↓ │ │ │
│ │ ┌─────────────────────────────────────────────────────────┐ │ │
│ │ │ 梯度累积 → 梯度裁剪 → 优化器更新 │ │ │
│ │ └─────────────────────────────────────────────────────────┘ │ │
│ │ ↓ │ │ │
│ │ ┌─────────────────────────────────────────────────────────┐ │ │
│ │ │ 检查点保存 / 最佳模型保存 / 损失曲线绘制 │ │ │
│ │ └─────────────────────────────────────────────────────────┘ │ │
│ └─────────────────────────────────────────────────────────────────┘ │
│ ↓ │
│ ┌─────────────────────────────────────────────────────────────────┐ │
│ │ 训练结束处理 │ │
│ │ 保存最终检查点 → 绘制损失曲线 → 快速推理测试 │ │
│ └─────────────────────────────────────────────────────────────────┘ │
└─────────────────────────────────────────────────────────────────────────┘
2.2 文件依赖关系
train.py
├── schedules/noise_schedule.py # 噪声调度器
│ └── NoiseScheduler 类
│ └── get_noise_schedule() 函数
├── data/dataset.py # 数据集模块
│ └── get_dataloader() 函数
│ └── SyntheticDrivingDataset 类
├── utils/logger.py # 日志工具
│ └── setup_logger() 函数
├── utils/visualization.py # 可视化工具
│ ├── save_video()
│ ├── save_comparison()
│ ├── plot_training_loss()
│ ├── save_frames_as_images()
│ └── create_gif()
└── DriveGen/models/ # 模型模块(可选)
└── STDiT 类
3. 核心函数详解
3.1 参数解析:parse_args()
文件位置:train.py#L47-L116
功能:解析命令行参数,支持灵活的训练配置。
|
参数 |
类型 |
默认值 |
说明 |
|---|---|---|---|
|
|
str |
|
配置文件路径 |
|
|
str |
|
从检查点恢复训练 |
|
|
int |
|
训练轮数(覆盖配置) |
|
|
float |
|
学习率(覆盖配置) |
|
|
int |
|
批量大小(覆盖配置) |
|
|
int |
|
随机种子 |
|
|
str |
|
计算设备 |
|
|
str |
|
输出目录 |
使用示例:
# 使用默认配置
python train.py
# 指定训练参数
python train.py --epochs 100 --lr 0.0005 --batch_size 8
# 从检查点恢复
python train.py --resume checkpoints/latest.pth
# 使用自定义配置文件
python train.py --config configs/custom.yaml
3.2 配置加载:load_config()
文件位置:train.py#L119-L140
功能:加载 YAML 配置文件并返回配置字典。
def load_config(config_path: str) -> Dict[str, Any]:
if not os.path.exists(config_path):
raise FileNotFoundError(f"配置文件不存在: {config_path}")
with open(config_path, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
return config
配置文件结构(示例):
model:
hidden_dim: 192
depth: 6
num_heads: 6
patch_size: 4
num_frames: 4
condition_frames: 1
training:
num_epochs: 50
learning_rate: 0.0001
batch_size: 8
gradient_accumulation_steps: 4
noise:
num_timesteps: 1000
schedule: cosine
beta_start: 0.0001
beta_end: 0.02
3.3 设备管理:get_device()
文件位置:train.py#L166-L191
功能:自动检测并返回最佳计算设备。
设备优先级:
-
用户指定设备(
--device参数) -
CUDA(NVIDIA GPU)
-
MPS(Apple Silicon GPU)
-
CPU(兜底方案)
def get_device(device_str: Optional[str] = None) -> torch.device:
if device_str is not None:
return torch.device(device_str)
if torch.cuda.is_available():
device = torch.device('cuda')
print(f"[get_device] 使用 CUDA: {torch.cuda.get_device_name(0)}")
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
device = torch.device('mps')
print("[get_device] 使用 Apple MPS")
else:
device = torch.device('cpu')
return device
3.4 模型创建:create_model()
文件位置:train.py#L194-L236
功能:创建扩散模型,优先使用 STDiT,失败则回退到 SimpleUNet。
|
模型类型 |
输入格式 |
适用场景 |
|---|---|---|
|
STDiT |
|
正式训练,时空Transformer架构 |
|
SimpleUNet |
|
测试/验证,轻量级CNN |
模型配置参数:
|
参数 |
默认值 |
说明 |
|---|---|---|
|
|
192 |
隐藏层维度 |
|
|
6 |
Transformer层数 |
|
|
6 |
注意力头数 |
|
|
4 |
图像分块大小 |
|
|
4 |
总帧数(条件帧+未来帧) |
|
|
1 |
条件帧数 |
|
|
1000 |
最大时间步数 |
|
|
0.1 |
Dropout比例 |
3.5 SimpleUNet 模型实现
文件位置:train.py#L239-L372
架构设计:
输入: [noisy_frames (B,F,C,H,W) + condition_frames (B,1,C,H,W)]
↓
[展平 + 拼接] → (B, (F+1)*C, H, W)
↓
┌─────────────────────────────────────────────────────────┐
│ 编码器 │
│ enc1 → pool → enc2 → pool → enc3 → pool → bottleneck │
└─────────────────────────────────────────────────────────┘
↓
[时间步嵌入注入]
↓
┌─────────────────────────────────────────────────────────┐
│ 解码器(带跳跃连接) │
│ up3 → dec3 → up2 → dec2 → up1 → dec1 → output │
└─────────────────────────────────────────────────────────┘
↓
输出: noise_pred (B, F, C, H, W)
核心组件:
|
组件 |
功能 |
|---|---|
|
|
将时间步编码为嵌入向量 |
|
|
编码器卷积块 |
|
|
瓶颈层,融合时间信息 |
|
|
反卷积上采样 |
|
|
解码器卷积块(带跳跃连接) |
前向传播:
def forward(self, noisy_frames, t, condition_frames):
# 展平帧维度并拼接条件帧
x = noisy_frames.reshape(B, F * C, H, W)
cond = condition_frames.reshape(B, self.condition_frames * C, H, W)
x = torch.cat([x, cond], dim=1)
# 时间步嵌入
t_emb = self.time_mlp(t.float().unsqueeze(-1) / 1000.0)
# 编码器
e1 = self.enc1(x)
e2 = self.enc2(self.pool(e1))
e3 = self.enc3(self.pool(e2))
# 瓶颈层(注入时间信息)
b = self.bottleneck(self.pool(e3))
b = b + self.time_inject3(t_emb).unsqueeze(-1).unsqueeze(-1)
# 解码器(带跳跃连接)
d3 = self.up3(b)
d3 = torch.cat([d3, e3], dim=1) # 跳跃连接
# ... 后续解码层
# 输出预测噪声
out = self.output(d1)
out = out.reshape(B, F, C, H, W)
return out
3.6 检查点管理
3.6.1 保存检查点:save_checkpoint()
文件位置:train.py#L390-L426
检查点包含内容:
|
字段 |
类型 |
说明 |
|---|---|---|
|
|
int |
当前训练轮数 |
|
|
int |
全局训练步数 |
|
|
float |
最佳损失值 |
|
|
list |
损失历史记录 |
|
|
dict |
训练配置 |
|
|
dict |
模型权重 |
|
|
dict |
优化器状态 |
3.6.2 加载检查点:load_checkpoint()
文件位置:train.py#L429-L462
功能:从检查点恢复模型和优化器状态,支持断点续训。
3.7 快速推理测试:quick_inference_test()
文件位置:train.py#L466-L590
功能:训练结束后自动执行推理测试,生成样本视频验证模型效果。
采样流程:
# 1. 创建条件帧(渐变图案测试)
condition = torch.zeros(1, condition_frames, in_channels, image_size, image_size)
# 2. 从纯噪声开始
x = torch.randn(1, future_frames, in_channels, image_size, image_size)
# 3. DDPM 采样循环
for i in range(len(step_indices) - 1, -1, -1):
t = step_indices[i]
# 预测噪声
noise_pred = model(x, t, condition)
# DDPM 更新公式
# x_{t-1} = (1/sqrt(alpha_t)) * (x_t - noise_pred * (1-alpha_t)/sqrt(1-alpha_bar_t))
# + sigma_t * z
x = x_0_pred if i == 0 else torch.sqrt(alpha_t_prev) * x_0_pred + sigma_t * noise
输出格式:
|
格式 |
保存路径 |
说明 |
|---|---|---|
|
MP4 视频 |
|
生成的视频 |
|
PNG 帧 |
|
单独帧图片 |
|
GIF 动图 |
|
动画预览 |
3.8 主训练函数:train()
文件位置:train.py#L593-L851
训练流程详解:
阶段一:初始化
# 设置随机种子
torch.manual_seed(seed)
np.random.seed(seed)
# 获取计算设备
device = get_device(args.device)
# 创建噪声调度器
noise_scheduler = NoiseScheduler(
num_timesteps=noise_config['num_timesteps'],
schedule_type=noise_config['schedule'],
)
# 创建数据加载器
dataloader = get_dataloader(config)
# 创建模型
model, use_stdit = create_model(config, device)
# 创建优化器
optimizer = AdamW(
model.parameters(),
lr=train_config['learning_rate'],
betas=(0.9, 0.999),
weight_decay=0.01,
)
阶段二:恢复训练
resume_path = args.resume or train_config.get('resume')
if resume_path:
checkpoint = load_checkpoint(resume_path, model, optimizer)
start_epoch = checkpoint['epoch'] + 1
global_step = checkpoint['step']
best_loss = checkpoint['best_loss']
losses_history = checkpoint.get('losses', [])
阶段三:训练循环
单次迭代流程:
for batch in dataloader:
# 1. 数据预处理
condition_frames = batch['condition_frames'].to(device)
future_frames = batch['future_frames'].to(device)
# 归一化到 [-1, 1]
condition_frames = condition_frames * 2.0 - 1.0
future_frames = future_frames * 2.0 - 1.0
# 2. 前向扩散(加噪)
t = torch.randint(0, max_timesteps, (B,), device=device)
noise = torch.randn_like(future_frames)
noisy_future = noise_scheduler.add_noise(future_frames, t, noise)
# 3. 模型预测
if use_stdit:
noise_pred = model(noisy_future.permute(0,2,1,3,4), t, condition_frames.permute(0,2,1,3,4))
noise_pred = noise_pred.permute(0,2,1,3,4)
else:
noise_pred = model(noisy_future, t, condition_frames)
# 4. 损失计算
loss = nn.functional.mse_loss(noise_pred, noise)
# 5. 梯度累积与更新
loss = loss / grad_accum_steps
loss.backward()
if (batch_idx + 1) % grad_accum_steps == 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
optimizer.step()
optimizer.zero_grad()
global_step += 1
阶段四:检查点保存
# 定期保存检查点
if (epoch + 1) % save_every == 0:
save_checkpoint(model, optimizer, epoch, global_step, best_loss, losses_history, config, 'latest.pth')
# 保存最佳模型
if avg_epoch_loss < best_loss:
best_loss = avg_epoch_loss
save_checkpoint(model, optimizer, epoch, global_step, best_loss, losses_history, config, 'best.pth')
阶段五:训练结束处理
# 保存最终检查点
save_checkpoint(model, optimizer, num_epochs - 1, global_step, best_loss, losses_history, config, 'final.pth')
# 绘制损失曲线
plot_training_loss(losses_history, 'loss_curve_final.png')
# 快速推理测试
quick_inference_test(model, noise_scheduler, config, device, 'outputs', logger)
3.9 主入口:main()
文件位置:train.py#L855-L883
执行流程:
-
解析命令行参数
-
加载 YAML 配置文件
-
合并命令行参数与配置(命令行参数优先)
-
启动训练
-
处理异常(KeyboardInterrupt、其他异常)
4. 扩散模型核心原理
4.1 前向扩散过程
扩散模型通过逐步向数据添加噪声来学习数据分布。对于视频生成任务:
x_0 → x_1 → x_2 → ... → x_T
其中 x_0 是原始视频帧,x_T 是纯高斯噪声。
加噪公式:
x_t = sqrt(alpha_cumprod_t) * x_0 + sqrt(1 - alpha_cumprod_t) * epsilon
-
alpha_cumprod_t:时间步 t 的累积噪声系数 -
epsilon:标准高斯噪声
4.2 反向去噪过程
训练时,模型学习预测噪声:
model(x_t, t, condition) → epsilon_pred
损失函数:
loss = MSE(epsilon_pred, epsilon)
4.3 采样过程
推理时,从纯噪声开始逐步去噪:
for t in range(T, 0, -1):
epsilon_pred = model(x_t, t, condition)
x_{t-1} = (x_t - (1-alpha_t)/sqrt(1-alpha_cumprod_t) * epsilon_pred) / sqrt(alpha_t)
+ sigma_t * noise # 仅在 t > 0 时添加噪声
5. 关键技术特性
5.1 梯度累积
作用:在显存有限的情况下,模拟更大的批量大小。
grad_accum_steps = train_config.get('gradient_accumulation_steps', 1)
loss = loss / grad_accum_steps # 损失除以累积步数
loss.backward()
if (batch_idx + 1) % grad_accum_steps == 0:
optimizer.step()
optimizer.zero_grad()
等效批量大小:batch_size * grad_accum_steps
5.2 梯度裁剪
作用:防止梯度爆炸,稳定训练。
grad_clip = train_config.get('grad_clip', 1.0)
if grad_clip > 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
5.3 噪声调度策略
支持的调度类型:
|
类型 |
特点 |
适用场景 |
|---|---|---|
|
|
线性增加噪声 |
简单场景 |
|
|
余弦调度,前期噪声增长慢 |
复杂数据分布 |
|
|
Sigmoid 调度 |
平滑过渡 |
6. 输出文件结构
outputs/
├── checkpoints/
│ ├── latest.pth # 最新检查点
│ ├── best.pth # 最佳模型
│ └── final.pth # 最终检查点
├── logs/
│ ├── train.log # 训练日志
│ ├── loss_curve.png # 损失曲线
│ └── loss_curve_final.png
└── test_samples/
├── sample_0/
│ ├── generated.mp4
│ ├── generated.gif
│ └── frame_00.png, frame_01.png, ...
└── sample_1/
├── generated.mp4
├── generated.gif
└── frame_00.png, frame_01.png, ...
7. 性能优化建议
7.1 训练效率优化
|
优化项 |
建议值 |
说明 |
|---|---|---|
|
批量大小 |
8-32 |
根据显存调整 |
|
梯度累积 |
4-8 |
模拟大批次 |
|
混合精度 |
FP16/FP8 |
加速训练,减少显存 |
|
数据预加载 |
开启 |
使用 |
7.2 超参数建议
|
参数 |
建议范围 |
说明 |
|---|---|---|
|
学习率 |
1e-4 ~ 5e-4 |
AdamW 优化器 |
|
权重衰减 |
1e-2 ~ 1e-4 |
防止过拟合 |
|
时间步数 |
1000 |
标准配置 |
|
beta_start |
1e-4 |
初始噪声系数 |
|
beta_end |
0.02 |
终止噪声系数 |
8. 常见问题
8.1 STDiT 导入失败
问题:运行时提示 ModuleNotFoundError: No module named 'DriveGen.models'
解决方案:
-
确保
DriveGen/models/目录存在且包含__init__.py和 STDiT 实现 -
脚本会自动回退到 SimpleUNet,可用于测试
8.2 显存不足
解决方案:
-
减小
batch_size -
增加
gradient_accumulation_steps -
使用更小的模型配置(
hidden_dim,depth) -
启用混合精度训练
8.3 训练损失不下降
排查步骤:
-
检查数据加载是否正确(帧顺序、归一化范围)
-
确认噪声调度器配置正确
-
检查学习率是否合适
-
验证模型输出形状是否与标签匹配
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐


所有评论(0)