【DriveGen 文件详解】01——default.yaml
DriveGen/
├── configs/
│ └── default.yaml # 训练配置文件
├── DriveGen/
│ ├── __init__.py # 包初始化
│ ├── models/
│ │ ├── __init__.py
│ │ ├── embedding.py # Patch嵌入、时间步编码、位置编码
│ │ ├── attention.py # 空间注意力、时间注意力
│ │ ├── dit_block.py # AdaLN-Zero DiT Block
│ │ └── stdit.py # STDiT 完整模型
│ ├── data/
│ │ ├── __init__.py
│ │ └── dataset.py # 合成数据集 + nuScenes 适配器
│ ├── schedules/
│ │ ├── __init__.py
│ │ └── noise_schedule.py # 线性/余弦噪声调度
│ └── utils/
│ ├── __init__.py
│ ├── visualization.py # 视频保存、对比图、损失曲线
│ └── logger.py # 日志工具
├── train.py # 训练脚本
├── inference.py # 推理脚本(DDPM 采样 + CFG)
├── evaluate.py # 评估脚本(FID 计算)
├── requirements.txt # 依赖清单
├── setup.py # 安装配置
└── README.md # 本文件
LQY-hh/DriveGen-Transformer-: 自动驾驶技术的发展离不开海量数据的支撑,但稀有场景(如极端天气、突发事故)的数据采集成本极高。**DriveGen** 旨在通过扩散模型生成高质量的驾驶场景视频,为自动驾驶算法提供无限的虚拟训练数据。 ### 核心价值
https://github.com/LQY-hh/DriveGen-Transformer-
DriveGen 配置文件详解
本文档详细解读 default.yaml 配置文件的各项参数。
概述
该配置文件用于 DriveGen 项目,这是一个基于 扩散模型(Diffusion Model) 和 Transformer 架构的自动驾驶场景视频帧生成器。
核心功能:输入1帧条件图像,生成后续3帧连续图像,模拟自动驾驶场景的帧序列生成。
配置结构总览
|
配置模块 |
功能说明 |
|---|---|
|
|
模型架构参数 |
|
|
数据集配置 |
|
|
训练超参数 |
|
|
扩散噪声调度 |
|
|
推理生成配置 |
|
|
模型评估配置 |
|
|
日志记录配置 |
1. Model(模型参数)
model:
hidden_dim: 192
depth: 6
num_heads: 6
patch_size: 4
mlp_ratio: 4.0
in_channels: 3
out_channels: 3
num_frames: 4
condition_frames: 1
max_timestep: 1000
dropout: 0.1
|
参数 |
值 |
说明 |
|---|---|---|
|
|
192 |
Transformer 隐藏层特征维度 |
|
|
6 |
Transformer 编码器/解码器层数 |
|
|
6 |
多头注意力机制的头数 |
|
|
4 |
图像分块大小(4×4 patch) |
|
|
4.0 |
MLP 层扩展比例 |
|
|
3 |
输入通道数(RGB) |
|
|
3 |
输出通道数(RGB) |
|
|
4 |
总帧数(1条件帧 + 3生成帧) |
|
|
1 |
条件帧数量 |
|
|
1000 |
最大扩散时间步 |
|
|
0.1 |
Dropout 比例(防止过拟合) |
2. Data(数据配置)
data:
dataset_type: "synthetic"
image_size: 64
num_frames: 4
synthetic_samples: 500
synthetic_dir: "data/synthetic"
nuscenes_dir: ""
nuscenes_version: "v1.0-mini"
camera: "CAM_FRONT"
num_workers: 0
batch_size: 2
|
参数 |
值 |
说明 |
|---|---|---|
|
|
|
数据集类型选择 |
|
|
64 |
图像尺寸(64×64像素) |
|
|
4 |
每个样本帧数 |
|
|
500 |
合成数据样本数 |
|
|
|
合成数据目录 |
|
|
|
NuScenes 数据集路径(为空则不使用) |
|
|
|
NuScenes 版本 |
|
|
|
相机视角(前置相机) |
|
|
0 |
数据加载线程数(0为主线程) |
|
|
2 |
批量大小 |
3. Training(训练配置)
training:
num_epochs: 50
learning_rate: 1.0e-4
weight_decay: 0.01
beta1: 0.9
beta2: 0.999
grad_clip: 1.0
gradient_accumulation_steps: 4
save_every: 10
log_every: 10
checkpoint_dir: "checkpoints"
resume: null
|
参数 |
值 |
说明 |
|---|---|---|
|
|
50 |
训练轮数 |
|
|
1e-4 |
学习率 |
|
|
0.01 |
权重衰减(L2正则化) |
|
|
0.9 |
Adam 优化器参数 |
|
|
0.999 |
Adam 优化器参数 |
|
|
1.0 |
梯度裁剪阈值 |
|
|
4 |
梯度累积步数(等效 batch_size = 2×4 = 8) |
|
|
10 |
每10个epoch保存模型 |
|
|
10 |
每10个step记录日志 |
|
|
|
模型保存目录 |
|
|
|
恢复训练的checkpoint路径 |
4. Noise Schedule(噪声调度)
noise:
schedule: "linear"
beta_start: 0.0001
beta_end: 0.02
num_timesteps: 1000
|
参数 |
值 |
说明 |
|---|---|---|
|
|
|
噪声调度类型 |
|
|
0.0001 |
噪声起始值 |
|
|
0.02 |
噪声终止值 |
|
|
1000 |
扩散步骤数 |
扩散模型原理:通过逐步向数据添加噪声,然后学习逆向过程来生成数据。线性调度表示噪声强度从 beta_start 线性增长到 beta_end。
5. Inference(推理配置)
inference:
num_inference_steps: 100
cfg_scale: 3.0
output_dir: "outputs"
seed: 42
|
参数 |
值 |
说明 |
|---|---|---|
|
|
100 |
推理扩散步数(少于训练,加速生成) |
|
|
3.0 |
分类器自由引导系数(控制生成多样性) |
|
|
|
生成结果输出目录 |
|
|
42 |
随机种子(保证结果可复现) |
6. Evaluation(评估配置)
evaluation:
num_samples: 100
batch_size: 8
fid_batch_size: 32
|
参数 |
值 |
说明 |
|---|---|---|
|
|
100 |
评估样本数 |
|
|
8 |
评估批量大小 |
|
|
32 |
FID 指标计算批量 |
7. Logging(日志配置)
logging:
log_dir: "logs"
use_wandb: false
|
参数 |
值 |
说明 |
|---|---|---|
|
|
|
日志保存目录 |
|
|
|
是否启用 Weights & Biases 可视化 |
工作流程示意
输入条件帧 → [Transformer + Diffusion] → 生成后续3帧 → 输出视频序列
↓
训练阶段:1000步扩散
↓
推理阶段:100步快速生成
配置文件完整路径
-
配置文件:
configs/default.yaml
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐


所有评论(0)