DriveGen/
├── configs/
│   └── default.yaml              # 训练配置文件
├── DriveGen/
│   ├── __init__.py               # 包初始化
│   ├── models/
│   │   ├── __init__.py
│   │   ├── embedding.py          # Patch嵌入、时间步编码、位置编码
│   │   ├── attention.py          # 空间注意力、时间注意力
│   │   ├── dit_block.py          # AdaLN-Zero DiT Block
│   │   └── stdit.py              # STDiT 完整模型
│   ├── data/
│   │   ├── __init__.py
│   │   └── dataset.py            # 合成数据集 + nuScenes 适配器
│   ├── schedules/
│   │   ├── __init__.py
│   │   └── noise_schedule.py     # 线性/余弦噪声调度
│   └── utils/
│       ├── __init__.py
│       ├── visualization.py      # 视频保存、对比图、损失曲线
│       └── logger.py             # 日志工具
├── train.py                      # 训练脚本
├── inference.py                  # 推理脚本(DDPM 采样 + CFG)
├── evaluate.py                   # 评估脚本(FID 计算)
├── requirements.txt              # 依赖清单
├── setup.py                      # 安装配置
└── README.md                     # 本文件

LQY-hh/DriveGen-Transformer-: 自动驾驶技术的发展离不开海量数据的支撑,但稀有场景(如极端天气、突发事故)的数据采集成本极高。**DriveGen** 旨在通过扩散模型生成高质量的驾驶场景视频,为自动驾驶算法提供无限的虚拟训练数据。 ### 核心价值https://github.com/LQY-hh/DriveGen-Transformer-DriveGen 训练脚本技术文档

1. 文件概述

train.pyDriveGen 项目的核心训练入口,实现了一个完整的条件视频扩散模型训练流程。该脚本基于扩散模型(Diffusion Model)原理,从驾驶场景的条件帧出发,学习预测未来帧序列,可应用于自动驾驶仿真、数据增强等场景。

核心功能

模块

功能描述

参数解析

支持命令行参数与 YAML 配置文件混合使用

噪声调度

管理扩散过程的加噪/去噪时序

模型构建

支持 STDiT(时空扩散Transformer)和 SimpleUNet

训练循环

实现完整的扩散训练流程

检查点管理

支持保存/恢复训练进度

推理测试

训练结束后自动生成测试样本


2. 技术架构

2.1 整体架构图

┌─────────────────────────────────────────────────────────────────────────┐
│                         train.py 训练流程                               │
├─────────────────────────────────────────────────────────────────────────┤
│  命令行参数 ──┐                                                         │
│              │                                                         │
│  YAML 配置 ──┼──→ 配置合并 ──→ 训练配置                                 │
│              │                                                         │
│              ↓                                                         │
│  ┌─────────────────────────────────────────────────────────────────┐    │
│  │                       初始化阶段                                │    │
│  │  随机种子 → 设备检测 → 日志设置 → 噪声调度器 → 数据加载器        │    │
│  │                          → 模型创建 → 优化器                   │    │
│  └─────────────────────────────────────────────────────────────────┘    │
│                              ↓                                         │
│  ┌─────────────────────────────────────────────────────────────────┐    │
│  │                       训练循环                                  │    │
│  │  ┌─────────────────────────────────────────────────────────┐   │    │
│  │  │ 前向扩散: 随机时间步 → 生成噪声 → 加噪到未来帧            │   │    │
│  │  └─────────────────────────────────────────────────────────┘   │    │
│  │                              ↓                                 │   │    │
│  │  ┌─────────────────────────────────────────────────────────┐   │    │
│  │  │ 模型预测: noisy_future + t + condition → noise_pred     │   │    │
│  │  └─────────────────────────────────────────────────────────┘   │    │
│  │                              ↓                                 │   │    │
│  │  ┌─────────────────────────────────────────────────────────┐   │    │
│  │  │ 损失计算: MSE(noise_pred, noise)                        │   │    │
│  │  └─────────────────────────────────────────────────────────┘   │    │
│  │                              ↓                                 │   │    │
│  │  ┌─────────────────────────────────────────────────────────┐   │    │
│  │  │ 梯度累积 → 梯度裁剪 → 优化器更新                         │   │    │
│  │  └─────────────────────────────────────────────────────────┘   │    │
│  │                              ↓                                 │   │    │
│  │  ┌─────────────────────────────────────────────────────────┐   │    │
│  │  │ 检查点保存 / 最佳模型保存 / 损失曲线绘制                 │   │    │
│  │  └─────────────────────────────────────────────────────────┘   │    │
│  └─────────────────────────────────────────────────────────────────┘    │
│                              ↓                                         │
│  ┌─────────────────────────────────────────────────────────────────┐    │
│  │                    训练结束处理                                  │    │
│  │  保存最终检查点 → 绘制损失曲线 → 快速推理测试                    │    │
│  └─────────────────────────────────────────────────────────────────┘    │
└─────────────────────────────────────────────────────────────────────────┘

2.2 文件依赖关系

train.py
├── schedules/noise_schedule.py     # 噪声调度器
│   └── NoiseScheduler 类
│   └── get_noise_schedule() 函数
├── data/dataset.py                 # 数据集模块
│   └── get_dataloader() 函数
│   └── SyntheticDrivingDataset 类
├── utils/logger.py                 # 日志工具
│   └── setup_logger() 函数
├── utils/visualization.py          # 可视化工具
│   ├── save_video()
│   ├── save_comparison()
│   ├── plot_training_loss()
│   ├── save_frames_as_images()
│   └── create_gif()
└── DriveGen/models/                # 模型模块(可选)
    └── STDiT 类

3. 核心函数详解

3.1 参数解析:parse_args()

文件位置:train.py#L47-L116

功能:解析命令行参数,支持灵活的训练配置。

参数

类型

默认值

说明

--config

str

configs/default.yaml

配置文件路径

--resume

str

None

从检查点恢复训练

--epochs

int

None

训练轮数(覆盖配置)

--lr

float

None

学习率(覆盖配置)

--batch_size

int

None

批量大小(覆盖配置)

--seed

int

42

随机种子

--device

str

None

计算设备

--output_dir

str

None

输出目录

使用示例

# 使用默认配置
python train.py

# 指定训练参数
python train.py --epochs 100 --lr 0.0005 --batch_size 8

# 从检查点恢复
python train.py --resume checkpoints/latest.pth

# 使用自定义配置文件
python train.py --config configs/custom.yaml

3.2 配置加载:load_config()

文件位置:train.py#L119-L140

功能:加载 YAML 配置文件并返回配置字典。

def load_config(config_path: str) -> Dict[str, Any]:
    if not os.path.exists(config_path):
        raise FileNotFoundError(f"配置文件不存在: {config_path}")
    
    with open(config_path, 'r', encoding='utf-8') as f:
        config = yaml.safe_load(f)
    
    return config

配置文件结构(示例):

model:
  hidden_dim: 192
  depth: 6
  num_heads: 6
  patch_size: 4
  num_frames: 4
  condition_frames: 1

training:
  num_epochs: 50
  learning_rate: 0.0001
  batch_size: 8
  gradient_accumulation_steps: 4

noise:
  num_timesteps: 1000
  schedule: cosine
  beta_start: 0.0001
  beta_end: 0.02

3.3 设备管理:get_device()

文件位置:train.py#L166-L191

功能:自动检测并返回最佳计算设备。

设备优先级

  1. 用户指定设备(--device 参数)

  2. CUDA(NVIDIA GPU)

  3. MPS(Apple Silicon GPU)

  4. CPU(兜底方案)

def get_device(device_str: Optional[str] = None) -> torch.device:
    if device_str is not None:
        return torch.device(device_str)
    
    if torch.cuda.is_available():
        device = torch.device('cuda')
        print(f"[get_device] 使用 CUDA: {torch.cuda.get_device_name(0)}")
    elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
        device = torch.device('mps')
        print("[get_device] 使用 Apple MPS")
    else:
        device = torch.device('cpu')
    
    return device

3.4 模型创建:create_model()

文件位置:train.py#L194-L236

功能:创建扩散模型,优先使用 STDiT,失败则回退到 SimpleUNet。

模型类型

输入格式

适用场景

STDiT

(B, C, T, H, W)

正式训练,时空Transformer架构

SimpleUNet

(B, T, C, H, W)

测试/验证,轻量级CNN

模型配置参数

参数

默认值

说明

hidden_dim

192

隐藏层维度

depth

6

Transformer层数

num_heads

6

注意力头数

patch_size

4

图像分块大小

num_frames

4

总帧数(条件帧+未来帧)

condition_frames

1

条件帧数

max_timestep

1000

最大时间步数

dropout

0.1

Dropout比例


3.5 SimpleUNet 模型实现

文件位置:train.py#L239-L372

架构设计

输入: [noisy_frames (B,F,C,H,W) + condition_frames (B,1,C,H,W)]
         ↓
    [展平 + 拼接] → (B, (F+1)*C, H, W)
         ↓
    ┌─────────────────────────────────────────────────────────┐
    │                    编码器                               │
    │  enc1 → pool → enc2 → pool → enc3 → pool → bottleneck │
    └─────────────────────────────────────────────────────────┘
         ↓
    [时间步嵌入注入]
         ↓
    ┌─────────────────────────────────────────────────────────┐
    │                    解码器(带跳跃连接)                  │
    │  up3 → dec3 → up2 → dec2 → up1 → dec1 → output        │
    └─────────────────────────────────────────────────────────┘
         ↓
    输出: noise_pred (B, F, C, H, W)

核心组件

组件

功能

time_mlp

将时间步编码为嵌入向量

enc1/enc2/enc3

编码器卷积块

bottleneck

瓶颈层,融合时间信息

up1/up2/up3

反卷积上采样

dec1/dec2/dec3

解码器卷积块(带跳跃连接)

前向传播

def forward(self, noisy_frames, t, condition_frames):
    # 展平帧维度并拼接条件帧
    x = noisy_frames.reshape(B, F * C, H, W)
    cond = condition_frames.reshape(B, self.condition_frames * C, H, W)
    x = torch.cat([x, cond], dim=1)
    
    # 时间步嵌入
    t_emb = self.time_mlp(t.float().unsqueeze(-1) / 1000.0)
    
    # 编码器
    e1 = self.enc1(x)
    e2 = self.enc2(self.pool(e1))
    e3 = self.enc3(self.pool(e2))
    
    # 瓶颈层(注入时间信息)
    b = self.bottleneck(self.pool(e3))
    b = b + self.time_inject3(t_emb).unsqueeze(-1).unsqueeze(-1)
    
    # 解码器(带跳跃连接)
    d3 = self.up3(b)
    d3 = torch.cat([d3, e3], dim=1)  # 跳跃连接
    # ... 后续解码层
    
    # 输出预测噪声
    out = self.output(d1)
    out = out.reshape(B, F, C, H, W)
    
    return out

3.6 检查点管理

3.6.1 保存检查点:save_checkpoint()

文件位置:train.py#L390-L426

检查点包含内容

字段

类型

说明

epoch

int

当前训练轮数

step

int

全局训练步数

best_loss

float

最佳损失值

losses

list

损失历史记录

config

dict

训练配置

model_state_dict

dict

模型权重

optimizer_state_dict

dict

优化器状态

3.6.2 加载检查点:load_checkpoint()

文件位置:train.py#L429-L462

功能:从检查点恢复模型和优化器状态,支持断点续训。


3.7 快速推理测试:quick_inference_test()

文件位置:train.py#L466-L590

功能:训练结束后自动执行推理测试,生成样本视频验证模型效果。

采样流程

# 1. 创建条件帧(渐变图案测试)
condition = torch.zeros(1, condition_frames, in_channels, image_size, image_size)

# 2. 从纯噪声开始
x = torch.randn(1, future_frames, in_channels, image_size, image_size)

# 3. DDPM 采样循环
for i in range(len(step_indices) - 1, -1, -1):
    t = step_indices[i]
    
    # 预测噪声
    noise_pred = model(x, t, condition)
    
    # DDPM 更新公式
    # x_{t-1} = (1/sqrt(alpha_t)) * (x_t - noise_pred * (1-alpha_t)/sqrt(1-alpha_bar_t))
    #          + sigma_t * z
    x = x_0_pred if i == 0 else torch.sqrt(alpha_t_prev) * x_0_pred + sigma_t * noise

输出格式

格式

保存路径

说明

MP4 视频

outputs/test_samples/sample_X/generated.mp4

生成的视频

PNG 帧

outputs/test_samples/sample_X/frame_XX.png

单独帧图片

GIF 动图

outputs/test_samples/sample_X/generated.gif

动画预览


3.8 主训练函数:train()

文件位置:train.py#L593-L851

训练流程详解

阶段一:初始化
# 设置随机种子
torch.manual_seed(seed)
np.random.seed(seed)

# 获取计算设备
device = get_device(args.device)

# 创建噪声调度器
noise_scheduler = NoiseScheduler(
    num_timesteps=noise_config['num_timesteps'],
    schedule_type=noise_config['schedule'],
)

# 创建数据加载器
dataloader = get_dataloader(config)

# 创建模型
model, use_stdit = create_model(config, device)

# 创建优化器
optimizer = AdamW(
    model.parameters(),
    lr=train_config['learning_rate'],
    betas=(0.9, 0.999),
    weight_decay=0.01,
)
阶段二:恢复训练
resume_path = args.resume or train_config.get('resume')
if resume_path:
    checkpoint = load_checkpoint(resume_path, model, optimizer)
    start_epoch = checkpoint['epoch'] + 1
    global_step = checkpoint['step']
    best_loss = checkpoint['best_loss']
    losses_history = checkpoint.get('losses', [])
阶段三:训练循环

单次迭代流程

for batch in dataloader:
    # 1. 数据预处理
    condition_frames = batch['condition_frames'].to(device)
    future_frames = batch['future_frames'].to(device)
    # 归一化到 [-1, 1]
    condition_frames = condition_frames * 2.0 - 1.0
    future_frames = future_frames * 2.0 - 1.0
    
    # 2. 前向扩散(加噪)
    t = torch.randint(0, max_timesteps, (B,), device=device)
    noise = torch.randn_like(future_frames)
    noisy_future = noise_scheduler.add_noise(future_frames, t, noise)
    
    # 3. 模型预测
    if use_stdit:
        noise_pred = model(noisy_future.permute(0,2,1,3,4), t, condition_frames.permute(0,2,1,3,4))
        noise_pred = noise_pred.permute(0,2,1,3,4)
    else:
        noise_pred = model(noisy_future, t, condition_frames)
    
    # 4. 损失计算
    loss = nn.functional.mse_loss(noise_pred, noise)
    
    # 5. 梯度累积与更新
    loss = loss / grad_accum_steps
    loss.backward()
    
    if (batch_idx + 1) % grad_accum_steps == 0:
        torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
        optimizer.step()
        optimizer.zero_grad()
        global_step += 1
阶段四:检查点保存
# 定期保存检查点
if (epoch + 1) % save_every == 0:
    save_checkpoint(model, optimizer, epoch, global_step, best_loss, losses_history, config, 'latest.pth')

# 保存最佳模型
if avg_epoch_loss < best_loss:
    best_loss = avg_epoch_loss
    save_checkpoint(model, optimizer, epoch, global_step, best_loss, losses_history, config, 'best.pth')
阶段五:训练结束处理
# 保存最终检查点
save_checkpoint(model, optimizer, num_epochs - 1, global_step, best_loss, losses_history, config, 'final.pth')

# 绘制损失曲线
plot_training_loss(losses_history, 'loss_curve_final.png')

# 快速推理测试
quick_inference_test(model, noise_scheduler, config, device, 'outputs', logger)

3.9 主入口:main()

文件位置:train.py#L855-L883

执行流程

  1. 解析命令行参数

  2. 加载 YAML 配置文件

  3. 合并命令行参数与配置(命令行参数优先)

  4. 启动训练

  5. 处理异常(KeyboardInterrupt、其他异常)


4. 扩散模型核心原理

4.1 前向扩散过程

扩散模型通过逐步向数据添加噪声来学习数据分布。对于视频生成任务:

x_0 → x_1 → x_2 → ... → x_T

其中 x_0 是原始视频帧,x_T 是纯高斯噪声。

加噪公式

x_t = sqrt(alpha_cumprod_t) * x_0 + sqrt(1 - alpha_cumprod_t) * epsilon
  • alpha_cumprod_t:时间步 t 的累积噪声系数

  • epsilon:标准高斯噪声

4.2 反向去噪过程

训练时,模型学习预测噪声:

model(x_t, t, condition) → epsilon_pred

损失函数

loss = MSE(epsilon_pred, epsilon)

4.3 采样过程

推理时,从纯噪声开始逐步去噪:

for t in range(T, 0, -1):
    epsilon_pred = model(x_t, t, condition)
    x_{t-1} = (x_t - (1-alpha_t)/sqrt(1-alpha_cumprod_t) * epsilon_pred) / sqrt(alpha_t)
              + sigma_t * noise  # 仅在 t > 0 时添加噪声

5. 关键技术特性

5.1 梯度累积

作用:在显存有限的情况下,模拟更大的批量大小。

grad_accum_steps = train_config.get('gradient_accumulation_steps', 1)
loss = loss / grad_accum_steps  # 损失除以累积步数
loss.backward()

if (batch_idx + 1) % grad_accum_steps == 0:
    optimizer.step()
    optimizer.zero_grad()

等效批量大小batch_size * grad_accum_steps

5.2 梯度裁剪

作用:防止梯度爆炸,稳定训练。

grad_clip = train_config.get('grad_clip', 1.0)
if grad_clip > 0:
    torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)

5.3 噪声调度策略

支持的调度类型

类型

特点

适用场景

linear

线性增加噪声

简单场景

cosine

余弦调度,前期噪声增长慢

复杂数据分布

sigmoid

Sigmoid 调度

平滑过渡


6. 输出文件结构

outputs/
├── checkpoints/
│   ├── latest.pth        # 最新检查点
│   ├── best.pth          # 最佳模型
│   └── final.pth         # 最终检查点
├── logs/
│   ├── train.log         # 训练日志
│   ├── loss_curve.png    # 损失曲线
│   └── loss_curve_final.png
└── test_samples/
    ├── sample_0/
    │   ├── generated.mp4
    │   ├── generated.gif
    │   └── frame_00.png, frame_01.png, ...
    └── sample_1/
        ├── generated.mp4
        ├── generated.gif
        └── frame_00.png, frame_01.png, ...

7. 性能优化建议

7.1 训练效率优化

优化项

建议值

说明

批量大小

8-32

根据显存调整

梯度累积

4-8

模拟大批次

混合精度

FP16/FP8

加速训练,减少显存

数据预加载

开启

使用 pin_memory=True

7.2 超参数建议

参数

建议范围

说明

学习率

1e-4 ~ 5e-4

AdamW 优化器

权重衰减

1e-2 ~ 1e-4

防止过拟合

时间步数

1000

标准配置

beta_start

1e-4

初始噪声系数

beta_end

0.02

终止噪声系数


8. 常见问题

8.1 STDiT 导入失败

问题:运行时提示 ModuleNotFoundError: No module named 'DriveGen.models'

解决方案

  1. 确保 DriveGen/models/ 目录存在且包含 __init__.py 和 STDiT 实现

  2. 脚本会自动回退到 SimpleUNet,可用于测试

8.2 显存不足

解决方案

  1. 减小 batch_size

  2. 增加 gradient_accumulation_steps

  3. 使用更小的模型配置(hidden_dim, depth

  4. 启用混合精度训练

8.3 训练损失不下降

排查步骤

  1. 检查数据加载是否正确(帧顺序、归一化范围)

  2. 确认噪声调度器配置正确

  3. 检查学习率是否合适

  4. 验证模型输出形状是否与标签匹配

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐