动手学深度学习——PyTorch 基础:读写文件详解
一、前言
前面学习 PyTorch 时,我们已经逐步掌握了很多核心内容,比如:
-
如何构造模型
-
如何管理参数
-
如何自定义层
-
如何让模型完成前向传播
但当真正开始训练模型时,很快就会遇到一个非常实际的问题:
模型辛辛苦苦训练了很久,训练完之后怎么保存?
下次还想继续用,怎么重新加载?
不只是模型,普通张量、参数、结果文件又该怎么存?
这就是“读写文件”这一节要解决的问题。
这一部分内容非常重要,因为在真实项目里,我们几乎不可能每次都从头重新训练模型。
很多时候我们都需要:
-
保存训练好的参数
-
下次直接加载继续用
-
保存中间实验结果
-
保存数据张量
-
保存模型状态,避免训练中断后全部白费
所以这一节虽然不像卷积层那样“看起来很炫”,但它属于非常实用、非常高频的基础能力。
二、什么是读写文件
所谓“读写文件”,简单来说就是两件事:
1. 写文件
把程序中的数据保存到磁盘中。
2. 读文件
把磁盘中的数据重新加载回程序中。
在 PyTorch 语境下,常见需要保存的内容包括:
-
张量
Tensor -
模型参数
-
优化器状态
-
训练中间结果
-
字典、列表等 Python 对象
所以这一节的核心问题可以概括成一句话:
如何把训练得到的重要内容保存下来,并在需要时重新恢复。
三、为什么读写文件很重要
1. 模型训练成本高
有些模型训练一次要花很长时间。
如果不保存,程序一关、环境一断,结果就没了。
2. 需要复现实验
很多时候我们希望:
-
记录某次训练得到的最好模型
-
对比不同参数下的实验结果
-
重新加载最佳模型进行测试
这都离不开文件读写。
3. 部署和推理都需要加载模型
真实项目中,模型通常是“训练一次,使用很多次”。
训练完后必须把结果保存起来,部署时再加载。
4. 断点续训很常见
如果训练一半中断,下次继续训练,也需要读取之前保存的状态。
四、PyTorch 中最常用的两个函数
在 PyTorch 中,最常见的文件读写函数是:
torch.save()
torch.load()
你可以先把它们理解成:
-
torch.save():保存对象 -
torch.load():加载对象
这两个函数几乎贯穿整个 PyTorch 实战过程。
五、先从最简单的:保存和读取张量
最基础的情况,就是直接保存一个张量。
例如:
import torch
x = torch.arange(4)
torch.save(x, 'x-file')
这里做的事情很简单:
-
创建一个张量
x -
用
torch.save()把它保存到当前目录下名为x-file的文件中
然后我们就可以把它再读回来:
x2 = torch.load('x-file')
print(x2)
输出就会是原来的张量内容。
六、torch.save() 和 torch.load() 怎么理解
1. torch.save(obj, path)
表示把对象 obj 保存到 path 对应的文件中。
2. torch.load(path)
表示从 path 对应文件中读回之前保存的对象。
你可以把它理解成:
save是“打包存起来”,load是“拆包读出来”。
而且这里保存的不一定只是张量,也可以是更复杂的数据结构。
七、可以保存多个张量吗
当然可以。
如果要保存多个张量,最常见的做法是把它们放进一个列表或者字典里一起保存。
例如保存一个列表:
x = torch.arange(4)
y = torch.zeros(4)
torch.save([x, y], 'x-files')
然后加载:
x2, y2 = torch.load('x-files')
print(x2)
print(y2)
这样就能一次保存和恢复多个张量。
八、为什么常用字典保存而不是列表
虽然列表可以保存多个对象,但实际项目里更常用字典,因为:
-
更清晰
-
更方便按名字取
-
不容易搞混顺序
例如:
mydict = {'x': x, 'y': y}
torch.save(mydict, 'mydict')
读取时:
mydict2 = torch.load('mydict')
print(mydict2['x'])
print(mydict2['y'])
这种方式很适合保存:
-
训练损失
-
验证结果
-
参数记录
-
多组实验输出
所以在真实项目中,字典是非常常见的保存格式。
九、模型能不能直接保存
可以,但这里要区分两种思路:
1. 保存整个模型对象
2. 保存模型参数
虽然两种方式都能实现,但在 PyTorch 中,更推荐保存模型参数,而不是整个模型对象。
这是这一节最重要的一个点。
十、为什么更推荐保存模型参数
因为直接保存整个模型对象虽然方便,但存在一些问题:
1. 依赖模型类定义
加载时必须保证当时定义模型的代码结构没有变化。
2. 可移植性差
不同环境、不同文件组织方式下,容易出问题。
3. 不够灵活
有时我们只想加载参数到一个同结构模型中,而不是整个对象原样恢复。
所以 PyTorch 更常见、更规范的做法是:
保存
state_dict(),也就是模型参数字典。
十一、什么是 state_dict()
前面“参数管理”已经接触过它,这里再强化一下。
state_dict() 可以理解成:
模型当前参数状态的字典。
例如:
from torch import nn
net = nn.Sequential(nn.Linear(2, 2))
print(net.state_dict())
你会看到类似:
-
权重
weight -
偏置
bias
它们都被组织在一个字典中。
所以保存模型最常见的方式就是:
torch.save(net.state_dict(), 'mlp.params')
十二、怎么加载保存好的模型参数
加载参数的流程通常分成三步:
第一步:先重新定义同样结构的模型
clone = nn.Sequential(nn.Linear(2, 2))
第二步:读取保存的参数文件
clone.load_state_dict(torch.load('mlp.params'))
第三步:使用模型
clone.eval()
其中:
-
load_state_dict()用来把读进来的参数装载到模型中 -
eval()表示切换到推理/测试模式
完整示例:
import torch
from torch import nn
net = nn.Sequential(nn.Linear(2, 2))
torch.save(net.state_dict(), 'mlp.params')
clone = nn.Sequential(nn.Linear(2, 2))
clone.load_state_dict(torch.load('mlp.params'))
clone.eval()
十三、为什么加载前要先定义同结构模型
这是很多初学者第一次会困惑的地方。
因为我们保存的是:
参数
而不是:
模型结构本身
所以加载时,PyTorch 需要先知道:
-
这些参数应该放进哪个模型
-
每个参数对应哪一层
-
对应张量形状是什么
因此必须先有一个同样结构的模型“壳子”,然后再把参数装进去。
可以把它理解成:
-
模型结构 = 房子的框架
-
参数 = 家具和物品
你保存的是家具,但加载时先得有房子,才能把家具摆进去。
十四、eval() 为什么重要
模型加载完成后,经常会写一句:
clone.eval()
这表示把模型切换到评估模式。
为什么要这样做?
因为某些层在训练模式和测试模式下行为不一样,比如:
-
Dropout -
BatchNorm
训练时它们会有随机性或统计行为,测试时则需要稳定输出。
因此:
-
训练时:
model.train() -
推理/测试时:
model.eval()
虽然在最简单的 Linear 模型里不一定看出差别,但养成这个习惯很重要。
十五、保存整个模型对象怎么做
虽然更推荐保存 state_dict(),但你也可以直接保存整个模型:
torch.save(net, 'entire-model')
加载时:
net2 = torch.load('entire-model')
这样写看起来很方便,但前面说过,它依赖环境和类定义,不够稳健。
所以一般来说:
-
学习演示可以知道这种方式
-
真正项目里更推荐保存参数字典
十六、实际项目中通常保存什么
在真实训练任务中,往往不只是保存模型参数,还会保存一个更完整的字典。
例如:
torch.save({
'epoch': 10,
'model_state_dict': net.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': 0.123
}, 'checkpoint.pth')
这叫做一个 checkpoint(检查点)。
它通常包含:
-
当前训练到第几轮
-
模型参数
-
优化器状态
-
当前损失值
这样下次不只是能加载模型,还能继续训练。
十七、为什么优化器状态也要保存
很多人一开始只想到保存模型参数,但在继续训练时,优化器状态也很重要。
因为像 Adam、Momentum SGD 这类优化器,不只是简单记录一个学习率,它们内部还维护了历史状态,例如:
-
动量
-
二阶矩估计
-
历史梯度信息
如果不保存优化器状态,那么继续训练时相当于“优化器重新开始”,效果可能和真正断点续训不同。
所以如果只是做推理,可以只保存模型参数;
如果要恢复训练,最好连优化器状态一起保存。
十八、如何读取 checkpoint
例如前面保存了:
torch.save({
'epoch': 10,
'model_state_dict': net.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': 0.123
}, 'checkpoint.pth')
那么加载时通常这样写:
checkpoint = torch.load('checkpoint.pth')
net.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
这样就能把训练状态比较完整地恢复出来。
十九、文件后缀一定要 .pth 吗
不一定。
PyTorch 对文件名后缀没有强制要求,你写成:
-
.pt -
.pth -
.params -
甚至没有后缀
技术上都可以。
但约定俗成中,常见的是:
-
.pth -
.pt
例如:
-
model.pth -
checkpoint.pt
这样别人一看就知道这是 PyTorch 保存的模型文件。
二十、保存路径和当前目录要注意什么
这一点虽然简单,但很实用。
例如:
torch.save(x, 'x-file')
这里表示保存在当前工作目录下。
如果你不确定文件存到哪里了,可以:
-
看运行脚本所在目录
-
打印当前工作路径
-
使用绝对路径保存
例如:
import os
print(os.getcwd())
这在调试“为什么找不到保存文件”时很有帮助。
二十一、一个适合 CSDN 展示的完整示例
下面给你一份适合博客使用的完整示例代码,基本覆盖本节核心内容。
import torch
from torch import nn
# 1. 保存和加载张量
x = torch.arange(4)
torch.save(x, 'x-file')
x2 = torch.load('x-file')
print("加载后的张量 x2:", x2)
# 2. 保存和加载多个张量
y = torch.zeros(4)
torch.save({'x': x, 'y': y}, 'xy-dict')
data = torch.load('xy-dict')
print("data['x'] =", data['x'])
print("data['y'] =", data['y'])
# 3. 定义模型
net = nn.Sequential(nn.Linear(2, 2))
# 4. 保存模型参数
torch.save(net.state_dict(), 'mlp.params')
# 5. 加载模型参数
clone = nn.Sequential(nn.Linear(2, 2))
clone.load_state_dict(torch.load('mlp.params'))
clone.eval()
print("原模型参数:")
print(net.state_dict())
print("加载后的模型参数:")
print(clone.state_dict())
这份代码很适合你放在 CSDN 里讲解,因为它把:
-
张量保存
-
字典保存
-
模型参数保存与加载
串到一起了。
二十二、这一节最容易混淆的几个点
1. 保存模型和保存参数不是一回事
-
保存模型:把整个对象存起来
-
保存参数:只保存
state_dict()
通常更推荐后者。
2. 加载参数前必须先定义模型结构
因为参数本身不知道应该装进哪个模型。
3. eval() 不是“训练完成”的意思
它表示切换到评估模式,影响某些层的行为。
4. 推理和断点续训保存内容不同
-
推理:保存模型参数就够了
-
续训:最好连优化器状态一起保存
5. torch.load() 返回的就是之前保存的对象
如果之前保存的是字典,读出来还是字典;
如果保存的是张量,读出来就是张量。
二十三、这一节的核心思想
如果把“读写文件”这一节压缩成一句话,我觉得最核心的是:
训练结果必须被持久化保存,否则模型训练成果无法复用。
前面我们一直在学习:
-
如何定义模型
-
如何管理参数
-
如何做前向传播
而这一节则让这些工作真正具备“落地价值”。
因为只有学会保存和加载,我们才能:
-
复用训练成果
-
做实验对比
-
实现部署推理
-
实现断点续训
这一步其实是从“会写模型”走向“会做项目”的关键环节。
二十四、我对这一节的理解
学这一节之前,我对 PyTorch 的理解还更多停留在“代码运行时”的层面,感觉模型只存在于程序运行过程中。
但学完“读写文件”之后,我会更清楚地意识到:
模型不仅是运行时对象,它还应该是可以被保存、恢复、迁移和复用的实验资产。
这其实很重要。
因为真实项目里,我们不可能每次都从头开始。一个训练好的模型,往往本身就是最有价值的成果之一。
所以这一节虽然语法不难,但在实际工作流中非常关键。
二十五、结语
“读写文件”是 PyTorch 基础中非常实用的一节。
它不像前面的模型构造那样偏向“定义网络”,而是更贴近真实任务中的模型管理流程。
通过这一节,我们要真正掌握的是:
-
如何保存张量
-
如何保存多个对象
-
如何保存模型参数
-
如何重新加载模型
-
为什么推荐保存
state_dict() -
如何保存 checkpoint 用于断点续训
如果把这些内容弄明白,后面无论是做课程作业、比赛项目,还是更复杂的训练任务,都会顺手很多。
二十六、重点速记版
1. PyTorch 中最常用的文件读写函数是什么
torch.save() 和 torch.load()
2. 可以保存哪些内容
张量、列表、字典、模型参数、优化器状态等。
3. 保存模型时推荐什么方式
推荐保存 state_dict(),也就是模型参数字典。
4. 加载模型参数前要做什么
先定义同样结构的模型。
5. 为什么加载后常写 eval()
切换到评估模式,保证推理行为正确。
6. checkpoint 一般保存什么
模型参数、优化器状态、epoch、loss 等。
7. 只做推理和继续训练有什么区别
推理只需模型参数;继续训练最好连优化器状态一起保存。
以上就是我对《动手学深度学习》中 PyTorch 基础——读写文件 这一节的学习整理。
这一节虽然不涉及复杂模型结构,但它在真实项目中非常重要,因为它直接决定了我们能不能把训练成果保存下来并重复利用。
对于刚开始学 PyTorch 的同学来说,这一节一定要自己动手敲一遍。尤其是“保存 state_dict() 再加载”的流程,后面做项目时几乎一定会反复用到。
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐


所有评论(0)