VideoMamba:Pipeline过程记录(UA-DETRAC数据集)
一、数据集
UA-DETRAC数据集:
下载:kaggle网站下载UA-DERRAC.zip,然后上传到远程主机
detrac的数据集目录结构:
UA-DETRAC/
├── DETRAC-Images/ # 所有的图片数据都在这里
│ ├── MVI_20011/ # 每一个文件夹代表一段完整的视频序列
│ │ ├── img00001.jpg
│ │ ├── img00002.jpg
│ │ └── ... (可能包含几百上千张按顺序命名的图片)
│ ├── MVI_20012/
│ └── MVI_...
└── annotations/ # 官方通常提供 XML 标注框
└── ...

数据集非视频数据集,而是帧序列,videomamba它默认你的数据是标准的视频动作识别数据集


二、数据预处理
convert_detrac.py:
数据集预处理脚本,将UA-DETRAC数据集转换为适合视频分类任务使用的CSV标注文件格式
核心功能:
UA-DETRAC原始数据 (XML标注 + 图像帧)
↓
[convert_detrac.py]
↓
train.csv / val.csv (视频级分类标注)
2.1 关键配置
BASE_PATH = os.path.dirname(os.path.abspath(__file__))
DATA_ROOT = os.path.join(BASE_PATH, "data/UA-DETRAC/DETRAC-Images") # 图像帧目录
XML_ROOT = os.path.join(BASE_PATH, "data/UA-DETRAC/DETRAC-Train-Annotations-XML") # XML标注目录
OUTPUT_DIR = os.path.join(BASE_PATH, "data/UA-DETRAC/annotations") # 输出CSV目录
VEHICLE_MAP = {'car': 0, 'bus': 1, 'van': 2, 'others': 3} # 车辆类别→数字标签映射
2.2 核心函数
![]()
![]()
![]()

2.3 os.path模块概述
1. os.path.join(path1, path2, ...)
将多个路径组件智能拼接为一个完整路径,自动处理分隔符,示例:
os.path.join('folder', 'subfolder', 'file.txt') # 输出: 'folder/subfolder/file.txt'(Unix)或 'folder\\subfolder\\file.txt'(Windows)
2.os.path.split(path)
将路径拆分为 (head, tail) 元组,tail 是最后一级目录或文件名,示例:
os.path.split('/home/user/file.txt') # 输出: ('/home/user', 'file.txt')
3.os.path.splitext(path)
拆分路径的扩展名,返回 (root, ext) 元组(ext 包含点号)
os.path.splitext('data.json') # 输出: ('data', '.json')
4.os.path.exists(path)
检查路径是否存在(文件或目录均可)
os.path.exists('/etc/passwd') # 返回 True 或 False
5.os.path.dirname(path)
提取路径的目录部分
os.path.dirname('/home/user/file.txt') # 输出: '/home/user'
6.os.path.abspath(path)
将相对路径转为绝对路径
os.path.abspath('./file.txt') # 输出当前目录的绝对路径
解析:
BASE_PATH = os.path.dirname(os.path.abspath(__file__))
os.path.abspath(__file__):首先返回当前脚本的绝对路径(包括文件名)
os.path.dirname():提取该路径的目录部分,即去掉文件名
用于动态加载同级目录,再用os.path.join()拼接其他路径
三、构建数据加载器 dataloader
不改变官方的代码,手写脚本文件,将.jpg序列变成模型能吃的 5D Tensor
3.1 完整数据流
原始数据
│
▼
[convert_detrac.py]
│ ├─ 读取XML → 统计主类别 → 映射为0/1/2/3
│ ├─ 过滤短视频(<8帧)
│ └─ 输出: train.csv / val.csv (无表头)
│
▼
[DetracImageDataset.__getitem__]
│ ├─ 解析CSV → 获取folder_name, label
│ ├─ 按clip_len×sample_rate策略抽帧
│ ├─ 逐帧: PIL→Resize→ToTensor→[0,1]
│ ├─ 堆叠+置换: [T,C,H,W]→[C,T,H,W]
│ └─ ImageNet归一化
│
▼
[DataLoader]
│ └─ 批处理: [C,T,H,W]×B → [B,C,T,H,W]
│
▼
[VideoMamba]
│ ├─ 3D Patch Embed: [B,3,T,224,224]→[B,L,C]
│ ├─ 位置编码 + Bidirectional Mamba Blocks
│ └─ [CLS]→Linear→Logits→Loss
│
▼
[训练/推理完成]
四、训练脚本
4.1 模块导入说明
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from datasets.detrac_loader import DetracImageDataset # 自定义Dataset
from models.videomamba import videomamba_tiny # VideoMamba模型定义
4.2 配置区解析
def train():
# --- 硬件配置 ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 自动检测GPU,若无则使用CPU
# --- 训练超参数 ---
batch_size = 2 # ⚠️ 小batch避免显存溢出(ALLOC_FAILED)
epochs = 30 # 训练轮数
lr = 1e-5 # 学习率:微调时用较小值
# --- 路径配置 ---
BASE = "/home/chenchenyang/pro26/VideoMamba-main"
train_csv = f"{BASE}/data/UA-DETRAC/annotations/train.csv" # 训练集标注
img_prefix = f"{BASE}/data/UA-DETRAC/DETRAC-Images" # 图像根目录
ckpt_path = f"{BASE}/checkpoints/videomamba_t16_k400_f8_res224.pth" # 预训练权重
output_dir = f"{BASE}/outputs/detrac_checkpoints" # 模型保存目录
os.makedirs(output_dir, exist_ok=True)
4.3 训练超参数解析
4.3.1 .batch_size(批处理大小)
batch_size表示每次迭代训练时使用的样本数量,在这里代表每次训练迭代中加载视频片段video clip的数量
设置为2意味着每次前向传播和反向传播会处理2个样本
较小的batch_size会降低显存占用,避免GPU内存不足(如提示ALLOC_FAILED错误),但可能导致训练过程波动较大(梯度更新更频繁)
较大的batch_size能稳定梯度估计,但需要更多显存,可能影响模型泛化能力
4.3.2 epochs(训练轮数)
epochs指完整遍历整个训练数据集的次数。
设置为30表示模型会重复学习训练数据30轮。增加epochs可能提升模型性能,但需注意过拟合风险(训练误差持续下降而验证误差上升)。若训练早期验证性能已稳定,可提前停止(Early Stopping)以避免无效训练。
4.3.3 .lr(学习率)
lr(learning rate)控制参数更新的步长。
1e-5(0.00001)是较小的学习率,适合微调预训练模型,避免破坏预训练权重。学习率过大会导致训练不稳定或难以收敛;过小会减慢收敛速度。通常需配合学习率调度策略(如余弦退火、预热)动态调整
4.4 数据准备模块
# --- 数据准备 ---
dataset = DetracImageDataset(train_csv, prefix=img_prefix, mode='train')
num_classes = len(dataset.label_map) # 自动获取类别数
print(f"Detected Number of Classes: {num_classes}")
loader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=True, # 训练时打乱顺序
num_workers=2, # 2个子进程加载数据
pin_memory=True # 加速CPU→GPU数据传输
)

4.5 模型加载与权重迁移(迁移学习)
# --- 模型加载 ---
model = videomamba_tiny(num_classes=num_classes) # 创建新模型,分类头适配DETRAC
if os.path.exists(ckpt_path):
print(f"Loading weights from {ckpt_path}")
checkpoint = torch.load(ckpt_path, map_location='cpu') # 先加载到CPU避免显存碎片
state_dict = checkpoint['model'] if 'model' in checkpoint else checkpoint
# 🔹 关键步骤:移除预训练权重的分类头
for k in ['head.weight', 'head.bias']:
if k in state_dict:
del state_dict[k]
# 🔹 加载权重:strict=False 允许部分参数不匹配
msg = model.load_state_dict(state_dict, strict=False)
print(f"Checkpoint loaded: {msg}")
model.to(device) # 模型迁移到GPU,将模型(model)及其所有参数(权重)移动到指定的计算设备(device)
模型加载选择 videomamba_tiny(基于预训练模型选择),该函数在videomamba.py中有所定义
pytorch的 .pth的文件(检查点文件)通常以字典形式存储
常见的结构示例:
# 结构1: 最简形式 (只存模型参数) torch.save(model.state_dict(), 'model.pth') # 内容: {'layer1.weight': tensor(...), 'layer1.bias': tensor(...), ...} # 结构2: 完整检查点 (包含元信息) torch.save({ 'epoch': 10, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': 0.05, }, 'checkpoint.pth') # 内容: {'epoch': 10, 'model_state_dict': {...}, 'optimizer_state_dict': {...}, ...} # 结构3: 某些框架的命名习惯 torch.save({'model': model.state_dict()}, 'checkpoint.pth') # 内容: {'model': {...}}
1.先通过 torch.load加载.pth检查点文件
state_dict = checkpoint['model'] if 'model' in checkpoint else checkpoint
2. state_dict 模型状态字典,检查字典中是否有“model”键,如果存在,那就是结构3,如果不存在,就是结构1,直接使用整个checkpoint,整个checkpoint里面保存的就是整个模型参数字典
模型参数字典示例:
# 假设原始模型的 state_dict
state_dict_example = {
'patch_embed.proj.weight': torch.randn(96, 3, 2, 2),
'blocks.0.norm1.weight': torch.randn(96),
...,
'head.weight': torch.randn(1000, 768), # 分类头权重
'head.bias': torch.randn(1000), # 分类头偏置
}
3.移除分类头权重
为什么要权重迁移(迁移学习)?
对于不同的数据集,分类的数量是不同的
比如本次实验,预训练权重使用K400数据集训练的,K400中有400/600类,但是UA-DETRAC数据集中只有4类,所以在模型参数字典中,参数的形状不匹配,预训练权重的head.height是(400,768),而我们用UA-DETRAC数据集训练的模型是(num_classes,768),其中num_classes = 4,所以要保留主干网络的权重,丢弃不匹配的分类头,分类头要重新训练
1. 加载预训练权重 → 只保留backbone参数
2. 新建分类头 → 随机初始化
3. 微调训练 → backbone小幅更新,head大幅学习
4.检查加载状态
为了了解那些权重加载成功了,权重加载是否按预期进行
# 4. 加载权重
msg = model.load_state_dict(state_dict, strict=False)
model是第一步加载的videomamba_tiny
strict = true:要求state_dict的键要与模型完全匹配
strict = false:允许部分加载

返回的msg格式:
# msg 的实际结构
class _IncompatibleKeys:
missing_keys: List[str] # 模型有但 state_dict 中没有的键
unexpected_keys: List[str] # state_dict 中有但模型中没有的键
_IncompatibleKeys(
missing_keys=['head.weight', 'head.bias'], # 模型有但 state_dict 无
unexpected_keys=['some_unused_key'] # state_dict 有但模型无
)
加载过程可视化:
预训练权重 state_dict: 您的模型结构:
┌─────────────────────┐ ┌─────────────────────┐
│ patch_embed.weight │ ←──匹配──→ │ patch_embed.weight │
│ blocks.0.norm1.weight│ ←──匹配──→ │ blocks.0.norm1.weight│
│ ... │ │ ... │
│ head.weight [400,768]│ ✗ │ head.weight [4,768] │ ← 不匹配,删除
│ head.bias [400] │ ✗ │ head.bias [4] │ ← 不匹配,删除
└─────────────────────┘ └─────────────────────┘
↓ ↓
删除 head.* 保留随机初始化
实际输出示例:
# 打印加载信息
print(f"Checkpoint loaded: {msg}")
# 输出可能:
# Checkpoint loaded: <All keys matched successfully> # 理想情况
# 或
# Checkpoint loaded: _IncompatibleKeys(
# missing_keys=['head.weight', 'head.bias'], # 预期的,因为被删除了
# unexpected_keys=[]
# )
4.6 优化器与损失函数
# --- 优化器 ---
optimizer = torch.optim.AdamW(
model.parameters(),
lr=lr,
weight_decay=0.05 # 建议添加L2正则,防止过拟合
)
criterion = nn.CrossEntropyLoss() # 多分类交叉熵损失

4.7 训练循环详解
print("Start Training...")
for epoch in range(epochs):
model.train() # 切换到训练模式(启用Dropout等)
epoch_loss = 0
for i, (frames, labels) in enumerate(loader):
# 1️⃣ 数据迁移到GPU
frames = frames.to(device, non_blocking=True) # non_blocking加速传输
labels = labels.to(device, non_blocking=True)
# 2️⃣ 前向传播
outputs = model(frames) # [B, num_classes]
loss = criterion(outputs, labels) # 标量损失
# 3️⃣ 反向传播
optimizer.zero_grad() # 清空上轮梯度
loss.backward() # 计算梯度
optimizer.step() # 更新参数
# 4️⃣ 日志记录
epoch_loss += loss.item()
if i % 5 == 0: # 每5个batch打印一次
print(f"Epoch [{epoch}] Batch [{i}/{len(loader)}] Loss: {loss.item():.4f}")
# 5️⃣ 保存checkpoint
save_path = os.path.join(output_dir, f"mamba_detrac_epoch_{epoch}.pth")
torch.save(model.state_dict(), save_path)
# 6️⃣ 轮次统计
avg_loss = epoch_loss / len(loader)
print(f"Epoch {epoch} Average Loss: {avg_loss:.4f}")
4.7.1 变量解析和loader数据加载器
epoch_loss:用来累加当前epoch内所有的批次的损失值,每个epoch(完整训练一次数据集)都要重新计算平均损失
for i, (frames, labels) in enumerate(loader):
- i:当前批次的索引,从 0 开始
- frames:当前批次的数据(视频帧张量),[B,C,T,H,W]的形状
- labels:当前批次对应的标签或目标值
- loder:数据加载器的对象(Pytorch的DataLoader),用于批量生成训练数据
在循环训练中,loader每次返回一个批次
具体的数据结构:
# 假设 batch_size=4, clip_len=8, 图片尺寸224×224
print(type(frames)) # <class 'torch.Tensor'>
print(frames.shape) # torch.Size([4, 3, 8, 224, 224])
print(type(labels)) # <class 'torch.Tensor'>
print(labels.shape) # torch.Size([4])
print(labels) # tensor([2, 0, 3, 1], device='cuda:0')
可视化表示:
frames 的形状 [4, 3, 8, 224, 224]:
┌─────────────────────────────────────────────────┐
│ Batch 0: 视频片段0 │
│ Shape: [3, 8, 224, 224] (RGB, 8帧, 高, 宽) │
├─────────────────────────────────────────────────┤
│ Batch 1: 视频片段1 │
│ Shape: [3, 8, 224, 224] │
├─────────────────────────────────────────────────┤
│ Batch 2: 视频片段2 │
│ Shape: [3, 8, 224, 224] │
├─────────────────────────────────────────────────┤
│ Batch 3: 视频片段3 │
│ Shape: [3, 8, 224, 224] │
└─────────────────────────────────────────────────┘
labels: [2, 0, 3, 1] ← 每个视频片段的类别标签
完整的数据流示例:
数据集: [视频0, 视频1, ..., 视频999]
↓
DataLoader 打乱: [视频34, 视频12, 视频8, 视频25, 视频77, ...]
↓
第一个批次: [视频34, 视频12, 视频8, 视频25] → (frames0, labels0)
↓
模型输入: frames0.shape = [4, 3, 8, 224, 224]
labels0 = [2, 0, 3, 1]
↓
第二个批次: [视频77, 视频91, 视频3, 视频61] → (frames1, labels1)
4.7.2 向前传播 和 反向传播
outputs = model(frames) # [B, num_classes]
loss = criterion(outputs, labels) # 标量损失
将输出和标签传入优化器中,得到loss损失
optimizer.zero_grad() # 清空上轮梯度
loss.backward() # 计算梯度
optimizer.step() # 更新参数
epoch_loss += loss.item()
if i % 5 == 0:
print(f"Epoch [{epoch}] Batch [{i}/{len(loader)}] Loss: {loss.item():.4f}")
loss:损失张量,带有梯度信息的PyTorch张量
loss.item():获取损失张量的Python标量数值
+=:累加到epoch_loss变量
# 示例:理解 loss 和 loss.item() 的区别
loss = criterion(outputs, labels)
print(type(loss)) # <class 'torch.Tensor'> (形状: torch.Size([]))
print(loss) # tensor(2.134, device='cuda:0', grad_fn=<NllLossBackward0>)
print(loss.item()) # 2.134 (Python float类型)
# 如果直接相加会怎样?
# epoch_loss += loss # ❌ 错误!会把张量也累积进去
# 最终 epoch_loss 会变成一堆张量的和,而不是数值
4.7.3 保存权重文件和每个epoch平均损失统计
save_path = os.path.join(output_dir, f"mamba_detrac_epoch_{epoch}.pth")
torch.save(model.state_dict(), save_path)
avg_loss = epoch_loss / len(loader)
print(f"Epoch {epoch} Average Loss: {avg_loss:.4f}")
4.7.4 数据流形状变化
输入: frames [B, 3, 8, 224, 224]
├─ B: batch_size=2
├─ 3: RGB通道
├─ 8: 时间帧数(clip_len)
└─ 224×224: 空间分辨率
VideoMamba内部:
1. 3D Patch Embed: [2,3,8,224,224] → [2, L, 192]
L = 8 × (224/16)² = 8 × 196 = 1568 tokens
2. B-Mamba Blocks × 24: [2, 1568, 192] → [2, 1568, 192]
3. [CLS] token提取: [2, 192]
4. Classification Head: [2, 192] → [2, num_classes]
输出: outputs [B, num_classes], labels [B]
五、验证脚本
val_detrac.py
5.1 配置信息
def inference():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# --- 1. 配置路径 ---
BASE = "/home/chenchenyang/pro26/VideoMamba-main"
test_csv = f"{BASE}/data/UA-DETRAC/annotations/train.csv" # ⚠️ 先用训练集测过拟合
img_prefix = f"{BASE}/data/UA-DETRAC/DETRAC-Images"
model_path = f"{BASE}/outputs/detrac_checkpoints/mamba_detrac_epoch_29.pth"
test 临时用训练集测试,验证是否过拟合
5.2 加载数据
# --- 2. 加载数据 ---
# mode='val' 会使用中间抽帧,结果更稳定
dataset = DetracImageDataset(test_csv, prefix=img_prefix, mode='val')
loader = DataLoader(dataset, batch_size=1, shuffle=False)
# 获取反向映射表 (ID -> 原始标签)
inv_label_map = {v: k for k, v in dataset.label_map.items()}
num_classes = len(dataset.label_map)
- mode = "val" : 居中抽帧,推理是用稳定的抽帧策略
- batch_size = 1:单样本,便于逐条查看预测结果
- shuffle = False :不打乱,保持样本顺序,便于调试
反向映射表:
# 假设 dataset.label_map = {0: 0, 1: 1, 2: 2, 3: 3}
# 则 inv_label_map = {0: 0, 1: 1, 2: 2, 3: 3}
# 如果原始标签是字符串:
# label_map = {'car': 0, 'bus': 1, 'van': 2}
# inv_label_map = {0: 'car', 1: 'bus', 2: 'van'}
- 训练时,需要的映射: label_map = {'car': 0, 'bus': 1, 'van': 2}、
- 推理时,需要的反向映射: inv_label_map = {0: 'car', 1: 'bus', 2: 'van'},模型输出的是数字,映射为标签“car”
5.3 模型加载模块
# --- 3. 初始化模型并加载权重 ---
model = videomamba_tiny(num_classes=num_classes)
if os.path.exists(model_path):
print(f"Loading trained model from {model_path}")
model.load_state_dict(torch.load(model_path, map_location='cpu'))
else:
print("Model file not found!")
return
model.to(device)
model.eval() # 🔹 关键:切换到评估模式(关闭Dropout等)
1. torch.load(model_path, map_location='cpu')
- 从磁盘读取.pth文件
- 将文件的权重数据加载到CPU内存
- 返回一个包含权重的字典(state_dict)
2.model.load_state_dict(...):
- 将这些权重复制到模型的参数中
- 模型的设备位置保持不变
模型model是一直在CPU上的,权重文件.pth从磁盘加载到字典,再复制到模型的参数中
model.load_state_dict(torch.load(model_path, map_location='cpu'))
3.model.eval() :切换到评估模式

5.4 推理循环讲解
# --- 4. 开始推理 ---
correct = 0
total = 0
print(f"\n{'Sample Index':<15} | {'True Label':<12} | {'Pred Label':<12} | {'Status'}")
print("-" * 60)
with torch.no_grad():
for i, (frames, label) in enumerate(loader):
frames = frames.to(device)
label = label.to(device)
outputs = model(frames)
prob = F.softmax(outputs, dim=1)
pred = torch.argmax(prob, dim=1)
true_origin = inv_label_map[label.item()]
pred_origin = inv_label_map[pred.item()]
status = "✅" if label == pred else "❌"
if label == pred:
correct += 1
total += 1
print(f"{i:<15} | {true_origin:<12} | {pred_origin:<12} | {status}")
acc = 100 * correct / total
print("-" * 60)
print(f"Final Accuracy: {acc:.2f}% ({correct}/{total})")
5.4.1变量解析和 loader 数据加载器
- correct:统计正确的样本数
- total:总样本数
for i, (frames, label) in enumerate(loader):
先在loader里面加载的是 test 测试集(batch_size = 1)

将数据迁移到设备上(模型和数据必须在同一个设备上,否则报错!)
frames = frames.to(device)
label = label.to(device)
5.4.2 模型向前传播
outputs = model(frames) # [1, num_classes]
数据流形状变化:
输入 frames: [1, 3, 8, 224, 224]
├─ 1: batch_size
├─ 3: RGB 通道
├─ 8: 时间帧数 (clip_len)
└─ 224×224: 空间分辨率
↓ 3D Patch Embed (1×16×16)
Patch 序列: [1, 1568, 192]
├─ 1568 = 8 × (224/16)² = 8 × 196 tokens
└─ 192: embed_dim (Tiny 版本)
↓ B-Mamba Blocks × 24
隐藏状态: [1, 1568, 192]
↓ [CLS] Token 提取
CLS 表示: [1, 192]
↓ Classification Head
输出 outputs: [1, num_classes]
例如 [1, 4] → 4 个类别的 logits
输出outputs示例:
outputs = tensor([[2.3, -0.5, 1.1, -1.2]]) # 原始 logits(未归一化)
# 对应类别:[car, bus, van, others]
5.4.3 概率映射标签
1.softmax转概率:将logits转换为概率分布(和为1)
- 可以查看模型对每个类别的置信度
- 可以设置置信度阈值(如>0.8才相信预测)
prob = F.softmax(outputs, dim=1) # [1, num_classes], 和为 1
2. 取最大概率类别
pred = torch.argmax(prob, dim=1) # [1]
argmax:返回最大值的索引
计算示例:
# 输入 logits
outputs = tensor([[2.3, -0.5, 1.1, -1.2]])
# Softmax 计算
exp_outputs = [e^2.3, e^-0.5, e^1.1, e^-1.2]
= [9.97, 0.61, 3.00, 0.30]
sum = 9.97 + 0.61 + 3.00 + 0.30 = 13.88
prob = [9.97/13.88, 0.61/13.88, 3.00/13.88, 0.30/13.88]
= [0.718, 0.044, 0.216, 0.022] # 和为 1.0
# 输出
prob = tensor([[0.718, 0.044, 0.216, 0.022]])
3.标签反向映射
true_origin = inv_label_map[label.item()]
pred_origin = inv_label_map[pred.item()]
- item():是Tensor的一个方法,用于提取Tensor中的数字
- items():是字典的方法,获取所有键值对
# 原始标签表 (类别名 → ID)
label_map = {'car': 0, 'bus': 1, 'van': 2}
# 使用 .items() 拿到每一对 (key, value)
label_map.items()
# 输出类似:dict_items([('car', 0), ('bus', 1), ('van', 2)])
# 列表推导式里交换位置 (v: k for k, v in ...)
# 变成:{0: 'car', 1: 'bus', 2: 'van'}
5.4.4 完整数据流变化
┌─────────────────────────────────────────────────────────────┐
│ 推理循环数据流 │
└─────────────────────────────────────────────────────────────┘
1️⃣ 数据加载
loader → (frames, label)
[1,3,8,224,224] [1]
2️⃣ 设备迁移
frames.to(device), label.to(device)
3️⃣ 前向传播
model(frames) → outputs
[1, num_classes]
例如 [1, 4] = [[2.3, -0.5, 1.1, -1.2]]
4️⃣ Softmax
F.softmax(outputs, dim=1) → prob
[1, 4] = [[0.718, 0.044, 0.216, 0.022]]
5️⃣ Argmax
torch.argmax(prob, dim=1) → pred
[1] = [0]
6️⃣ 标签映射
inv_label_map[pred.item()] → pred_origin
0 → 'car'
7️⃣ 统计
if label == pred: correct += 1
total += 1
8️⃣ 输出
print(f"{i} | {true_origin} | {pred_origin} | {status}")
六、脚本源码
6.1 convert_detrac.py
import os
import xml.etree.ElementTree as ET
import pandas as pd
from sklearn.model_selection import train_test_split
# --- 路径配置 ---
BASE_PATH = os.path.dirname(os.path.abspath(__file__))
DATA_ROOT = os.path.join(BASE_PATH, "data/UA-DETRAC/DETRAC-Images")
XML_ROOT = os.path.join(BASE_PATH, "data/UA-DETRAC/DETRAC-Train-Annotations-XML")
OUTPUT_DIR = os.path.join(BASE_PATH, "data/UA-DETRAC/annotations")
os.makedirs(OUTPUT_DIR, exist_ok=True)
VEHICLE_MAP = {'car': 0, 'bus': 1, 'van': 2, 'others': 3}
def get_major_class(xml_path):
"""
解析XML文件,统计视频中出现最多的车辆类型
DETRAC XML结构:
<sequence>
<frame num="1">
<target_list>
<target id="1">
<box .../>
<attribute vehicle_type="car" .../>
</target>
</target_list>
</frame>
</sequence>
"""
try:
tree = ET.parse(xml_path)
root = tree.getroot()
counts = {'car': 0, 'bus': 0, 'van': 0, 'others': 0}
# 正确的遍历方式:遍历所有frame中的target_list/target/attribute
for frame in root.findall('.//frame'):
target_list = frame.find('target_list')
if target_list is not None:
for target in target_list.findall('target'):
attribute = target.find('attribute')
if attribute is not None:
v_type = attribute.get('vehicle_type')
if v_type in counts:
counts[v_type] += 1
else:
counts['others'] += 1
# 调试信息:打印统计结果
total = sum(counts.values())
if total > 0:
print(f" XML解析成功: {os.path.basename(xml_path)}")
print(f" 车辆统计: {counts}")
# 返回出现次数最多的类别对应的数字标签
major_class = max(counts, key=counts.get)
return VEHICLE_MAP[major_class]
except Exception as e:
print(f" ⚠️ XML解析失败 {xml_path}: {e}")
return 0 # 解析失败时默认返回'car'类别
def generate_csv():
data_list = []
print("=" * 60)
print("开始处理UA-DETRAC数据集...")
print("=" * 60)
# Step 1: 筛选有效视频文件夹(以MVI_开头)
video_folders = [f for f in os.listdir(DATA_ROOT)
if os.path.isdir(os.path.join(DATA_ROOT, f)) and f.startswith('MVI')]
print(f"\n📁 找到 {len(video_folders)} 个候选视频文件夹")
print(f"📂 数据目录: {DATA_ROOT}")
print(f"📄 标注目录: {XML_ROOT}")
print()
for idx, folder in enumerate(video_folders):
folder_path = os.path.join(DATA_ROOT, folder)
# Step 2: 匹配XML标注文件(兼容_v3后缀)
xml_file = f"{folder}_v3.xml" if os.path.exists(os.path.join(XML_ROOT, f"{folder}_v3.xml")) else f"{folder}.xml"
xml_path = os.path.join(XML_ROOT, xml_file)
if not os.path.exists(xml_path):
print(f"⚠️ [{idx + 1}/{len(video_folders)}] {folder}: 未找到XML文件,跳过")
continue
# Step 3: 统计有效图像帧数,过滤短视频
images = [img for img in os.listdir(folder_path) if img.lower().endswith(('.jpg', '.jpeg'))]
if len(images) < 8: # 帧数太少无法满足f8预训练模型输入
print(f"⚠️ [{idx + 1}/{len(video_folders)}] {folder}: 帧数不足({len(images)}<8),跳过")
continue
# Step 4: 获取主类别标签
print(f"[{idx + 1}/{len(video_folders)}] 处理: {folder} ({len(images)}帧)")
label = get_major_class(xml_path)
data_list.append([folder, len(images), label])
# 每处理10个视频显示一次进度摘要
if (idx + 1) % 10 == 0:
labels_count = {}
for _, _, l in data_list:
labels_count[l] = labels_count.get(l, 0) + 1
print(f" 进度: 已处理{len(data_list)}个视频,标签分布: {labels_count}\n")
print("\n" + "=" * 60)
print("处理完成!统计信息:")
print("=" * 60)
if not data_list:
print("\n❌ 错误:没有成功处理任何视频!")
print(f"请检查以下路径是否正确:")
print(f" - 数据目录: {DATA_ROOT}")
print(f" - 标注目录: {XML_ROOT}")
print(f"\n建议执行以下命令检查:")
print(f" ls {DATA_ROOT} | head -5")
print(f" ls {XML_ROOT} | head -5")
return
# 统计标签分布
label_stats = {}
for _, _, label in data_list:
label_stats[label] = label_stats.get(label, 0) + 1
print(f"\n✅ 成功处理 {len(data_list)} 个视频")
print(f"📊 标签分布:")
reverse_map = {v: k for k, v in VEHICLE_MAP.items()}
for label in sorted(label_stats.keys()):
print(f" {reverse_map[label]} (标签{label}): {label_stats[label]} 个")
# Step 5: 构建DataFrame并划分训练/验证集
df = pd.DataFrame(data_list, columns=['path', 'frames', 'label'])
train_df, val_df = train_test_split(df, test_size=0.2, random_state=42)
# Step 6: 保存为无表头CSV
train_df.to_csv(os.path.join(OUTPUT_DIR, "train.csv"), index=False, header=False)
val_df.to_csv(os.path.join(OUTPUT_DIR, "val.csv"), index=False, header=False)
print(f"\n💾 文件已保存:")
print(f" 训练集: {os.path.join(OUTPUT_DIR, 'train.csv')} ({len(train_df)}条)")
print(f" 验证集: {os.path.join(OUTPUT_DIR, 'val.csv')} ({len(val_df)}条)")
print(f" 输出目录: {OUTPUT_DIR}")
# 显示前5行样例
print(f"\n📋 训练集前5行样例:")
print(train_df.head().to_string(index=False))
if __name__ == "__main__":
generate_csv()
控制台打印:
(videomamba) chenchenyang@musk:~/pro26/VideoMamba-main$ python convert_detrac.py
============================================================
开始处理UA-DETRAC数据集...
============================================================
📁 找到 100 个候选视频文件夹
📂 数据目录: /home/chenchenyang/pro26/VideoMamba-main/data/UA-DETRAC/DETRAC-Images
📄 标注目录: /home/chenchenyang/pro26/VideoMamba-main/data/UA-DETRAC/DETRAC-Train-Annotations-XML
[1/100] 处理: MVI_20011 (664帧)
XML解析成功: MVI_20011.xml
车辆统计: {'car': 7053, 'bus': 95, 'van': 296, 'others': 211}
[2/100] 处理: MVI_20012 (936帧)
XML解析成功: MVI_20012.xml
车辆统计: {'car': 6481, 'bus': 1337, 'van': 790, 'others': 0}
[3/100] 处理: MVI_20032 (437帧)
XML解析成功: MVI_20032.xml
车辆统计: {'car': 1657, 'bus': 0, 'van': 0, 'others': 0}
[4/100] 处理: MVI_20033 (784帧)
XML解析成功: MVI_20033.xml
车辆统计: {'car': 4188, 'bus': 327, 'van': 759, 'others': 0}
[5/100] 处理: MVI_20034 (800帧)
XML解析成功: MVI_20034.xml
车辆统计: {'car': 8073, 'bus': 1373, 'van': 488, 'others': 0}
[6/100] 处理: MVI_20035 (800帧)
XML解析成功: MVI_20035.xml
车辆统计: {'car': 10396, 'bus': 0, 'van': 1459, 'others': 0}
[7/100] 处理: MVI_20051 (906帧)
XML解析成功: MVI_20051.xml
车辆统计: {'car': 7738, 'bus': 906, 'van': 290, 'others': 0}
[8/100] 处理: MVI_20052 (694帧)
XML解析成功: MVI_20052.xml
车辆统计: {'car': 6982, 'bus': 694, 'van': 683, 'others': 0}
[9/100] 处理: MVI_20061 (800帧)
XML解析成功: MVI_20061.xml
车辆统计: {'car': 7348, 'bus': 873, 'van': 1035, 'others': 0}
[10/100] 处理: MVI_20062 (800帧)
XML解析成功: MVI_20062.xml
车辆统计: {'car': 3049, 'bus': 501, 'van': 725, 'others': 187}
进度: 已处理10个视频,标签分布: {0: 10}
[11/100] 处理: MVI_20063 (800帧)
XML解析成功: MVI_20063.xml
车辆统计: {'car': 4789, 'bus': 771, 'van': 712, 'others': 120}
[12/100] 处理: MVI_20064 (800帧)
XML解析成功: MVI_20064.xml
车辆统计: {'car': 13478, 'bus': 241, 'van': 465, 'others': 0}
[13/100] 处理: MVI_20065 (1200帧)
XML解析成功: MVI_20065.xml
车辆统计: {'car': 14205, 'bus': 1252, 'van': 1699, 'others': 0}
⚠️ [14/100] MVI_39031: 未找到XML文件,跳过
⚠️ [15/100] MVI_39051: 未找到XML文件,跳过
⚠️ [16/100] MVI_39211: 未找到XML文件,跳过
⚠️ [17/100] MVI_39271: 未找到XML文件,跳过
⚠️ [18/100] MVI_39311: 未找到XML文件,跳过
⚠️ [19/100] MVI_39371: 未找到XML文件,跳过
⚠️ [20/100] MVI_39401: 未找到XML文件,跳过
⚠️ [21/100] MVI_39501: 未找到XML文件,跳过
⚠️ [22/100] MVI_39511: 未找到XML文件,跳过
[23/100] 处理: MVI_39761 (1660帧)
XML解析成功: MVI_39761.xml
车辆统计: {'car': 3316, 'bus': 402, 'van': 0, 'others': 0}
[24/100] 处理: MVI_39771 (570帧)
XML解析成功: MVI_39771.xml
车辆统计: {'car': 2933, 'bus': 511, 'van': 161, 'others': 0}
[25/100] 处理: MVI_39781 (1865帧)
XML解析成功: MVI_39781.xml
车辆统计: {'car': 4707, 'bus': 2628, 'van': 91, 'others': 221}
[26/100] 处理: MVI_39801 (885帧)
XML解析成功: MVI_39801.xml
车辆统计: {'car': 4828, 'bus': 0, 'van': 0, 'others': 25}
[27/100] 处理: MVI_39811 (1070帧)
XML解析成功: MVI_39811.xml
车辆统计: {'car': 599, 'bus': 0, 'van': 0, 'others': 0}
[28/100] 处理: MVI_39821 (880帧)
XML解析成功: MVI_39821.xml
车辆统计: {'car': 3405, 'bus': 181, 'van': 609, 'others': 0}
[29/100] 处理: MVI_39851 (1420帧)
XML解析成功: MVI_39851.xml
车辆统计: {'car': 4007, 'bus': 777, 'van': 342, 'others': 0}
[30/100] 处理: MVI_39861 (745帧)
XML解析成功: MVI_39861.xml
车辆统计: {'car': 2298, 'bus': 96, 'van': 0, 'others': 0}
进度: 已处理21个视频,标签分布: {0: 21}
[31/100] 处理: MVI_39931 (1270帧)
XML解析成功: MVI_39931.xml
车辆统计: {'car': 3495, 'bus': 436, 'van': 0, 'others': 0}
[32/100] 处理: MVI_40131 (1645帧)
XML解析成功: MVI_40131.xml
车辆统计: {'car': 12117, 'bus': 2238, 'van': 969, 'others': 0}
[33/100] 处理: MVI_40141 (1600帧)
XML解析成功: MVI_40141.xml
车辆统计: {'car': 4916, 'bus': 0, 'van': 1306, 'others': 0}
[34/100] 处理: MVI_40152 (1750帧)
XML解析成功: MVI_40152.xml
车辆统计: {'car': 4599, 'bus': 256, 'van': 631, 'others': 126}
[35/100] 处理: MVI_40161 (1490帧)
XML解析成功: MVI_40161.xml
车辆统计: {'car': 4803, 'bus': 847, 'van': 975, 'others': 0}
[36/100] 处理: MVI_40162 (1765帧)
XML解析成功: MVI_40162.xml
车辆统计: {'car': 9380, 'bus': 1427, 'van': 454, 'others': 0}
[37/100] 处理: MVI_40171 (1150帧)
XML解析成功: MVI_40171.xml
车辆统计: {'car': 7084, 'bus': 1697, 'van': 147, 'others': 0}
[38/100] 处理: MVI_40181 (1700帧)
XML解析成功: MVI_40181.xml
车辆统计: {'car': 4261, 'bus': 2350, 'van': 417, 'others': 96}
[39/100] 处理: MVI_40191 (2495帧)
XML解析成功: MVI_40191.xml
车辆统计: {'car': 33647, 'bus': 0, 'van': 4633, 'others': 121}
[40/100] 处理: MVI_40192 (2195帧)
XML解析成功: MVI_40192.xml
车辆统计: {'car': 23573, 'bus': 346, 'van': 4097, 'others': 167}
进度: 已处理31个视频,标签分布: {0: 31}
[41/100] 处理: MVI_40201 (925帧)
XML解析成功: MVI_40201.xml
车辆统计: {'car': 10141, 'bus': 0, 'van': 784, 'others': 0}
[42/100] 处理: MVI_40204 (1225帧)
XML解析成功: MVI_40204.xml
车辆统计: {'car': 19387, 'bus': 267, 'van': 2774, 'others': 0}
[43/100] 处理: MVI_40211 (1950帧)
XML解析成功: MVI_40211.xml
车辆统计: {'car': 5792, 'bus': 92, 'van': 977, 'others': 31}
[44/100] 处理: MVI_40212 (1690帧)
XML解析成功: MVI_40212.xml
车辆统计: {'car': 6800, 'bus': 188, 'van': 870, 'others': 21}
[45/100] 处理: MVI_40213 (1790帧)
XML解析成功: MVI_40213.xml
车辆统计: {'car': 6301, 'bus': 133, 'van': 1124, 'others': 144}
[46/100] 处理: MVI_40241 (2320帧)
XML解析成功: MVI_40241.xml
车辆统计: {'car': 18502, 'bus': 71, 'van': 2596, 'others': 178}
[47/100] 处理: MVI_40243 (1265帧)
XML解析成功: MVI_40243.xml
车辆统计: {'car': 9047, 'bus': 171, 'van': 1259, 'others': 71}
[48/100] 处理: MVI_40244 (1345帧)
XML解析成功: MVI_40244.xml
车辆统计: {'car': 7885, 'bus': 122, 'van': 804, 'others': 0}
⚠️ [49/100] MVI_40701: 未找到XML文件,跳过
⚠️ [50/100] MVI_40711: 未找到XML文件,跳过
⚠️ [51/100] MVI_40712: 未找到XML文件,跳过
⚠️ [52/100] MVI_40714: 未找到XML文件,跳过
[53/100] 处理: MVI_40732 (2120帧)
XML解析成功: MVI_40732.xml
车辆统计: {'car': 10154, 'bus': 730, 'van': 276, 'others': 352}
⚠️ [54/100] MVI_40742: 未找到XML文件,跳过
⚠️ [55/100] MVI_40743: 未找到XML文件,跳过
[56/100] 处理: MVI_40751 (1145帧)
XML解析成功: MVI_40751.xml
车辆统计: {'car': 5418, 'bus': 1836, 'van': 41, 'others': 91}
⚠️ [57/100] MVI_40761: 未找到XML文件,跳过
⚠️ [58/100] MVI_40762: 未找到XML文件,跳过
⚠️ [59/100] MVI_40763: 未找到XML文件,跳过
⚠️ [60/100] MVI_40771: 未找到XML文件,跳过
⚠️ [61/100] MVI_40772: 未找到XML文件,跳过
⚠️ [62/100] MVI_40773: 未找到XML文件,跳过
⚠️ [63/100] MVI_40774: 未找到XML文件,跳过
⚠️ [64/100] MVI_40775: 未找到XML文件,跳过
⚠️ [65/100] MVI_40792: 未找到XML文件,跳过
⚠️ [66/100] MVI_40793: 未找到XML文件,跳过
⚠️ [67/100] MVI_40851: 未找到XML文件,跳过
⚠️ [68/100] MVI_40852: 未找到XML文件,跳过
⚠️ [69/100] MVI_40853: 未找到XML文件,跳过
⚠️ [70/100] MVI_40854: 未找到XML文件,跳过
⚠️ [71/100] MVI_40855: 未找到XML文件,跳过
⚠️ [72/100] MVI_40863: 未找到XML文件,跳过
⚠️ [73/100] MVI_40864: 未找到XML文件,跳过
[74/100] 处理: MVI_40871 (1720帧)
XML解析成功: MVI_40871.xml
车辆统计: {'car': 29634, 'bus': 1720, 'van': 5271, 'others': 0}
⚠️ [75/100] MVI_40891: 未找到XML文件,跳过
⚠️ [76/100] MVI_40901: 未找到XML文件,跳过
⚠️ [77/100] MVI_40902: 未找到XML文件,跳过
⚠️ [78/100] MVI_40903: 未找到XML文件,跳过
⚠️ [79/100] MVI_40904: 未找到XML文件,跳过
⚠️ [80/100] MVI_40905: 未找到XML文件,跳过
[81/100] 处理: MVI_40962 (1875帧)
XML解析成功: MVI_40962.xml
车辆统计: {'car': 7009, 'bus': 96, 'van': 475, 'others': 0}
[82/100] 处理: MVI_40963 (1820帧)
XML解析成功: MVI_40963.xml
车辆统计: {'car': 9906, 'bus': 726, 'van': 1007, 'others': 0}
[83/100] 处理: MVI_40981 (1995帧)
XML解析成功: MVI_40981.xml
车辆统计: {'car': 9248, 'bus': 111, 'van': 990, 'others': 0}
[84/100] 处理: MVI_40991 (1820帧)
XML解析成功: MVI_40991.xml
车辆统计: {'car': 4482, 'bus': 0, 'van': 0, 'others': 0}
[85/100] 处理: MVI_40992 (2160帧)
XML解析成功: MVI_40992.xml
车辆统计: {'car': 4926, 'bus': 0, 'van': 136, 'others': 0}
[86/100] 处理: MVI_41063 (1505帧)
XML解析成功: MVI_41063.xml
车辆统计: {'car': 8447, 'bus': 76, 'van': 1338, 'others': 187}
[87/100] 处理: MVI_41073 (1825帧)
XML解析成功: MVI_41073.xml
车辆统计: {'car': 9143, 'bus': 0, 'van': 1041, 'others': 111}
[88/100] 处理: MVI_63521 (2055帧)
XML解析成功: MVI_63521.xml
车辆统计: {'car': 11806, 'bus': 2099, 'van': 962, 'others': 221}
[89/100] 处理: MVI_63525 (985帧)
XML解析成功: MVI_63525.xml
车辆统计: {'car': 2097, 'bus': 1161, 'van': 212, 'others': 0}
[90/100] 处理: MVI_63544 (1160帧)
XML解析成功: MVI_63544.xml
车辆统计: {'car': 1524, 'bus': 121, 'van': 318, 'others': 0}
进度: 已处理52个视频,标签分布: {0: 52}
[91/100] 处理: MVI_63552 (1150帧)
XML解析成功: MVI_63552.xml
车辆统计: {'car': 5482, 'bus': 0, 'van': 1540, 'others': 91}
[92/100] 处理: MVI_63553 (1405帧)
XML解析成功: MVI_63553.xml
车辆统计: {'car': 9208, 'bus': 167, 'van': 1338, 'others': 62}
[93/100] 处理: MVI_63554 (1445帧)
XML解析成功: MVI_63554.xml
车辆统计: {'car': 8359, 'bus': 0, 'van': 1457, 'others': 31}
[94/100] 处理: MVI_63561 (1285帧)
XML解析成功: MVI_63561.xml
车辆统计: {'car': 8578, 'bus': 0, 'van': 1007, 'others': 218}
[95/100] 处理: MVI_63562 (1185帧)
XML解析成功: MVI_63562.xml
车辆统计: {'car': 5650, 'bus': 66, 'van': 877, 'others': 87}
[96/100] 处理: MVI_63563 (1390帧)
XML解析成功: MVI_63563.xml
车辆统计: {'car': 8322, 'bus': 31, 'van': 1053, 'others': 158}
⚠️ [97/100] MVI_39361: 未找到XML文件,跳过
[98/100] 处理: MVI_40172 (2635帧)
XML解析成功: MVI_40172.xml
车辆统计: {'car': 16671, 'bus': 626, 'van': 644, 'others': 201}
[99/100] 处理: MVI_40752 (2025帧)
XML解析成功: MVI_40752.xml
车辆统计: {'car': 14529, 'bus': 479, 'van': 1647, 'others': 197}
⚠️ [100/100] MVI_40892: 未找到XML文件,跳过
============================================================
处理完成!统计信息:
============================================================
✅ 成功处理 60 个视频
📊 标签分布:
car (标签0): 60 个
💾 文件已保存:
训练集: /home/chenchenyang/pro26/VideoMamba-main/data/UA-DETRAC/annotations/train.csv (48条)
验证集: /home/chenchenyang/pro26/VideoMamba-main/data/UA-DETRAC/annotations/val.csv (12条)
输出目录: /home/chenchenyang/pro26/VideoMamba-main/data/UA-DETRAC/annotations
📋 训练集前5行样例:
path frames label
MVI_40201 925 0
MVI_20033 784 0
MVI_63552 1150 0
MVI_39811 1070 0
MVI_20061 800 0
6.2 train_detrac.py
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from datasets.detrac_loader import DetracImageDataset
from models.videomamba import videomamba_tiny
def train():
# --- 配置区 ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 如果还是报 ALLOC_FAILED,请尝试把 batch_size 设为 1
batch_size = 2
epochs = 30
lr = 1e-5
# 路径配置
BASE = "/home/chenchenyang/pro26/VideoMamba-main"
train_csv = f"{BASE}/data/UA-DETRAC/annotations/train.csv"
img_prefix = f"{BASE}/data/UA-DETRAC/DETRAC-Images"
ckpt_path = f"{BASE}/checkpoints/videomamba_t16_k400_f8_res224.pth"
output_dir = f"{BASE}/outputs/detrac_checkpoints"
os.makedirs(output_dir, exist_ok=True)
# --- 数据准备 ---
dataset = DetracImageDataset(train_csv, prefix=img_prefix, mode='train')
# 自动获取类别数量
num_classes = len(dataset.label_map)
print(f"Detected Number of Classes: {num_classes}")
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
# --- 模型加载 ---
model = videomamba_tiny(num_classes=num_classes)
if os.path.exists(ckpt_path):
print(f"Loading weights from {ckpt_path}")
checkpoint = torch.load(ckpt_path, map_location='cpu')
state_dict = checkpoint['model'] if 'model' in checkpoint else checkpoint
# 移除分类头权重
for k in ['head.weight', 'head.bias']:
if k in state_dict:
del state_dict[k]
msg = model.load_state_dict(state_dict, strict=False)
print(f"Checkpoint loaded: {msg}")
model.to(device)
# --- 优化器 ---
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()
# --- 训练循环 ---
print("Start Training...")
for epoch in range(epochs):
model.train()
epoch_loss = 0
for i, (frames, labels) in enumerate(loader):
# 确保显存干净
frames = frames.to(device, non_blocking=True)
labels = labels.to(device, non_blocking=True)
# 前向传播
outputs = model(frames)
loss = criterion(outputs, labels)
# 反向传播
optimizer.zero_grad()
loss.backward()
optimizer.step()
epoch_loss += loss.item()
if i % 5 == 0:
print(f"Epoch [{epoch}] Batch [{i}/{len(loader)}] Loss: {loss.item():.4f}")
# 保存模型
save_path = os.path.join(output_dir, f"mamba_detrac_epoch_{epoch}.pth")
torch.save(model.state_dict(), save_path)
print(f"Epoch {epoch} Average Loss: {epoch_loss / len(loader):.4f}")
if __name__ == "__main__":
# 清理显存缓存后再开始
torch.cuda.empty_cache()
train()
控制台打印信息:
(videomamba) chenchenyang@musk:~/pro26/VideoMamba-main/videomamba/video_sm$ python train_detrac.py
Dataset Loaded: 48 samples.
Label Mapping: {437: 0, 570: 1, 694: 2, 745: 3, 784: 4, 800: 5, 880: 6, 885: 7, 906: 8, 925: 9, 936: 10, 1070: 11, 1145: 12, 1150: 13, 1160: 14, 1185: 15, 1225: 16, 1265: 17, 1270: 18, 1285: 19, 1345: 20, 1405: 21, 1420: 22, 1490: 23, 1505: 24, 1600: 25, 1645: 26, 1690: 27, 1700: 28, 1720: 29, 1750: 30, 1765: 31, 1790: 32, 1820: 33, 1865: 34, 1875: 35, 1995: 36, 2025: 37, 2055: 38, 2120: 39, 2195: 40, 2495: 41, 2635: 42} (Original -> Processed)
Detected Number of Classes: 43
Use checkpoint: False
Checkpoint number: 0
Loading weights from /home/chenchenyang/pro26/VideoMamba-main/checkpoints/videomamba_t16_k400_f8_res224.pth
Checkpoint loaded: _IncompatibleKeys(missing_keys=['head.weight', 'head.bias'], unexpected_keys=[])
Start Training...
Epoch [0] Batch [0/24] Loss: 3.6546
Epoch [0] Batch [5/24] Loss: 3.7955
Epoch [0] Batch [10/24] Loss: 4.0790
Epoch [0] Batch [15/24] Loss: 3.4785
Epoch [0] Batch [20/24] Loss: 3.7999
Epoch 0 Average Loss: 3.8162
Epoch [1] Batch [0/24] Loss: 3.3199
Epoch [1] Batch [5/24] Loss: 3.7698
Epoch [1] Batch [10/24] Loss: 3.9962
Epoch [1] Batch [15/24] Loss: 3.6377
Epoch [1] Batch [20/24] Loss: 3.6275
Epoch 1 Average Loss: 3.7062
Epoch [2] Batch [0/24] Loss: 3.3535
Epoch [2] Batch [5/24] Loss: 3.3764
Epoch [2] Batch [10/24] Loss: 3.5017
Epoch [2] Batch [15/24] Loss: 3.8081
Epoch [2] Batch [20/24] Loss: 3.0009
Epoch 2 Average Loss: 3.5204
Epoch [3] Batch [0/24] Loss: 3.7942
Epoch [3] Batch [5/24] Loss: 3.9798
Epoch [3] Batch [10/24] Loss: 3.7204
Epoch [3] Batch [15/24] Loss: 3.1437
Epoch [3] Batch [20/24] Loss: 3.2389
Epoch 3 Average Loss: 3.4369
Epoch [4] Batch [0/24] Loss: 4.0305
Epoch [4] Batch [5/24] Loss: 3.2542
Epoch [4] Batch [10/24] Loss: 2.9418
Epoch [4] Batch [15/24] Loss: 3.3940
Epoch [4] Batch [20/24] Loss: 3.8373
Epoch 4 Average Loss: 3.3070
Epoch [5] Batch [0/24] Loss: 3.4846
Epoch [5] Batch [5/24] Loss: 2.6126
Epoch [5] Batch [10/24] Loss: 3.4506
Epoch [5] Batch [15/24] Loss: 3.2055
Epoch [5] Batch [20/24] Loss: 2.9002
Epoch 5 Average Loss: 3.1427
Epoch [6] Batch [0/24] Loss: 2.3080
Epoch [6] Batch [5/24] Loss: 3.2808
Epoch [6] Batch [10/24] Loss: 2.9535
Epoch [6] Batch [15/24] Loss: 3.8067
Epoch [6] Batch [20/24] Loss: 3.3424
Epoch 6 Average Loss: 3.0210
Epoch [7] Batch [0/24] Loss: 2.7884
Epoch [7] Batch [5/24] Loss: 3.4619
Epoch [7] Batch [10/24] Loss: 3.0369
Epoch [7] Batch [15/24] Loss: 3.2042
Epoch [7] Batch [20/24] Loss: 3.1106
Epoch 7 Average Loss: 2.8892
Epoch [8] Batch [0/24] Loss: 2.5874
Epoch [8] Batch [5/24] Loss: 3.1701
Epoch [8] Batch [10/24] Loss: 2.8044
Epoch [8] Batch [15/24] Loss: 2.2516
Epoch [8] Batch [20/24] Loss: 1.8578
Epoch 8 Average Loss: 2.7995
Epoch [9] Batch [0/24] Loss: 3.2771
Epoch [9] Batch [5/24] Loss: 3.1957
Epoch [9] Batch [10/24] Loss: 2.7424
Epoch [9] Batch [15/24] Loss: 2.4377
Epoch [9] Batch [20/24] Loss: 1.7674
Epoch 9 Average Loss: 2.7260
Epoch [10] Batch [0/24] Loss: 2.8704
Epoch [10] Batch [5/24] Loss: 2.9154
Epoch [10] Batch [10/24] Loss: 2.6658
Epoch [10] Batch [15/24] Loss: 3.4208
Epoch [10] Batch [20/24] Loss: 2.3008
Epoch 10 Average Loss: 2.5322
Epoch [11] Batch [0/24] Loss: 2.8190
Epoch [11] Batch [5/24] Loss: 2.5571
Epoch [11] Batch [10/24] Loss: 1.9972
Epoch [11] Batch [15/24] Loss: 3.4188
Epoch [11] Batch [20/24] Loss: 2.6478
Epoch 11 Average Loss: 2.6028
Epoch [12] Batch [0/24] Loss: 1.8262
Epoch [12] Batch [5/24] Loss: 2.8447
Epoch [12] Batch [10/24] Loss: 3.2256
Epoch [12] Batch [15/24] Loss: 1.9309
Epoch [12] Batch [20/24] Loss: 2.4237
Epoch 12 Average Loss: 2.3503
Epoch [13] Batch [0/24] Loss: 2.7680
Epoch [13] Batch [5/24] Loss: 2.1569
Epoch [13] Batch [10/24] Loss: 1.7677
Epoch [13] Batch [15/24] Loss: 2.8003
Epoch [13] Batch [20/24] Loss: 1.0333
Epoch 13 Average Loss: 2.3477
Epoch [14] Batch [0/24] Loss: 2.1494
Epoch [14] Batch [5/24] Loss: 1.3228
Epoch [14] Batch [10/24] Loss: 2.5616
Epoch [14] Batch [15/24] Loss: 1.6309
Epoch [14] Batch [20/24] Loss: 3.3164
Epoch 14 Average Loss: 2.2747
Epoch [15] Batch [0/24] Loss: 2.3707
Epoch [15] Batch [5/24] Loss: 2.6382
Epoch [15] Batch [10/24] Loss: 1.8941
Epoch [15] Batch [15/24] Loss: 2.5973
Epoch [15] Batch [20/24] Loss: 2.5639
Epoch 15 Average Loss: 2.0765
Epoch [16] Batch [0/24] Loss: 1.7345
Epoch [16] Batch [5/24] Loss: 2.1229
Epoch [16] Batch [10/24] Loss: 2.0360
Epoch [16] Batch [15/24] Loss: 0.9865
Epoch [16] Batch [20/24] Loss: 2.6630
Epoch 16 Average Loss: 2.0696
Epoch [17] Batch [0/24] Loss: 2.3109
Epoch [17] Batch [5/24] Loss: 1.2731
Epoch [17] Batch [10/24] Loss: 0.9292
Epoch [17] Batch [15/24] Loss: 2.5292
Epoch [17] Batch [20/24] Loss: 1.6114
Epoch 17 Average Loss: 1.9946
Epoch [18] Batch [0/24] Loss: 2.1993
Epoch [18] Batch [5/24] Loss: 2.2475
Epoch [18] Batch [10/24] Loss: 2.0214
Epoch [18] Batch [15/24] Loss: 1.9476
Epoch [18] Batch [20/24] Loss: 1.9298
Epoch 18 Average Loss: 1.8699
Epoch [19] Batch [0/24] Loss: 2.0948
Epoch [19] Batch [5/24] Loss: 2.0984
Epoch [19] Batch [10/24] Loss: 2.4701
Epoch [19] Batch [15/24] Loss: 2.6715
Epoch [19] Batch [20/24] Loss: 2.2770
Epoch 19 Average Loss: 1.8093
Epoch [20] Batch [0/24] Loss: 1.9324
Epoch [20] Batch [5/24] Loss: 2.2716
Epoch [20] Batch [10/24] Loss: 1.9846
Epoch [20] Batch [15/24] Loss: 1.7885
Epoch [20] Batch [20/24] Loss: 2.3410
Epoch 20 Average Loss: 1.7327
Epoch [21] Batch [0/24] Loss: 2.1571
Epoch [21] Batch [5/24] Loss: 1.7383
Epoch [21] Batch [10/24] Loss: 2.2885
Epoch [21] Batch [15/24] Loss: 1.5238
Epoch [21] Batch [20/24] Loss: 2.1028
Epoch 21 Average Loss: 1.7622
Epoch [22] Batch [0/24] Loss: 1.5601
Epoch [22] Batch [5/24] Loss: 1.4831
Epoch [22] Batch [10/24] Loss: 2.4486
Epoch [22] Batch [15/24] Loss: 1.4114
Epoch [22] Batch [20/24] Loss: 1.1914
Epoch 22 Average Loss: 1.6845
Epoch [23] Batch [0/24] Loss: 1.7578
Epoch [23] Batch [5/24] Loss: 1.6144
Epoch [23] Batch [10/24] Loss: 1.9124
Epoch [23] Batch [15/24] Loss: 1.9735
Epoch [23] Batch [20/24] Loss: 1.9652
Epoch 23 Average Loss: 1.5833
Epoch [24] Batch [0/24] Loss: 1.3146
Epoch [24] Batch [5/24] Loss: 1.3311
Epoch [24] Batch [10/24] Loss: 1.9232
Epoch [24] Batch [15/24] Loss: 1.8452
Epoch [24] Batch [20/24] Loss: 1.6920
Epoch 24 Average Loss: 1.6171
Epoch [25] Batch [0/24] Loss: 1.2637
Epoch [25] Batch [5/24] Loss: 1.8589
Epoch [25] Batch [10/24] Loss: 1.8562
Epoch [25] Batch [15/24] Loss: 1.7846
Epoch [25] Batch [20/24] Loss: 1.0790
Epoch 25 Average Loss: 1.5976
Epoch [26] Batch [0/24] Loss: 2.0882
Epoch [26] Batch [5/24] Loss: 1.6817
Epoch [26] Batch [10/24] Loss: 1.4747
Epoch [26] Batch [15/24] Loss: 1.2525
Epoch [26] Batch [20/24] Loss: 0.7680
Epoch 26 Average Loss: 1.5485
Epoch [27] Batch [0/24] Loss: 1.6124
Epoch [27] Batch [5/24] Loss: 0.6368
Epoch [27] Batch [10/24] Loss: 1.2857
Epoch [27] Batch [15/24] Loss: 2.3080
Epoch [27] Batch [20/24] Loss: 1.5885
Epoch 27 Average Loss: 1.4144
Epoch [28] Batch [0/24] Loss: 1.5449
Epoch [28] Batch [5/24] Loss: 1.9299
Epoch [28] Batch [10/24] Loss: 1.0669
Epoch [28] Batch [15/24] Loss: 0.9294
Epoch [28] Batch [20/24] Loss: 1.6587
Epoch 28 Average Loss: 1.4043
Epoch [29] Batch [0/24] Loss: 1.5119
Epoch [29] Batch [5/24] Loss: 0.9596
Epoch [29] Batch [10/24] Loss: 0.7105
Epoch [29] Batch [15/24] Loss: 1.7968
Epoch [29] Batch [20/24] Loss: 1.9556
Epoch 29 Average Loss: 1.3969
6.3 val_detrac.py
import os
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from datasets.detrac_loader import DetracImageDataset
from models.videomamba import videomamba_tiny
def inference():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# --- 1. 配置路径 (确保与训练时一致) ---
BASE = "/home/chenchenyang/pro26/VideoMamba-main"
test_csv = f"{BASE}/data/UA-DETRAC/annotations/train.csv" # 先用训练集测,看有没有过拟合
img_prefix = f"{BASE}/data/UA-DETRAC/DETRAC-Images"
# 加载你刚刚跑出来的第 29 轮权重
model_path = f"{BASE}/outputs/detrac_checkpoints/mamba_detrac_epoch_29.pth"
# --- 2. 加载数据 ---
# mode='val' 会使用中间抽帧,结果更稳定
dataset = DetracImageDataset(test_csv, prefix=img_prefix, mode='val')
loader = DataLoader(dataset, batch_size=1, shuffle=False)
# 获取反向映射表 (ID -> 原始标签)
inv_label_map = {v: k for k, v in dataset.label_map.items()}
num_classes = len(dataset.label_map)
# --- 3. 初始化模型并加载权重 ---
model = videomamba_tiny(num_classes=num_classes)
if os.path.exists(model_path):
print(f"Loading trained model from {model_path}")
model.load_state_dict(torch.load(model_path, map_location='cpu'))
else:
print("Model file not found!")
return
model.to(device)
model.eval()
# --- 4. 开始推理 ---
correct = 0
total = 0
print(f"\n{'Sample Index':<15} | {'True Label':<12} | {'Pred Label':<12} | {'Status'}")
print("-" * 60)
with torch.no_grad():
for i, (frames, label) in enumerate(loader):
frames = frames.to(device)
label = label.to(device)
outputs = model(frames)
prob = F.softmax(outputs, dim=1)
pred = torch.argmax(prob, dim=1)
true_origin = inv_label_map[label.item()]
pred_origin = inv_label_map[pred.item()]
status = "✅" if label == pred else "❌"
if label == pred:
correct += 1
total += 1
print(f"{i:<15} | {true_origin:<12} | {pred_origin:<12} | {status}")
acc = 100 * correct / total
print("-" * 60)
print(f"Final Accuracy: {acc:.2f}% ({correct}/{total})")
if __name__ == "__main__":
inference()
控制台打印:
(videomamba) chenchenyang@musk:~/pro26/VideoMamba-main/videomamba/video_sm$ python val_detrac.py
Dataset Loaded: 48 samples.
Label Mapping: {437: 0, 570: 1, 694: 2, 745: 3, 784: 4, 800: 5, 880: 6, 885: 7, 906: 8, 925: 9, 936: 10, 1070: 11, 1145: 12, 1150: 13, 1160: 14, 1185: 15, 1225: 16, 1265: 17, 1270: 18, 1285: 19, 1345: 20, 1405: 21, 1420: 22, 1490: 23, 1505: 24, 1600: 25, 1645: 26, 1690: 27, 1700: 28, 1720: 29, 1750: 30, 1765: 31, 1790: 32, 1820: 33, 1865: 34, 1875: 35, 1995: 36, 2025: 37, 2055: 38, 2120: 39, 2195: 40, 2495: 41, 2635: 42} (Original -> Processed)
Use checkpoint: False
Checkpoint number: 0
Loading trained model from /home/chenchenyang/pro26/VideoMamba-main/outputs/detrac_checkpoints/mamba_detrac_epoch_29.pth
Sample Index | True Label | Pred Label | Status
------------------------------------------------------------
0 | 925 | 1225 | ❌
1 | 784 | 784 | ✅
2 | 1150 | 1405 | ❌
3 | 1070 | 1070 | ✅
4 | 800 | 800 | ✅
5 | 906 | 694 | ❌
6 | 1145 | 1145 | ✅
7 | 800 | 800 | ✅
8 | 1820 | 1820 | ✅
9 | 1420 | 1420 | ✅
10 | 1690 | 1690 | ✅
11 | 2635 | 2635 | ✅
12 | 1490 | 1765 | ❌
13 | 1185 | 1185 | ✅
14 | 1865 | 1865 | ✅
15 | 1150 | 1150 | ✅
16 | 800 | 800 | ✅
17 | 2195 | 2195 | ✅
18 | 1765 | 1765 | ✅
19 | 885 | 885 | ✅
20 | 1750 | 1750 | ✅
21 | 1285 | 1285 | ✅
22 | 800 | 800 | ✅
23 | 1225 | 1225 | ✅
24 | 1405 | 1405 | ✅
25 | 1720 | 1720 | ✅
26 | 1265 | 1345 | ❌
27 | 2495 | 2495 | ✅
28 | 1995 | 1995 | ✅
29 | 936 | 936 | ✅
30 | 1270 | 1270 | ✅
31 | 437 | 437 | ✅
32 | 1505 | 1505 | ✅
33 | 2120 | 2120 | ✅
34 | 1790 | 1690 | ❌
35 | 1600 | 1600 | ✅
36 | 2055 | 2055 | ✅
37 | 800 | 800 | ✅
38 | 1645 | 800 | ❌
39 | 880 | 880 | ✅
40 | 2025 | 2025 | ✅
41 | 745 | 745 | ✅
42 | 694 | 694 | ✅
43 | 1875 | 1820 | ❌
44 | 570 | 570 | ✅
45 | 1700 | 1700 | ✅
46 | 1160 | 1160 | ✅
47 | 1345 | 1345 | ✅
------------------------------------------------------------
Final Accuracy: 83.33% (40/48)
6.4 detrac_loader.py
import os
import torch
import numpy as np
from torch.utils.data import Dataset
from PIL import Image
import pandas as pd
class DetracImageDataset(Dataset):
def __init__(self, anno_path, prefix='', split=' ', mode='train',
clip_len=8, frame_sample_rate=4, filename_tmpl='img{:05}.jpg'):
self.anno_path = anno_path
self.prefix = prefix
self.split = split
self.mode = mode
self.clip_len = clip_len
self.frame_sample_rate = frame_sample_rate
self.filename_tmpl = filename_tmpl
# 读取 CSV
# 注意:这里根据你之前的报错,默认 delimiter 是空格或逗号,pd.read_csv 比较智能
cleaned = pd.read_csv(self.anno_path, header=None, sep=None, engine='python')
self.dataset_samples = list(cleaned.values[:, 0])
raw_labels = list(cleaned.values[:, 1])
# --- 自动标签映射:确保标签在 [0, num_classes-1] ---
unique_labels = sorted(list(set(raw_labels)))
self.label_map = {val: i for i, val in enumerate(unique_labels)}
self.label_array = [self.label_map[l] for l in raw_labels]
print(f"Dataset Loaded: {len(self.dataset_samples)} samples.")
print(f"Label Mapping: {self.label_map} (Original -> Processed)")
def __len__(self):
return len(self.dataset_samples)
def __getitem__(self, index):
folder_name = str(self.dataset_samples[index])
label = int(self.label_array[index])
folder_path = os.path.join(self.prefix, folder_name)
# 1. 获取图片列表
try:
all_imgs = sorted([img for img in os.listdir(folder_path) if img.endswith('.jpg')])
except FileNotFoundError:
print(f"Error: Folder not found {folder_path}")
return torch.zeros(3, self.clip_len, 224, 224), 0
total_frames = len(all_imgs)
# 2. 抽帧逻辑
converted_len = self.clip_len * self.frame_sample_rate
if total_frames <= converted_len:
indices = np.linspace(0, total_frames - 1, num=self.clip_len).astype(np.int64)
else:
if self.mode == 'train':
start_idx = np.random.randint(0, total_frames - converted_len)
else:
start_idx = (total_frames - converted_len) // 2
indices = np.arange(start_idx, start_idx + converted_len, self.frame_sample_rate)[:self.clip_len]
# 3. 读取并 Resize
frames = []
for i in indices:
img_name = self.filename_tmpl.format(i + 1)
img_path = os.path.join(folder_path, img_name)
if not os.path.exists(img_path):
img_path = os.path.join(folder_path, all_imgs[i])
with Image.open(img_path) as img:
img = img.convert('RGB').resize((224, 224))
img_tensor = torch.from_numpy(np.array(img)).permute(2, 0, 1).float() / 255.0
frames.append(img_tensor)
# 4. 拼成 (C, T, H, W) 并归一化
frames = torch.stack(frames).permute(1, 0, 2, 3)
mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1, 1)
std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1, 1)
frames = (frames - mean) / std
return frames, label
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐
所有评论(0)