【学习】使用MMPretrain训练和测试CUB-200-2011数据集
什么是MMPretrain
MMPretrain 是一个全新升级的预训练开源算法框架,旨在提供各种强大的预训练主干网络, 并支持了不同的预训练策略。MMPretrain 源自著名的开源项目 MMClassification 和 MMSelfSup,并开发了许多令人兴奋的新功能。 目前,预训练阶段对于视觉识别至关重要,凭借丰富而强大的预训练模型,我们能够改进各种下游视觉任务。
中文教程:欢迎来到 MMPretrain 中文教程! — MMPretrain 1.1.1 文档
github:https://github.com/open-mmlab/mmpretrain
如何安装MMPretrain
和MMDetection很像,我直接装在了MMDetection那个环境里面,MMDetection安装可参考:【学习】使用mmdetection,在自定义数据集上进行训练。-CSDN博客
官方教程很详细:依赖环境 — MMPretrain 1.1.1 文档
注意 -e 后面有个点。
mim install -e .
准备数据集
CUB-200-2011我是提前下载好的。
下载链接:Perona Lab - CUB-200-2011 (caltech.edu)
如何准备配置文件
配置文件基本不用自己写,都给写好了,“这你受得了吗!”
官方教程左侧
选择模型库,找到SWIN-TRANSFORMER
或者直接上链接:Swin-Transformer — MMPretrain 1.1.1 文档
拉到底层
config里面,即配置文件。
也可以在源代码按照如下路径查找
一般先在根目录(或者某个夹子里),新建一个py文件,命名为swin-large_8xb8_cub-384px_test.py,前面是原config名称,最后加一个你的任务标记。
你新建的config文件里,需要根据位置,改一下_base_里的文件路径。
配置文件最上面 _base_ 是继承了文件中的配置,凡是这些文件里的配置,在这里都能重写,别在源文件里面改。
_base_ = [
'../_base_/models/swin_transformer/large_384.py',
'../_base_/datasets/cub_bs8_384.py',
'../_base_/schedules/cub_bs64.py',
'../_base_/default_runtime.py',
]
# model settings
checkpoint = 'https://download.openmmlab.com/mmclassification/v0/swin-transformer/convert/swin-large_3rdparty_in21k-384px.pth' # noqa
model = dict(
type='ImageClassifier',
backbone=dict(
init_cfg=dict(
type='Pretrained', checkpoint=checkpoint, prefix='backbone')),
head=dict(num_classes=200, ))
# schedule settings
optim_wrapper = dict(
optimizer=dict(
_delete_=True,
type='AdamW',
lr=5e-6,
weight_decay=0.0005,
eps=1e-8,
betas=(0.9, 0.999)),
paramwise_cfg=dict(
norm_decay_mult=0.0,
bias_decay_mult=0.0,
custom_keys={
'.absolute_pos_embed': dict(decay_mult=0.0),
'.relative_position_bias_table': dict(decay_mult=0.0)
}),
clip_grad=dict(max_norm=5.0),
)
default_hooks = dict(
# log every 20 intervals
logger=dict(type='LoggerHook', interval=20),
# save last three checkpoints
checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=3))
我改了一下路径,你需要改什么,就重写一下就ok。
重写:把继承的4个.py文件找到,打开后看看你要改的参数在哪里,把那一小段代码(比如train_dataloader)复制到该配置文件,修改一下,不用改的删了就可以了。
train_dataloader = dict(
batch_size=8,
num_workers=2,
dataset=dict(
data_root='data/CUB-200-2011/CUB_200_2011',
split='train'),
)
val_dataloader = dict(
batch_size=8,
num_workers=2,
dataset=dict(
data_root='data/CUB-200-2011/CUB_200_2011'),
sampler=dict(type='DefaultSampler', shuffle=False),
)
val_evaluator = dict(type='Accuracy', topk=(1, ))
train_cfg = dict(by_epoch=True, max_epochs=30, val_interval=10)
如何训练
配置文件做好后,即可开始训练。
python tools/train.py ${CONFIG_FILE} [ARGS]
${CONFIG_FILE} [ARGS]是配置文件的相对路径。
如何测试
python tools/test.py swin-large_8xb8_cub-384px_test.py work_dirs/swin-large_8xb8_cub-384px_test/epoch_30.pth --show
会进行可视化测试
若想要画出 accuracy/top1或loss曲线
--keys 绘制参数,画loss则改为loss --out 输出文件名、地址
python tools/analysis_tools/analyze_logs.py plot_curve work_dirs/swin-large_8xb8_cub-384px_test/20231215_103600/vis_data/20231215_103600.json --keys accuracy/top1 --out accuracy1.pdf
更多推荐
所有评论(0)