目标跟踪之Pytracking系列代码训练(ATOM\DiMP\PrDiMP\KYS\Super_DiMP\KeepTrack)
Pytracking工程路径:https://github.com/visionml/pytracking
参考链接:https://github.com/visionml/pytracking/blob/master/ltr/README.md#ATOM
pytracking代码环境配置参考大佬的博客
一、训练模型
待训练数据集路径:修改/ltr/admin/local.py
注:根据ATOM原文的描述,训练ATOM共用到了LaSOT、TrackingNet、COCO等三个数据集,而代码新增了GOT-10k数据集
self.tensorboard_dir = self.workspace_dir + '/home1/users/huangbo/anaconda3/envs/pytracking/lib/python3.7/site-packages/tensorboard/' # Directory for tensorboard files.
self.pretrained_networks = self.workspace_dir + '/pretrained_networks/'
self.lasot_dir = '/data3/publicData/Datasets/LaSOT/'
self.got10k_dir = '/data3/publicData/Datasets/GOT-10k/train/'
self.trackingnet_dir = '/data3/publicData/Datasets/TrackingNet/'
self.coco_dir = '/data3/publicData/Datasets/COCO/' # default version is COCO2014
修改好数据集路径,运行python run_training.py bbreg atom
,即可正常训练
二、使用新数据集进行训练
如果是使用自己的训练集进行训练的话,主要是仿照ltr/dataset底下的lasot.py等文件建立一个自己的文件YourDataset.py
,并在/ltr/dataset/__init__.py
中声明
from .lasot import Lasot
from .got10k import Got10k
from .tracking_net import TrackingNet
from .YourDataset import YourDataset
之后修改ltr/train_settings/bbreg/atom.py
(如果是DiMP,KYS则到train_settings文件底下修改对应文件)
# Train datasets 训练集路径
# lasot_train = Lasot(settings.env.lasot_dir, split='train')
# got10k_train = Got10k(settings.env.got10k_dir, split='vottrain')
# trackingnet_train = TrackingNet(settings.env.trackingnet_dir, set_ids=list(range(4)))
# coco_train = MSCOCOSeq(settings.env.coco_dir)
yourdataset_train = YourDataset(settings.env.yourdataset_dir, split='train')
# Validation datasets 验证集路径
yourdataset_val = YourDataset(settings.env.yourdataset_dir, split='val')
# got10k_val = Got10k(settings.env.got10k_dir, split='votval')
# The joint augmentation transform, that is applied to the pairs jointly
transform_joint = tfm.Transform(tfm.ToGrayscale(probability=0.05))
# The augmentation transform applied to the training set (individually to each image in the pair)
transform_train = tfm.Transform(tfm.ToTensorAndJitter(0.2),
tfm.Normalize(mean=settings.normalize_mean, std=settings.normalize_std))
# The augmentation transform applied to the validation set (individually to each image in the pair)
transform_val = tfm.Transform(tfm.ToTensor(),
tfm.Normalize(mean=settings.normalize_mean, std=settings.normalize_std))
# Data processing to do on the training pairs
proposal_params = {'min_iou': 0.1, 'boxes_per_frame': 16, 'sigma_factor': [0.01, 0.05, 0.1, 0.2, 0.3]}
data_processing_train = processing.ATOMProcessing(search_area_factor=settings.search_area_factor,
output_sz=settings.output_sz,
center_jitter_factor=settings.center_jitter_factor,
scale_jitter_factor=settings.scale_jitter_factor,
mode='sequence',
proposal_params=proposal_params,
transform=transform_train,
joint_transform=transform_joint)
# Data processing to do on the validation pairs
data_processing_val = processing.ATOMProcessing(search_area_factor=settings.search_area_factor,
output_sz=settings.output_sz,
center_jitter_factor=settings.center_jitter_factor,
scale_jitter_factor=settings.scale_jitter_factor,
mode='sequence',
proposal_params=proposal_params,
transform=transform_val,
joint_transform=transform_joint)
# The sampler for training 如果有多个数据集 第二个参数[1]改成[1,1,1,...]
dataset_train = sampler.ATOMSampler([yourdataset_train], [1],
samples_per_epoch=1000*settings.batch_size, max_gap=50, processing=data_processing_train)
# dataset_train = sampler.ATOMSampler([lasot_train, got10k_train, trackingnet_train, coco_train], [1,1,1,1], samples_per_epoch=1000*settings.batch_size, max_gap=50, processing=data_processing_train)
# The loader for training
loader_train = LTRLoader('train', dataset_train, training=True, batch_size=settings.batch_size, num_workers=settings.num_workers,
shuffle=True, drop_last=True, stack_dim=1)
# The sampler for validation
dataset_val = sampler.ATOMSampler([yourdataset_val], [1], samples_per_epoch=500*settings.batch_size, max_gap=50,
processing=data_processing_val)
# The loader for validation
loader_val = LTRLoader('val', dataset_val, training=False, batch_size=settings.batch_size, num_workers=settings.num_workers,
shuffle=False, drop_last=True, epoch_interval=5, stack_dim=1)
# Create network and actor 默认训练的模型为resnet18,可以修改成50
net = atom_models.atom_resnet18(backbone_pretrained=True)
objective = nn.MSELoss()
actor = actors.AtomActor(net=net, objective=objective)
之后运行python run_training.py bbreg atom
,即可正常训练
KeepTrack跟踪代码训练
KeepTrack是ICCV2021刚中的论文,与pytracking系列其他跟踪器的训练不同之处在于,KeepTrack需要对数据集重新清洗一下。KeepTrack设计了一个target candidate association network来关联不同帧之间的背景干扰物,从而实现抑制背景干扰物的作用。因此KeepTrack从Lasot数据集中挑选了部分存在大量干扰物的视频序列用于训练,而base tracker是直接调用super_DiMP模型。
进入到pytracking/util_scripts
目录,运行python create_distractor_dataset.py dimp_simple super_dimp_simple lasot_train ./
最终生成target_candidates_dataset_dimp_simple_super_dimp_simple.json
文件,内容如下:
相当是调用Super_DiMP模型跑了一遍数据集,挑出合适的帧来构建训练集。
训练
在ltr/admin/local.py
中修改self.lasot_candidate_matching_dataset_path=‘Path to target_candidates_dataset_dimp_simple_super_dimp_simple.json’
,开始训练
更多推荐
所有评论(0)