什么是MMPretrain

MMPretrain 是一个全新升级的预训练开源算法框架,旨在提供各种强大的预训练主干网络, 并支持了不同的预训练策略。MMPretrain 源自著名的开源项目 MMClassification 和 MMSelfSup,并开发了许多令人兴奋的新功能。 目前,预训练阶段对于视觉识别至关重要,凭借丰富而强大的预训练模型,我们能够改进各种下游视觉任务。

中文教程:欢迎来到 MMPretrain 中文教程! — MMPretrain 1.1.1 文档

github:https://github.com/open-mmlab/mmpretrain

如何安装MMPretrain

和MMDetection很像,我直接装在了MMDetection那个环境里面,MMDetection安装可参考:【学习】使用mmdetection,在自定义数据集上进行训练。-CSDN博客

官方教程很详细:依赖环境 — MMPretrain 1.1.1 文档

注意 -e 后面有个点。

mim install -e .

准备数据集

CUB-200-2011我是提前下载好的。

下载链接:Perona Lab - CUB-200-2011 (caltech.edu)

如何准备配置文件

配置文件基本不用自己写,都给写好了,“这你受得了吗!”

官方教程左侧

 选择模型库,找到SWIN-TRANSFORMER

 或者直接上链接:Swin-Transformer — MMPretrain 1.1.1 文档

 拉到底层 

 config里面,即配置文件。

也可以在源代码按照如下路径查找

一般先在根目录(或者某个夹子里),新建一个py文件,命名为swin-large_8xb8_cub-384px_test.py,前面是原config名称,最后加一个你的任务标记。

你新建的config文件里,需要根据位置,改一下_base_里的文件路径。

配置文件最上面 _base_ 是继承了文件中的配置,凡是这些文件里的配置,在这里都能重写,别在源文件里面改

_base_ = [
    '../_base_/models/swin_transformer/large_384.py',
    '../_base_/datasets/cub_bs8_384.py',
    '../_base_/schedules/cub_bs64.py',
    '../_base_/default_runtime.py',
]

# model settings
checkpoint = 'https://download.openmmlab.com/mmclassification/v0/swin-transformer/convert/swin-large_3rdparty_in21k-384px.pth'  # noqa
model = dict(
    type='ImageClassifier',
    backbone=dict(
        init_cfg=dict(
            type='Pretrained', checkpoint=checkpoint, prefix='backbone')),
    head=dict(num_classes=200, ))

# schedule settings
optim_wrapper = dict(
    optimizer=dict(
        _delete_=True,
        type='AdamW',
        lr=5e-6,
        weight_decay=0.0005,
        eps=1e-8,
        betas=(0.9, 0.999)),
    paramwise_cfg=dict(
        norm_decay_mult=0.0,
        bias_decay_mult=0.0,
        custom_keys={
            '.absolute_pos_embed': dict(decay_mult=0.0),
            '.relative_position_bias_table': dict(decay_mult=0.0)
        }),
    clip_grad=dict(max_norm=5.0),
)

default_hooks = dict(
    # log every 20 intervals
    logger=dict(type='LoggerHook', interval=20),
    # save last three checkpoints
    checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=3))

我改了一下路径,你需要改什么,就重写一下就ok。

重写:把继承的4个.py文件找到,打开后看看你要改的参数在哪里,把那一小段代码(比如train_dataloader)复制到该配置文件,修改一下,不用改的删了就可以了。

train_dataloader = dict(
    batch_size=8,
    num_workers=2,
    dataset=dict(
        data_root='data/CUB-200-2011/CUB_200_2011',
        split='train'),
)

val_dataloader = dict(
    batch_size=8,
    num_workers=2,
    dataset=dict(
        data_root='data/CUB-200-2011/CUB_200_2011'),
    sampler=dict(type='DefaultSampler', shuffle=False),
)
val_evaluator = dict(type='Accuracy', topk=(1, ))

train_cfg = dict(by_epoch=True, max_epochs=30, val_interval=10)

如何训练

配置文件做好后,即可开始训练。

python tools/train.py ${CONFIG_FILE} [ARGS]

${CONFIG_FILE} [ARGS]是配置文件的相对路径。 

如何测试

 python tools/test.py swin-large_8xb8_cub-384px_test.py work_dirs/swin-large_8xb8_cub-384px_test/epoch_30.pth --show

会进行可视化测试 

若想要画出 accuracy/top1或loss曲线

--keys 绘制参数,画loss则改为loss   --out 输出文件名、地址

 python tools/analysis_tools/analyze_logs.py plot_curve work_dirs/swin-large_8xb8_cub-384px_test/20231215_103600/vis_data/20231215_103600.json --keys accuracy/top1 --out accuracy1.pdf

Logo

旨在为数千万中国开发者提供一个无缝且高效的云端环境,以支持学习、使用和贡献开源项目。

更多推荐