一文看明白PyTorch 模型设计训练保存加载预测
·
需求
代码样例
包含训练 → 保存 → 加载 → 预测,代码可以直接运行:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
# -----------------------------
# 1. 定义模型
# -----------------------------
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.fc1 = nn.Linear(128, 96)
self.fc2 = nn.Linear(96, 64)
self.fc3 = nn.Linear(64, 32)
self.relu = nn.ReLU()
self.dropout = nn.Dropout(0.2)
def forward(self, x):
x = self.relu(self.fc1(x))
x = self.dropout(x)
x = self.relu(self.fc2(x))
x = self.dropout(x)
out = self.fc3(x)
return out
# -----------------------------
# 2. 准备数据 (示例随机数据)
# -----------------------------
X = torch.randn(1000, 128)
y = torch.randn(1000, 32)
dataset = TensorDataset(X, y)
batch_size = 32
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# -----------------------------
# 3. 定义损失函数和优化器
# MSELoss Mean Squared Error(均方误差)
# -----------------------------
model = SimpleModel()
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# -----------------------------
# 4. 训练循环
# -----------------------------
num_epochs = 20
for epoch in range(num_epochs):
model.train() # 训练模式
epoch_loss = 0
for batch_X, batch_y in dataloader:
optimizer.zero_grad()
outputs = model(batch_X)
loss = criterion(outputs, batch_y)
loss.backward()
optimizer.step()
epoch_loss += loss.item() * batch_X.size(0)
epoch_loss /= len(dataset)
print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}")
# -----------------------------
# 5. 保存训练好的模型参数
# -----------------------------
torch.save(model.state_dict(), "simple_model.pth")
print("模型参数已保存到 simple_model.pth")
# -----------------------------
# 6. 加载模型进行预测
# -----------------------------
# 重新创建模型对象
model_loaded = SimpleModel()
# 加载保存的参数
model_loaded.load_state_dict(torch.load("simple_model.pth"))
# 切换到评估模式
model_loaded.eval()
# 假设有新样本 x_new
x_new = torch.randn(5, 128)
with torch.no_grad(): # 推理时禁用梯度
y_pred = model_loaded(x_new)
print("加载模型预测结果形状:", y_pred.shape) # [5, 32]
✅ 特点
- 训练完成后保存权重,
simple_model.pth可以随时加载。 - 加载模型时必须重新创建类,然后
load_state_dict。 - 推理时切换到
eval()模式,保证 Dropout 不随机失活。 - 使用
torch.no_grad()提升预测效率,减少显存占用。

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐


所有评论(0)