pytorch训练中断后,如何在之前的断点处继续训练
·
我们在训练模型的时候经常出现各种问题导致训练中断,比方说断电,或者关机之类的导致电脑系统关闭,从而将模型训练中断,那么如何在模型中断后,能够保留之前的训练结果不被丢失,同时又可以继续之前的断点处继续训练?
首先在代码离需要保存模型,比方说我们模型设置训练5000轮,那么我们可以选择每100轮保存一次模型,这样的话,在训练的过程中就能保存下100,200,300.。。。等轮数时候的模型,那么当模型训练到400轮的时候突然训练中断,那么我们就可以通过加载400轮的参数来进行继续训练,其实这个过程就类似在预训练模型的基础上进行训练。下面简单粗暴上代码:
1、保存模型
torch.save(checkpoint, checkpoint_path)
其中checkpoint其实保存的就是模型的一些参数,比方说下面这种字典形式的保存所需的模型参数:
checkpoint = { 'model': model_state_dict, 'generator': generator_state_dict, 'opt': model_opt, 'optim': optim, }
checkpoint_path则是表示保存的模型
checkpoint_path = '%s_step_%d.pt' % (self.base_path, step)
save_checkpoint_steps是保存的间隔轮数,step是保存的轮数,比方说save_checkpoint_steps=100,那么step的取值就是100,200,300,400等,下面的代码解释step的取值由来。
if step % self.save_checkpoint_steps != 0: return chkpt, chkpt_name = self._save(step)
其中_save函数就是实现了前面checkpoint的内容的保存。
模型的保存设置就此结束。
2、模型的加载
假如此时模型训练中断了,我们得在代码里设置一个参数,这个参数用来查找确定当前路径下是否有已存在得模型。
# 如果有保存的模型,则加载模型,并在其基础上继续训练
if os.path.exists(log_dir):
checkpoint = torch.load(log_dir)
model.load_state_dict(checkpoint['model'])
generator.load_state_dict(checkpoint['generator'])
start_epoch = checkpoint['model_opt']
optim=checkpoint['optim']
print('加载 epoch {} 成功!'.format(start_epoch))
else:
start_epoch = 0
print('无保存模型,将从头开始训练!')
或者设置一个变量train_from,若赋值已有模型得路径,则继续训练;若为None,那么从头训练。这块代码既可以用于训练中断,又可以用于使用预训练模型。
if opt.train_from:#是否存在预训练模型 logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from)#加载预训练模型的检查点 model_opt = checkpoint['opt'] else: checkpoint = None model_opt = opt
加油,come on!
更多推荐
已为社区贡献5条内容
所有评论(0)