PyTorch学习笔记:使用state_dict来保存和加载模型
1. state_dict简介
state_dict是Python的字典对象,可用于保存模型参数、超参数以及优化器(torch.optim)的状态信息。需要注意的是,只有具有可学习参数的层(如卷积层、线性层等)才有state_dict。
下面就拿官方教程中的一个小示例来说明state_dict的使用:
import torch
import torch.nn as nn
import torch.optim as optim
# 定义模型
class TheModelClass(nn.Module):
def __init__(self):
super(TheModelClass, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
# 初始化模型
model = TheModelClass()
# 初始化优化器
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
# 打印模型的状态字典
print("Model's state_dict:")
for param_tensor in model.state_dict():
print(param_tensor, "\t", model.state_dict()[param_tensor].size())
# 打印优化器的状态字典
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
print(var_name, "\t", optimizer.state_dict()[var_name])
让我们来运行一下以上代码:
从以上代码及运行结果可知,state_dict将模型的每一层映射到一个参数张量。在Python中,可以对state_dict进行保存、加载、更新、修改等操作。
下面我们就来看一下PyTorch如何通过state_dict来保存和加载模型。
2. 保存和加载state_dict
可以通过torch.save()来保存模型的state_dict,即只保存学习到的模型参数,并通过load_state_dict()来加载并恢复模型参数。PyTorch中最常见的模型保存扩展名为'.pt'或'.pth'。
下面我们就将上个例子中构造的简单模型TheModelClass的参数保存在state_dict,然后通过load_state_dict()来加载模型参数。
......
# 将模型保存到当前路径,名称为test_state_dict.pth
PATH = './test_state_dict.pth'
torch.save(model.state_dict(), PATH)
model = TheModelClass() # 首先通过代码获取模型结构
model.load_state_dict(torch.load(PATH)) # 然后加载模型的state_dict
model.eval()
注意:load_state_dict()函数只接受字典对象,不可直接传入模型路径,所以需要先使用torch.load()反序列化已保存的state_dict。
另外,在使用模型做推理之前,需要调用model.eval()函数将dropout和batch normalization层设置为评估模式,否则会导致模型推理结果不一致。
当然,除了保存state_dict,PyTorch还支持保存和加载整个模型。
3. 保存和加载完整模型
保存和加载整个模型的代码如下:
# 保存完整模型
torch.save(model, PATH)
# 加载完整模型
model = torch.load(PATH)
model.eval()
这种方式虽然代码看起来较state_dict方式要简洁,但是灵活性会差一些。因为torch.save()函数使用Python的pickle
模块进行序列化,但pickle无法保存模型本身,而是保存包含类的文件路径,该文件会在模型加载时使用。所以当在其他项目对模型进行重构之后,就可能会出现意想不到的错误。
4. 保存和加载checkpoint用于继续训练或推理
除了以上两种保存模型的方式,PyTorch还支持以checkpoint方式保存模型训练的中间结果,以实现模型的继续训练或者推理。这种方式下,保存的内容不仅包含模型的state_dict,还会保存优化器的state_dict,以及其他参数如loss、epoch等。
保存:
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
...
}, PATH)
加载:
model = TheModelClass()
optimizer = TheOptimizerClass()
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
model.eval()
model.train()
checkpoint在PyTorch中常保存为.tar的文件扩展名。
注:以上checkpoint保存和加载的代码未经本人测试。
5. 迁移学习下的热启动模式
我们在工程中,常常用到迁移学习,利用训练好的模型在新的数据集上进行迁移训练,可达到使用少量数据进行快速训练的目的。
在迁移学习中,我们常常需要对预训练模型进行部分加载的需要,这个时候我们就要用到热启动模式,可通过在load_state_dict()函数中将strict参数设置为False来忽略非匹配键的参数。
# 保存模型state_dict
torch.save(modelA.state_dict(), PATH)
# 热加载模型
modelB = TheModelBClass()
modelB.load_state_dict(torch.load(PATH), strict=False)
更多推荐
所有评论(0)