一、前言

前面学习 PyTorch 时,我们已经逐步掌握了很多核心内容,比如:

  • 如何构造模型

  • 如何管理参数

  • 如何自定义层

  • 如何让模型完成前向传播

但当真正开始训练模型时,很快就会遇到一个非常实际的问题:

模型辛辛苦苦训练了很久,训练完之后怎么保存?
下次还想继续用,怎么重新加载?
不只是模型,普通张量、参数、结果文件又该怎么存?

这就是“读写文件”这一节要解决的问题。

这一部分内容非常重要,因为在真实项目里,我们几乎不可能每次都从头重新训练模型。
很多时候我们都需要:

  • 保存训练好的参数

  • 下次直接加载继续用

  • 保存中间实验结果

  • 保存数据张量

  • 保存模型状态,避免训练中断后全部白费

所以这一节虽然不像卷积层那样“看起来很炫”,但它属于非常实用、非常高频的基础能力。


二、什么是读写文件

所谓“读写文件”,简单来说就是两件事:

1. 写文件

把程序中的数据保存到磁盘中。

2. 读文件

把磁盘中的数据重新加载回程序中。

在 PyTorch 语境下,常见需要保存的内容包括:

  • 张量 Tensor

  • 模型参数

  • 优化器状态

  • 训练中间结果

  • 字典、列表等 Python 对象

所以这一节的核心问题可以概括成一句话:

如何把训练得到的重要内容保存下来,并在需要时重新恢复。


三、为什么读写文件很重要

1. 模型训练成本高

有些模型训练一次要花很长时间。
如果不保存,程序一关、环境一断,结果就没了。

2. 需要复现实验

很多时候我们希望:

  • 记录某次训练得到的最好模型

  • 对比不同参数下的实验结果

  • 重新加载最佳模型进行测试

这都离不开文件读写。

3. 部署和推理都需要加载模型

真实项目中,模型通常是“训练一次,使用很多次”。
训练完后必须把结果保存起来,部署时再加载。

4. 断点续训很常见

如果训练一半中断,下次继续训练,也需要读取之前保存的状态。


四、PyTorch 中最常用的两个函数

在 PyTorch 中,最常见的文件读写函数是:

torch.save()
torch.load()

你可以先把它们理解成:

  • torch.save():保存对象

  • torch.load():加载对象

这两个函数几乎贯穿整个 PyTorch 实战过程。


五、先从最简单的:保存和读取张量

最基础的情况,就是直接保存一个张量。

例如:

import torch

x = torch.arange(4)
torch.save(x, 'x-file')

这里做的事情很简单:

  • 创建一个张量 x

  • torch.save() 把它保存到当前目录下名为 x-file 的文件中

然后我们就可以把它再读回来:

x2 = torch.load('x-file')
print(x2)

输出就会是原来的张量内容。


六、torch.save()torch.load() 怎么理解

1. torch.save(obj, path)

表示把对象 obj 保存到 path 对应的文件中。

2. torch.load(path)

表示从 path 对应文件中读回之前保存的对象。

你可以把它理解成:

save 是“打包存起来”,
load 是“拆包读出来”。

而且这里保存的不一定只是张量,也可以是更复杂的数据结构。


七、可以保存多个张量吗

当然可以。

如果要保存多个张量,最常见的做法是把它们放进一个列表或者字典里一起保存。

例如保存一个列表:

x = torch.arange(4)
y = torch.zeros(4)
torch.save([x, y], 'x-files')

然后加载:

x2, y2 = torch.load('x-files')
print(x2)
print(y2)

这样就能一次保存和恢复多个张量。


八、为什么常用字典保存而不是列表

虽然列表可以保存多个对象,但实际项目里更常用字典,因为:

  • 更清晰

  • 更方便按名字取

  • 不容易搞混顺序

例如:

mydict = {'x': x, 'y': y}
torch.save(mydict, 'mydict')

读取时:

mydict2 = torch.load('mydict')
print(mydict2['x'])
print(mydict2['y'])

这种方式很适合保存:

  • 训练损失

  • 验证结果

  • 参数记录

  • 多组实验输出

所以在真实项目中,字典是非常常见的保存格式


九、模型能不能直接保存

可以,但这里要区分两种思路:

1. 保存整个模型对象

2. 保存模型参数

虽然两种方式都能实现,但在 PyTorch 中,更推荐保存模型参数,而不是整个模型对象。

这是这一节最重要的一个点。


十、为什么更推荐保存模型参数

因为直接保存整个模型对象虽然方便,但存在一些问题:

1. 依赖模型类定义

加载时必须保证当时定义模型的代码结构没有变化。

2. 可移植性差

不同环境、不同文件组织方式下,容易出问题。

3. 不够灵活

有时我们只想加载参数到一个同结构模型中,而不是整个对象原样恢复。

所以 PyTorch 更常见、更规范的做法是:

保存 state_dict(),也就是模型参数字典。


十一、什么是 state_dict()

前面“参数管理”已经接触过它,这里再强化一下。

state_dict() 可以理解成:

模型当前参数状态的字典。

例如:

from torch import nn

net = nn.Sequential(nn.Linear(2, 2))
print(net.state_dict())

你会看到类似:

  • 权重 weight

  • 偏置 bias

它们都被组织在一个字典中。

所以保存模型最常见的方式就是:

torch.save(net.state_dict(), 'mlp.params')

十二、怎么加载保存好的模型参数

加载参数的流程通常分成三步:

第一步:先重新定义同样结构的模型

clone = nn.Sequential(nn.Linear(2, 2))

第二步:读取保存的参数文件

clone.load_state_dict(torch.load('mlp.params'))

第三步:使用模型

clone.eval()

其中:

  • load_state_dict() 用来把读进来的参数装载到模型中

  • eval() 表示切换到推理/测试模式

完整示例:

import torch
from torch import nn

net = nn.Sequential(nn.Linear(2, 2))
torch.save(net.state_dict(), 'mlp.params')

clone = nn.Sequential(nn.Linear(2, 2))
clone.load_state_dict(torch.load('mlp.params'))
clone.eval()

十三、为什么加载前要先定义同结构模型

这是很多初学者第一次会困惑的地方。

因为我们保存的是:

参数

而不是:

模型结构本身

所以加载时,PyTorch 需要先知道:

  • 这些参数应该放进哪个模型

  • 每个参数对应哪一层

  • 对应张量形状是什么

因此必须先有一个同样结构的模型“壳子”,然后再把参数装进去。

可以把它理解成:

  • 模型结构 = 房子的框架

  • 参数 = 家具和物品

你保存的是家具,但加载时先得有房子,才能把家具摆进去。


十四、eval() 为什么重要

模型加载完成后,经常会写一句:

clone.eval()

这表示把模型切换到评估模式

为什么要这样做?

因为某些层在训练模式和测试模式下行为不一样,比如:

  • Dropout

  • BatchNorm

训练时它们会有随机性或统计行为,测试时则需要稳定输出。

因此:

  • 训练时:model.train()

  • 推理/测试时:model.eval()

虽然在最简单的 Linear 模型里不一定看出差别,但养成这个习惯很重要。


十五、保存整个模型对象怎么做

虽然更推荐保存 state_dict(),但你也可以直接保存整个模型:

torch.save(net, 'entire-model')

加载时:

net2 = torch.load('entire-model')

这样写看起来很方便,但前面说过,它依赖环境和类定义,不够稳健。

所以一般来说:

  • 学习演示可以知道这种方式

  • 真正项目里更推荐保存参数字典


十六、实际项目中通常保存什么

在真实训练任务中,往往不只是保存模型参数,还会保存一个更完整的字典。

例如:

torch.save({
    'epoch': 10,
    'model_state_dict': net.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': 0.123
}, 'checkpoint.pth')

这叫做一个 checkpoint(检查点)

它通常包含:

  • 当前训练到第几轮

  • 模型参数

  • 优化器状态

  • 当前损失值

这样下次不只是能加载模型,还能继续训练。


十七、为什么优化器状态也要保存

很多人一开始只想到保存模型参数,但在继续训练时,优化器状态也很重要。

因为像 Adam、Momentum SGD 这类优化器,不只是简单记录一个学习率,它们内部还维护了历史状态,例如:

  • 动量

  • 二阶矩估计

  • 历史梯度信息

如果不保存优化器状态,那么继续训练时相当于“优化器重新开始”,效果可能和真正断点续训不同。

所以如果只是做推理,可以只保存模型参数;
如果要恢复训练,最好连优化器状态一起保存。


十八、如何读取 checkpoint

例如前面保存了:

torch.save({
    'epoch': 10,
    'model_state_dict': net.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': 0.123
}, 'checkpoint.pth')

那么加载时通常这样写:

checkpoint = torch.load('checkpoint.pth')

net.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

这样就能把训练状态比较完整地恢复出来。


十九、文件后缀一定要 .pth

不一定。

PyTorch 对文件名后缀没有强制要求,你写成:

  • .pt

  • .pth

  • .params

  • 甚至没有后缀

技术上都可以。

但约定俗成中,常见的是:

  • .pth

  • .pt

例如:

  • model.pth

  • checkpoint.pt

这样别人一看就知道这是 PyTorch 保存的模型文件。


二十、保存路径和当前目录要注意什么

这一点虽然简单,但很实用。

例如:

torch.save(x, 'x-file')

这里表示保存在当前工作目录下。

如果你不确定文件存到哪里了,可以:

  • 看运行脚本所在目录

  • 打印当前工作路径

  • 使用绝对路径保存

例如:

import os
print(os.getcwd())

这在调试“为什么找不到保存文件”时很有帮助。


二十一、一个适合 CSDN 展示的完整示例

下面给你一份适合博客使用的完整示例代码,基本覆盖本节核心内容。

import torch
from torch import nn

# 1. 保存和加载张量
x = torch.arange(4)
torch.save(x, 'x-file')

x2 = torch.load('x-file')
print("加载后的张量 x2:", x2)

# 2. 保存和加载多个张量
y = torch.zeros(4)
torch.save({'x': x, 'y': y}, 'xy-dict')

data = torch.load('xy-dict')
print("data['x'] =", data['x'])
print("data['y'] =", data['y'])

# 3. 定义模型
net = nn.Sequential(nn.Linear(2, 2))

# 4. 保存模型参数
torch.save(net.state_dict(), 'mlp.params')

# 5. 加载模型参数
clone = nn.Sequential(nn.Linear(2, 2))
clone.load_state_dict(torch.load('mlp.params'))
clone.eval()

print("原模型参数:")
print(net.state_dict())

print("加载后的模型参数:")
print(clone.state_dict())

这份代码很适合你放在 CSDN 里讲解,因为它把:

  • 张量保存

  • 字典保存

  • 模型参数保存与加载

串到一起了。


二十二、这一节最容易混淆的几个点

1. 保存模型和保存参数不是一回事

  • 保存模型:把整个对象存起来

  • 保存参数:只保存 state_dict()

通常更推荐后者。

2. 加载参数前必须先定义模型结构

因为参数本身不知道应该装进哪个模型。

3. eval() 不是“训练完成”的意思

它表示切换到评估模式,影响某些层的行为。

4. 推理和断点续训保存内容不同

  • 推理:保存模型参数就够了

  • 续训:最好连优化器状态一起保存

5. torch.load() 返回的就是之前保存的对象

如果之前保存的是字典,读出来还是字典;
如果保存的是张量,读出来就是张量。


二十三、这一节的核心思想

如果把“读写文件”这一节压缩成一句话,我觉得最核心的是:

训练结果必须被持久化保存,否则模型训练成果无法复用。

前面我们一直在学习:

  • 如何定义模型

  • 如何管理参数

  • 如何做前向传播

而这一节则让这些工作真正具备“落地价值”。

因为只有学会保存和加载,我们才能:

  • 复用训练成果

  • 做实验对比

  • 实现部署推理

  • 实现断点续训

这一步其实是从“会写模型”走向“会做项目”的关键环节。


二十四、我对这一节的理解

学这一节之前,我对 PyTorch 的理解还更多停留在“代码运行时”的层面,感觉模型只存在于程序运行过程中。
但学完“读写文件”之后,我会更清楚地意识到:

模型不仅是运行时对象,它还应该是可以被保存、恢复、迁移和复用的实验资产。

这其实很重要。
因为真实项目里,我们不可能每次都从头开始。一个训练好的模型,往往本身就是最有价值的成果之一。

所以这一节虽然语法不难,但在实际工作流中非常关键。


二十五、结语

“读写文件”是 PyTorch 基础中非常实用的一节。
它不像前面的模型构造那样偏向“定义网络”,而是更贴近真实任务中的模型管理流程。

通过这一节,我们要真正掌握的是:

  • 如何保存张量

  • 如何保存多个对象

  • 如何保存模型参数

  • 如何重新加载模型

  • 为什么推荐保存 state_dict()

  • 如何保存 checkpoint 用于断点续训

如果把这些内容弄明白,后面无论是做课程作业、比赛项目,还是更复杂的训练任务,都会顺手很多。


二十六、重点速记版

1. PyTorch 中最常用的文件读写函数是什么

torch.save()torch.load()

2. 可以保存哪些内容

张量、列表、字典、模型参数、优化器状态等。

3. 保存模型时推荐什么方式

推荐保存 state_dict(),也就是模型参数字典。

4. 加载模型参数前要做什么

先定义同样结构的模型。

5. 为什么加载后常写 eval()

切换到评估模式,保证推理行为正确。

6. checkpoint 一般保存什么

模型参数、优化器状态、epoch、loss 等。

7. 只做推理和继续训练有什么区别

推理只需模型参数;继续训练最好连优化器状态一起保存。

以上就是我对《动手学深度学习》中 PyTorch 基础——读写文件 这一节的学习整理。
这一节虽然不涉及复杂模型结构,但它在真实项目中非常重要,因为它直接决定了我们能不能把训练成果保存下来并重复利用。

对于刚开始学 PyTorch 的同学来说,这一节一定要自己动手敲一遍。尤其是“保存 state_dict() 再加载”的流程,后面做项目时几乎一定会反复用到。


Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐