【DriveGen 文件详解】03——inference.py
DriveGen/
├── configs/
│ └── default.yaml # 训练配置文件
├── DriveGen/
│ ├── __init__.py # 包初始化
│ ├── models/
│ │ ├── __init__.py
│ │ ├── embedding.py # Patch嵌入、时间步编码、位置编码
│ │ ├── attention.py # 空间注意力、时间注意力
│ │ ├── dit_block.py # AdaLN-Zero DiT Block
│ │ └── stdit.py # STDiT 完整模型
│ ├── data/
│ │ ├── __init__.py
│ │ └── dataset.py # 合成数据集 + nuScenes 适配器
│ ├── schedules/
│ │ ├── __init__.py
│ │ └── noise_schedule.py # 线性/余弦噪声调度
│ └── utils/
│ ├── __init__.py
│ ├── visualization.py # 视频保存、对比图、损失曲线
│ └── logger.py # 日志工具
├── train.py # 训练脚本
├── inference.py # 推理脚本(DDPM 采样 + CFG)
├── evaluate.py # 评估脚本(FID 计算)
├── requirements.txt # 依赖清单
├── setup.py # 安装配置
└── README.md # 本文件
概述
inference.py 是 DriveGen 项目的推理入口文件,负责加载训练好的扩散模型(Diffusion Model),执行 DDPM(Denoising Diffusion Probabilistic Models)采样过程,从条件帧生成后续的未来视频帧。
核心功能
|
功能模块 |
说明 |
实现位置 |
|---|---|---|
|
DDPM 迭代去噪采样 |
从纯噪声开始逐步去噪生成视频帧 |
|
|
Classifier-Free Guidance (CFG) |
通过对比有条件/无条件预测提高生成质量 |
|
|
条件帧生成 |
支持从数据集采样或生成合成条件帧 |
|
|
多格式输出 |
MP4 视频、单帧图片、GIF 动画、对比图 |
执行流程
┌─────────────────────────────────────────────────────────────────┐
│ 推理执行流程 │
├─────────────────────────────────────────────────────────────────┤
│ 1. 解析命令行参数 │
│ ↓ │
│ 2. 加载配置文件(YAML) │
│ ↓ │
│ 3. 初始化计算设备(CUDA/MPS/CPU) │
│ ↓ │
│ 4. 创建噪声调度器 │
│ ↓ │
│ 5. 加载预训练模型(STDiT / SimpleUNet) │
│ ↓ │
│ 6. 生成条件帧(起始帧) │
│ ↓ │
│ 7. 执行 DDPM 采样(迭代去噪) │
│ ↓ │
│ 8. 拼接条件帧+生成帧,保存输出 │
└─────────────────────────────────────────────────────────────────┘
模块详解
1. 参数解析 (parse_args)
解析命令行参数,支持灵活配置推理行为:
|
参数 |
类型 |
默认值 |
说明 |
|---|---|---|---|
|
|
str |
必需 |
模型检查点路径 |
|
|
str |
|
配置文件路径 |
|
|
int |
4 |
生成样本数量 |
|
|
int |
配置文件值 |
推理步数 |
|
|
float |
配置文件值 |
CFG 缩放因子 |
|
|
str |
配置文件值 |
输出目录 |
|
|
int |
配置文件值 |
随机种子 |
|
|
str |
自动检测 |
计算设备 |
|
|
flag |
False |
是否保存单帧图片 |
|
|
flag |
False |
是否保存 GIF 动画 |
代码实现:
parser = argparse.ArgumentParser(
description="DriveGen 推理脚本 - 使用训练好的模型生成驾驶场景视频",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser.add_argument('--checkpoint', '-c', type=str, required=True, ...)
# ... 其他参数定义
2. DDPM 采样核心 (ddpm_sample)
这是文件的核心算法,实现了 DDPM 的逆向采样过程。
数学原理
DDPM 采样遵循以下数学公式:
-
预测噪声:
-
估计
:
-
生成
:
其中
,最后一步不加噪声。
Classifier-Free Guidance
CFG 通过对比有条件和无条件预测来增强生成质量:
if use_cfg and cfg_scale > 1.0:
null_condition = torch.zeros_like(condition)
noise_uncond = model(x, t_batch, null_condition) # 无条件预测
noise_cond = model(x, t_batch, condition) # 有条件预测
noise_pred = noise_uncond + cfg_scale * (noise_cond - noise_uncond)
CFG 效果:
-
cfg_scale = 1.0:不使用 CFG -
cfg_scale > 1.0:增强条件约束,生成结果更符合条件帧 -
典型值:3.0 ~ 7.0
采样循环
# 从纯噪声开始
x = torch.randn(B, future_frames_count, C, H, W, device=device)
# 从最大时间步到最小时间步迭代去噪
for i in range(len(step_indices) - 1, -1, -1):
t = step_indices[i]
# 预测噪声 → 估计 x0 → 生成 x_{t-1}
3. 模型加载 (create_model_from_checkpoint)
支持两种模型架构的动态加载:
|
模型类型 |
优先级 |
特点 |
|---|---|---|
|
STDiT |
优先 |
Spatio-Temporal Diffusion Transformer,性能更优 |
|
SimpleUNet |
备选 |
轻量级 U-Net 架构,兼容性更好 |
加载逻辑:
try:
from DriveGen.models import STDiT
model = STDiT(...)
model.load_state_dict(checkpoint['model_state_dict'])
return model, True # use_stdit=True
except (ImportError, ModuleNotFoundError):
from train import SimpleUNet
model = SimpleUNet(...)
model.load_state_dict(checkpoint['model_state_dict'])
return model, False # use_stdit=False
4. 条件帧生成 (generate_condition_frames)
支持两种条件帧来源:
模式一:从数据集采样
from data.dataset import SyntheticDrivingDataset
dataset = SyntheticDrivingDataset(num_samples=100, ...)
indices = torch.randint(0, len(dataset), (num_samples,))
conditions = [dataset[idx.item()]['condition_frames'] for idx in indices]
模式二:生成合成条件帧(Fallback)
当数据集加载失败时,生成渐变图案作为条件帧:
for i in range(num_samples):
for c in range(in_channels):
base_color = torch.rand(1) * 0.6 + 0.2
direction = torch.rand(1)
if direction < 0.33:
# 水平渐变
condition[i, 0, c] = torch.linspace(base_color*0.5, base_color, image_size)
elif direction < 0.66:
# 垂直渐变
...
else:
# 对角渐变
...
5. 主函数 (main)
协调整个推理流程:
配置加载与合并
config = load_config(args.config)
# 检查点中的配置优先
checkpoint = torch.load(args.checkpoint, map_location='cpu')
if 'config' in checkpoint:
saved_config = checkpoint['config']
for key in saved_config:
if key not in config:
config[key] = saved_config[key]
设备自动检测
def get_device(device_str=None):
if device_str:
return torch.device(device_str)
if torch.cuda.is_available():
return torch.device('cuda')
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
return torch.device('mps')
else:
return torch.device('cpu')
输出保存
支持多种输出格式:
|
输出类型 |
函数 |
说明 |
|---|---|---|
|
MP4 视频 |
|
生成的完整视频序列 |
|
单帧图片 |
|
可选,单独保存每一帧 |
|
GIF 动画 |
|
可选,生成 GIF 动图 |
|
对比图 |
|
条件帧与生成帧对比 |
使用示例
基础使用
# 最小配置:仅指定检查点
python inference.py --checkpoint checkpoints/best.pth
完整配置
python inference.py \
--checkpoint checkpoints/best.pth \
--num_samples 5 \
--steps 200 \
--cfg_scale 5.0 \
--seed 123 \
--output_dir results/ \
--device cuda \
--save_frames \
--save_gif
输出结构
outputs/ # 输出根目录
├── sample_000/ # 第一个样本
│ ├── generated.mp4 # 生成的完整视频
│ ├── frames/ # 单帧图片目录(可选)
│ │ ├── frame_000.png
│ │ ├── frame_001.png
│ │ └── ...
│ └── generated.gif # GIF 动画(可选)
├── sample_001/ # 第二个样本
│ └── ...
└── comparison.png # 对比图
关键技术参数
推理步数 (--steps)
-
值越大:生成质量越高,但耗时越长
-
推荐值:100 ~ 500
-
典型配置:训练时 1000 步,推理时可减少到 100 ~ 200 步
CFG 缩放因子 (--cfg_scale)
-
值越大:条件约束越强,生成结果越"确定"
-
值越小:随机性越强,生成结果更多样
-
推荐值:3.0 ~ 7.0
随机种子 (--seed)
-
固定种子可复现相同的生成结果
-
不同种子产生不同的随机噪声,导致不同的生成结果
依赖关系
|
依赖模块 |
路径 |
作用 |
|---|---|---|
|
|
|
噪声调度器 |
|
|
|
日志配置 |
|
|
|
可视化工具 |
|
|
|
主模型 |
|
|
|
备选模型 |
|
|
|
数据集 |
代码优化建议
1. 性能优化
当前实现中,每个样本单独处理:
for sample_idx in range(num_samples):
cond = condition_norm[sample_idx:sample_idx + 1]
generated_future = ddpm_sample(model, ...)
优化方向:支持批量采样,减少循环开销。
2. 内存优化
对于大批次或高分辨率输入,建议添加梯度检查点(Gradient Checkpointing):
model.gradient_checkpointing_enable()
3. 代码健壮性
当前 ddpm_sample 中 future_frames_count 硬编码为 3:
future_frames_count = 3 # 默认值,会根据模型调整
优化方向:从配置或模型属性中动态获取。
总结
inference.py 是 DriveGen 项目的推理核心,实现了从条件帧生成未来视频帧的完整流程。其关键特性包括:
-
DDPM 采样:基于扩散模型的迭代去噪过程
-
CFG 增强:提高生成质量和可控性
-
多模型支持:动态加载 STDiT 或 SimpleUNet
-
灵活配置:支持命令行参数和配置文件
-
多格式输出:满足不同的可视化需求
该脚本为自动驾驶场景生成、数据增强等应用提供了便捷的推理接口。
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐



所有评论(0)