需求

输入x
128维

fc1
Linear 128→96

ReLU激活

Dropout 0.2

fc2
Linear 96→64

ReLU激活

Dropout 0.2

fc3
Linear 64→32

输出out
32维

代码样例

包含训练 → 保存 → 加载 → 预测,代码可以直接运行:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

# -----------------------------
# 1. 定义模型
# -----------------------------
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc1 = nn.Linear(128, 96)
        self.fc2 = nn.Linear(96, 64)
        self.fc3 = nn.Linear(64, 32)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.2)

    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.relu(self.fc2(x))
        x = self.dropout(x)
        out = self.fc3(x)
        return out

# -----------------------------
# 2. 准备数据 (示例随机数据)
# -----------------------------
X = torch.randn(1000, 128)
y = torch.randn(1000, 32)

dataset = TensorDataset(X, y)
batch_size = 32
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# -----------------------------
# 3. 定义损失函数和优化器
# MSELoss Mean Squared Error(均方误差)
# -----------------------------
model = SimpleModel()
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# -----------------------------
# 4. 训练循环
# -----------------------------
num_epochs = 20
for epoch in range(num_epochs):
    model.train()  # 训练模式
    epoch_loss = 0
    for batch_X, batch_y in dataloader:
        optimizer.zero_grad()
        outputs = model(batch_X)
        loss = criterion(outputs, batch_y)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item() * batch_X.size(0)
    
    epoch_loss /= len(dataset)
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}")

# -----------------------------
# 5. 保存训练好的模型参数
# -----------------------------
torch.save(model.state_dict(), "simple_model.pth")
print("模型参数已保存到 simple_model.pth")

# -----------------------------
# 6. 加载模型进行预测
# -----------------------------
# 重新创建模型对象
model_loaded = SimpleModel()

# 加载保存的参数
model_loaded.load_state_dict(torch.load("simple_model.pth"))

# 切换到评估模式
model_loaded.eval()

# 假设有新样本 x_new
x_new = torch.randn(5, 128)
with torch.no_grad():  # 推理时禁用梯度
    y_pred = model_loaded(x_new)

print("加载模型预测结果形状:", y_pred.shape)  # [5, 32]

✅ 特点

  1. 训练完成后保存权重simple_model.pth 可以随时加载。
  2. 加载模型时必须重新创建类,然后 load_state_dict
  3. 推理时切换到 eval() 模式,保证 Dropout 不随机失活。
  4. 使用 torch.no_grad() 提升预测效率,减少显存占用。

在这里插入图片描述

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐