Pytracking工程路径:https://github.com/visionml/pytracking
参考链接:https://github.com/visionml/pytracking/blob/master/ltr/README.md#ATOM
pytracking代码环境配置参考大佬的博客

一、训练模型

待训练数据集路径:修改/ltr/admin/local.py
注:根据ATOM原文的描述,训练ATOM共用到了LaSOT、TrackingNet、COCO等三个数据集,而代码新增了GOT-10k数据集

在这里插入图片描述

self.tensorboard_dir = self.workspace_dir + '/home1/users/huangbo/anaconda3/envs/pytracking/lib/python3.7/site-packages/tensorboard/'    # Directory for tensorboard files.
self.pretrained_networks = self.workspace_dir + '/pretrained_networks/'
self.lasot_dir = '/data3/publicData/Datasets/LaSOT/'
self.got10k_dir = '/data3/publicData/Datasets/GOT-10k/train/'
self.trackingnet_dir = '/data3/publicData/Datasets/TrackingNet/'
self.coco_dir = '/data3/publicData/Datasets/COCO/'   # default version is COCO2014

修改好数据集路径,运行python run_training.py bbreg atom,即可正常训练
在这里插入图片描述

二、使用新数据集进行训练

如果是使用自己的训练集进行训练的话,主要是仿照ltr/dataset底下的lasot.py等文件建立一个自己的文件YourDataset.py,并在/ltr/dataset/__init__.py中声明

from .lasot import Lasot
from .got10k import Got10k
from .tracking_net import TrackingNet
from .YourDataset import YourDataset

之后修改ltr/train_settings/bbreg/atom.py(如果是DiMP,KYS则到train_settings文件底下修改对应文件)

    # Train datasets  训练集路径
    # lasot_train = Lasot(settings.env.lasot_dir, split='train')
    # got10k_train = Got10k(settings.env.got10k_dir, split='vottrain')
    # trackingnet_train = TrackingNet(settings.env.trackingnet_dir, set_ids=list(range(4)))
    # coco_train = MSCOCOSeq(settings.env.coco_dir)
    yourdataset_train = YourDataset(settings.env.yourdataset_dir, split='train')

    # Validation datasets 验证集路径
    yourdataset_val = YourDataset(settings.env.yourdataset_dir, split='val')
    # got10k_val = Got10k(settings.env.got10k_dir, split='votval')

    # The joint augmentation transform, that is applied to the pairs jointly
    transform_joint = tfm.Transform(tfm.ToGrayscale(probability=0.05))

    # The augmentation transform applied to the training set (individually to each image in the pair)
    transform_train = tfm.Transform(tfm.ToTensorAndJitter(0.2),
                                    tfm.Normalize(mean=settings.normalize_mean, std=settings.normalize_std))

    # The augmentation transform applied to the validation set (individually to each image in the pair)
    transform_val = tfm.Transform(tfm.ToTensor(),
                                  tfm.Normalize(mean=settings.normalize_mean, std=settings.normalize_std))

    # Data processing to do on the training pairs
    proposal_params = {'min_iou': 0.1, 'boxes_per_frame': 16, 'sigma_factor': [0.01, 0.05, 0.1, 0.2, 0.3]}
    data_processing_train = processing.ATOMProcessing(search_area_factor=settings.search_area_factor,
                                                      output_sz=settings.output_sz,
                                                      center_jitter_factor=settings.center_jitter_factor,
                                                      scale_jitter_factor=settings.scale_jitter_factor,
                                                      mode='sequence',
                                                      proposal_params=proposal_params,
                                                      transform=transform_train,
                                                      joint_transform=transform_joint)

    # Data processing to do on the validation pairs
    data_processing_val = processing.ATOMProcessing(search_area_factor=settings.search_area_factor,
                                                    output_sz=settings.output_sz,
                                                    center_jitter_factor=settings.center_jitter_factor,
                                                    scale_jitter_factor=settings.scale_jitter_factor,
                                                    mode='sequence',
                                                    proposal_params=proposal_params,
                                                    transform=transform_val,
                                                    joint_transform=transform_joint)

    # The sampler for training 如果有多个数据集 第二个参数[1]改成[1,1,1,...]
    dataset_train = sampler.ATOMSampler([yourdataset_train], [1],
                                        samples_per_epoch=1000*settings.batch_size, max_gap=50, processing=data_processing_train)
    # dataset_train = sampler.ATOMSampler([lasot_train, got10k_train, trackingnet_train, coco_train], [1,1,1,1], samples_per_epoch=1000*settings.batch_size, max_gap=50, processing=data_processing_train)
    # The loader for training
    loader_train = LTRLoader('train', dataset_train, training=True, batch_size=settings.batch_size, num_workers=settings.num_workers,
                             shuffle=True, drop_last=True, stack_dim=1)

    # The sampler for validation
    dataset_val = sampler.ATOMSampler([yourdataset_val], [1], samples_per_epoch=500*settings.batch_size, max_gap=50,
                                      processing=data_processing_val)

    # The loader for validation
    loader_val = LTRLoader('val', dataset_val, training=False, batch_size=settings.batch_size, num_workers=settings.num_workers,
                           shuffle=False, drop_last=True, epoch_interval=5, stack_dim=1)

    # Create network and actor 默认训练的模型为resnet18,可以修改成50
    net = atom_models.atom_resnet18(backbone_pretrained=True)
    objective = nn.MSELoss()
    actor = actors.AtomActor(net=net, objective=objective)

之后运行python run_training.py bbreg atom,即可正常训练

KeepTrack跟踪代码训练

KeepTrack是ICCV2021刚中的论文,与pytracking系列其他跟踪器的训练不同之处在于,KeepTrack需要对数据集重新清洗一下。KeepTrack设计了一个target candidate association network来关联不同帧之间的背景干扰物,从而实现抑制背景干扰物的作用。因此KeepTrack从Lasot数据集中挑选了部分存在大量干扰物的视频序列用于训练,而base tracker是直接调用super_DiMP模型。
进入到pytracking/util_scripts目录,运行python create_distractor_dataset.py dimp_simple super_dimp_simple lasot_train ./
最终生成target_candidates_dataset_dimp_simple_super_dimp_simple.json文件,内容如下:
在这里插入图片描述
相当是调用Super_DiMP模型跑了一遍数据集,挑出合适的帧来构建训练集。
训练
ltr/admin/local.py中修改self.lasot_candidate_matching_dataset_path=‘Path to target_candidates_dataset_dimp_simple_super_dimp_simple.json’,开始训练

Logo

旨在为数千万中国开发者提供一个无缝且高效的云端环境,以支持学习、使用和贡献开源项目。

更多推荐