DriveGen/
├── configs/
│   └── default.yaml              # 训练配置文件
├── DriveGen/
│   ├── __init__.py               # 包初始化
│   ├── models/
│   │   ├── __init__.py
│   │   ├── embedding.py          # Patch嵌入、时间步编码、位置编码
│   │   ├── attention.py          # 空间注意力、时间注意力
│   │   ├── dit_block.py          # AdaLN-Zero DiT Block
│   │   └── stdit.py              # STDiT 完整模型
│   ├── data/
│   │   ├── __init__.py
│   │   └── dataset.py            # 合成数据集 + nuScenes 适配器
│   ├── schedules/
│   │   ├── __init__.py
│   │   └── noise_schedule.py     # 线性/余弦噪声调度
│   └── utils/
│       ├── __init__.py
│       ├── visualization.py      # 视频保存、对比图、损失曲线
│       └── logger.py             # 日志工具
├── train.py                      # 训练脚本
├── inference.py                  # 推理脚本(DDPM 采样 + CFG)
├── evaluate.py                   # 评估脚本(FID 计算)
├── requirements.txt              # 依赖清单
├── setup.py                      # 安装配置
└── README.md                     # 本文件

LQY-hh/DriveGen-Transformer-: 自动驾驶技术的发展离不开海量数据的支撑,但稀有场景(如极端天气、突发事故)的数据采集成本极高。**DriveGen** 旨在通过扩散模型生成高质量的驾驶场景视频,为自动驾驶算法提供无限的虚拟训练数据。 ### 核心价值https://github.com/LQY-hh/DriveGen-Transformer-

概述

inference.py 是 DriveGen 项目的推理入口文件,负责加载训练好的扩散模型(Diffusion Model),执行 DDPM(Denoising Diffusion Probabilistic Models)采样过程,从条件帧生成后续的未来视频帧。

核心功能

功能模块

说明

实现位置

DDPM 迭代去噪采样

从纯噪声开始逐步去噪生成视频帧

ddpm_sample

Classifier-Free Guidance (CFG)

通过对比有条件/无条件预测提高生成质量

ddpm_sample

条件帧生成

支持从数据集采样或生成合成条件帧

generate_condition_frames

多格式输出

MP4 视频、单帧图片、GIF 动画、对比图

main


执行流程

┌─────────────────────────────────────────────────────────────────┐
│                      推理执行流程                                │
├─────────────────────────────────────────────────────────────────┤
│  1. 解析命令行参数                                               │
│         ↓                                                      │
│  2. 加载配置文件(YAML)                                         │
│         ↓                                                      │
│  3. 初始化计算设备(CUDA/MPS/CPU)                               │
│         ↓                                                      │
│  4. 创建噪声调度器                                               │
│         ↓                                                      │
│  5. 加载预训练模型(STDiT / SimpleUNet)                         │
│         ↓                                                      │
│  6. 生成条件帧(起始帧)                                         │
│         ↓                                                      │
│  7. 执行 DDPM 采样(迭代去噪)                                   │
│         ↓                                                      │
│  8. 拼接条件帧+生成帧,保存输出                                   │
└─────────────────────────────────────────────────────────────────┘

模块详解

1. 参数解析 (parse_args)

解析命令行参数,支持灵活配置推理行为:

参数

类型

默认值

说明

--checkpoint, -c

str

必需

模型检查点路径

--config

str

configs/default.yaml

配置文件路径

--num_samples, -n

int

4

生成样本数量

--steps

int

配置文件值

推理步数

--cfg_scale

float

配置文件值

CFG 缩放因子

--output_dir, -o

str

配置文件值

输出目录

--seed

int

配置文件值

随机种子

--device

str

自动检测

计算设备

--save_frames

flag

False

是否保存单帧图片

--save_gif

flag

False

是否保存 GIF 动画

代码实现

parser = argparse.ArgumentParser(
    description="DriveGen 推理脚本 - 使用训练好的模型生成驾驶场景视频",
    formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser.add_argument('--checkpoint', '-c', type=str, required=True, ...)
# ... 其他参数定义

2. DDPM 采样核心 (ddpm_sample)

这是文件的核心算法,实现了 DDPM 的逆向采样过程。

数学原理

DDPM 采样遵循以下数学公式:

  1. 预测噪声$\epsilon_\theta(x_t, t, \text{condition})$

  2. 估计$x_0$$\hat{x}_0 = \frac{x_t - \sqrt{1-\bar{\alpha}t} \cdot \epsilon\theta}{\sqrt{\bar{\alpha}_t}}$

  3. 生成 $x_{t-1}$$x_{t-1} = \sqrt{\bar{\alpha}{t-1}} \cdot \hat{x}0 + \sqrt{1-\bar{\alpha}{t-1}} \cdot \epsilon\theta + \sigma_t \cdot z$ 其中 $z \sim \mathcal{N}(0,I)$,最后一步不加噪声。

Classifier-Free Guidance

CFG 通过对比有条件和无条件预测来增强生成质量:

if use_cfg and cfg_scale > 1.0:
    null_condition = torch.zeros_like(condition)
    noise_uncond = model(x, t_batch, null_condition)      # 无条件预测
    noise_cond = model(x, t_batch, condition)             # 有条件预测
    noise_pred = noise_uncond + cfg_scale * (noise_cond - noise_uncond)

CFG 效果

  • cfg_scale = 1.0:不使用 CFG

  • cfg_scale > 1.0:增强条件约束,生成结果更符合条件帧

  • 典型值:3.0 ~ 7.0

采样循环
# 从纯噪声开始
x = torch.randn(B, future_frames_count, C, H, W, device=device)

# 从最大时间步到最小时间步迭代去噪
for i in range(len(step_indices) - 1, -1, -1):
    t = step_indices[i]
    # 预测噪声 → 估计 x0 → 生成 x_{t-1}

3. 模型加载 (create_model_from_checkpoint)

支持两种模型架构的动态加载:

模型类型

优先级

特点

STDiT

优先

Spatio-Temporal Diffusion Transformer,性能更优

SimpleUNet

备选

轻量级 U-Net 架构,兼容性更好

加载逻辑

try:
    from DriveGen.models import STDiT
    model = STDiT(...)
    model.load_state_dict(checkpoint['model_state_dict'])
    return model, True  # use_stdit=True
except (ImportError, ModuleNotFoundError):
    from train import SimpleUNet
    model = SimpleUNet(...)
    model.load_state_dict(checkpoint['model_state_dict'])
    return model, False  # use_stdit=False

4. 条件帧生成 (generate_condition_frames)

支持两种条件帧来源:

模式一:从数据集采样
from data.dataset import SyntheticDrivingDataset
dataset = SyntheticDrivingDataset(num_samples=100, ...)
indices = torch.randint(0, len(dataset), (num_samples,))
conditions = [dataset[idx.item()]['condition_frames'] for idx in indices]
模式二:生成合成条件帧(Fallback)

当数据集加载失败时,生成渐变图案作为条件帧:

for i in range(num_samples):
    for c in range(in_channels):
        base_color = torch.rand(1) * 0.6 + 0.2
        direction = torch.rand(1)
        
        if direction < 0.33:
            # 水平渐变
            condition[i, 0, c] = torch.linspace(base_color*0.5, base_color, image_size)
        elif direction < 0.66:
            # 垂直渐变
            ...
        else:
            # 对角渐变
            ...

5. 主函数 (main)

协调整个推理流程:

配置加载与合并
config = load_config(args.config)

# 检查点中的配置优先
checkpoint = torch.load(args.checkpoint, map_location='cpu')
if 'config' in checkpoint:
    saved_config = checkpoint['config']
    for key in saved_config:
        if key not in config:
            config[key] = saved_config[key]
设备自动检测
def get_device(device_str=None):
    if device_str:
        return torch.device(device_str)
    if torch.cuda.is_available():
        return torch.device('cuda')
    elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
        return torch.device('mps')
    else:
        return torch.device('cpu')
输出保存

支持多种输出格式:

输出类型

函数

说明

MP4 视频

save_video()

生成的完整视频序列

单帧图片

save_frames_as_images()

可选,单独保存每一帧

GIF 动画

create_gif()

可选,生成 GIF 动图

对比图

save_comparison()

条件帧与生成帧对比


使用示例

基础使用

# 最小配置:仅指定检查点
python inference.py --checkpoint checkpoints/best.pth

完整配置

python inference.py \
    --checkpoint checkpoints/best.pth \
    --num_samples 5 \
    --steps 200 \
    --cfg_scale 5.0 \
    --seed 123 \
    --output_dir results/ \
    --device cuda \
    --save_frames \
    --save_gif

输出结构

outputs/                              # 输出根目录
├── sample_000/                       # 第一个样本
│   ├── generated.mp4                 # 生成的完整视频
│   ├── frames/                       # 单帧图片目录(可选)
│   │   ├── frame_000.png
│   │   ├── frame_001.png
│   │   └── ...
│   └── generated.gif                 # GIF 动画(可选)
├── sample_001/                       # 第二个样本
│   └── ...
└── comparison.png                    # 对比图

关键技术参数

推理步数 (--steps)

  • 值越大:生成质量越高,但耗时越长

  • 推荐值:100 ~ 500

  • 典型配置:训练时 1000 步,推理时可减少到 100 ~ 200 步

CFG 缩放因子 (--cfg_scale)

  • 值越大:条件约束越强,生成结果越"确定"

  • 值越小:随机性越强,生成结果更多样

  • 推荐值:3.0 ~ 7.0

随机种子 (--seed)

  • 固定种子可复现相同的生成结果

  • 不同种子产生不同的随机噪声,导致不同的生成结果


依赖关系

依赖模块

路径

作用

NoiseScheduler

schedules/noise_schedule.py

噪声调度器

setup_logger

utils/logger.py

日志配置

save_video, save_comparison

utils/visualization.py

可视化工具

STDiT

DriveGen/models/__init__.py

主模型

SimpleUNet

train.py

备选模型

SyntheticDrivingDataset

data/dataset.py

数据集


代码优化建议

1. 性能优化

当前实现中,每个样本单独处理:

for sample_idx in range(num_samples):
    cond = condition_norm[sample_idx:sample_idx + 1]
    generated_future = ddpm_sample(model, ...)

优化方向:支持批量采样,减少循环开销。

2. 内存优化

对于大批次或高分辨率输入,建议添加梯度检查点(Gradient Checkpointing):

model.gradient_checkpointing_enable()

3. 代码健壮性

当前 ddpm_samplefuture_frames_count 硬编码为 3:

future_frames_count = 3  # 默认值,会根据模型调整

优化方向:从配置或模型属性中动态获取。


总结

inference.py 是 DriveGen 项目的推理核心,实现了从条件帧生成未来视频帧的完整流程。其关键特性包括:

  1. DDPM 采样:基于扩散模型的迭代去噪过程

  2. CFG 增强:提高生成质量和可控性

  3. 多模型支持:动态加载 STDiT 或 SimpleUNet

  4. 灵活配置:支持命令行参数和配置文件

  5. 多格式输出:满足不同的可视化需求

该脚本为自动驾驶场景生成、数据增强等应用提供了便捷的推理接口。

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐