025.模型导出:从PyTorch到ONNX/TorchScript的实战与踩坑手记
一、深夜报警:模型在训练端跑得好好的,部署端直接崩了
上周三凌晨两点,手机突然狂震——生产环境的目标检测服务挂了。日志里赫然一行:
RuntimeError: Expected tensor for argument #1 'input' to have the same dimension
训练时用PyTorch 1.8,部署环境却是C++推理服务。问题就出在模型导出这一步:训练脚本里动态尺寸输入跑得欢,导出成TorchScript后死活不认变长输入。这个坑让我熬了个通宵,也让我彻底明白:模型导出不是点一下torch.onnx.export()就完事的玄学操作,而是连接训练与部署的关键桥梁。
今天咱们就聊聊PyTorch模型导出的那些实战细节,特别是ONNX和TorchScript这两个主流格式。我会把踩过的坑、绕过的路都摊开来写,你跟着做能省下不少调试时间。
二、TorchScript:PyTorch亲儿子的序列化方案
先看TorchScript,这是PyTorch自家的部署格式,兼容性最好。导出有两种方式:追踪(Tracing) 和脚本化(Scripting)。
2.1 追踪模式:简单但有限制
import torch
from models.yolo import Model # 你的YOLO模型类
model = Model('yolov5s.yaml') # 加载配置
model.load_state_dict(torch.load('best.pt')['model']) # 加载权重
model.eval()
# 关键在这里:准备一个示例输入
example_input = torch.rand(1, 3, 640, 640) # 固定尺寸!
# 追踪导出
traced_script = torch.jit.trace(model, example_input)
traced_script.save("yolo_traced.pt")
# 测试一下
with torch.no_grad():
output = traced_script(torch.rand(1, 3, 640, 640))
print(output.shape) # 正常
# output2 = traced_script(torch.rand(1, 3, 320, 320)) # 这个会报错!尺寸必须和example_input一致
踩坑点1:追踪模式会记录下example_input这个具体张量在模型里的流动路径。如果你的模型有动态控制流(比如if-else分支依赖输入值),追踪只会记录当时走的那条路,其他分支就丢了。
踩坑点2:输入尺寸被写死。上面代码里用(1,3,640,640)导出的模型,推理时就必须是这个尺寸。想支持动态尺寸?得用脚本化。
2.2 脚本化模式:支持动态逻辑
# 在模型类里加装饰器
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
# 你的层定义
@torch.jit.export # 显式标记要导出的方法
def forward(self, x):
# 你的前向逻辑
if x.mean() > 0: # 动态控制流,追踪模式处理不了
return self.path_a(x)
else:
return self.path_b(x)
# 脚本化导出
scripted_model = torch.jit.script(model)
scripted_model.save("yolo_scripted.pt")
脚本化会真正解析Python代码,所以能处理条件分支、循环。但代价是:你的模型代码必须符合TorchScript的语法子集。这意味着:
- 不能有复杂的Python类型注解(用
List[Tensor]别用list) - 不能调用外部Python函数(除非也用
@torch.jit.script装饰) - 列表推导、字典操作受限
个人习惯:我通常先用追踪模式快速验证,如果模型简单且输入尺寸固定,这就够了。遇到动态逻辑或需要多尺寸支持时,再忍痛改代码适配脚本化。
三、ONNX:生态更广的开放格式
ONNX的优势在于跨框架:PyTorch导出,可以用TensorRT、OpenVINO、ONNX Runtime等各种后端推理。但导出过程更像走钢丝,平衡不好就掉坑里。
3.1 基础导出:一堆参数要看准
import torch.onnx
# 还是那个模型和示例输入
model.eval()
example_input = torch.rand(1, 3, 640, 640)
# 导出核心调用
torch.onnx.export(
model,
example_input,
"yolo.onnx",
export_params=True, # 把模型参数也导进去
opset_version=13, # 这个很重要!版本不对算子可能不支持
do_constant_folding=True, # 常量折叠优化,一般开着
input_names=["images"], # 输入节点名,后面推理用
output_names=["output"], # 输出节点名
dynamic_axes={
"images": {0: "batch", 2: "height", 3: "width"}, # 动态轴!支持变尺寸
"output": {0: "batch"}
}
)
关键参数解读:
opset_version:ONNX算子集版本。YOLOv5/v7用的Focus层需要opset>=11,某些新算子需要更高版本。先查清楚你的模型用了什么特殊算子。dynamic_axes:这是实现动态尺寸的关键。上面配置表示第0维(batch)、第2维(高)、第3维(宽)可变。如果你训练时用了多尺度,这里必须设对。
3.2 验证:别等到部署才发现问题
import onnx
import onnxruntime as ort
# 1. 检查模型格式是否正确
onnx_model = onnx.load("yolo.onnx")
onnx.checker.check_model(onnx_model) # 语法检查
print(onnx.helper.printable_graph(onnx_model.graph)) # 看一眼计算图
# 2. 用ONNX Runtime推理测试
ort_session = ort.InferenceSession("yolo.onnx")
ort_inputs = {ort_session.get_inputs()[0].name: example_input.numpy()}
ort_outputs = ort_session.run(None, ort_inputs)
# 3. 和PyTorch结果对比
with torch.no_grad():
torch_output = model(example_input)
import numpy as np
print("输出差值:", np.max(np.abs(ort_outputs[0] - torch_output.numpy()))) # 应该很小(1e-7量级)
常见坑:验证通过但推理结果不对?大概率是预处理/后处理没对齐。PyTorch里可能做了归一化(/255.0),导出时如果没包含进计算图,部署端就得自己补上。
四、YOLO模型导出的特殊处理
YOLO系列模型导出时有几个高频坑点:
4.1 后处理别丢
# 错误示范:只导出主干网络
class Detector(torch.nn.Module):
def forward(self, x):
features = self.backbone(x)
# 这里少了检测头的解码、NMS
return features # 这样导出的模型输出是原始特征图,不是最终检测框
# 正确做法:把后处理打包(如果推理引擎支持)
class DetectorWithPostprocess(torch.nn.Module):
def forward(self, x):
pred = self.model(x) # 原始输出
boxes, scores, classes = self.non_max_suppression(pred) # 包含NMS
return boxes, scores, classes
但注意:很多部署框架(如TensorRT)有自己优化过的NMS算子。我通常做法是导出时不带NMS,在部署端用框架的NMS实现,性能更好。
4.2 动态尺寸与批处理
# 如果你想支持批量推理和变尺寸
dynamic_axes={
"images": {
0: "batch_size",
2: "height",
3: "width"
},
"output": {
0: "batch_size"
}
}
实际部署时,如果用了TensorRT,动态尺寸会显著增加引擎构建时间。生产环境如果尺寸固定,最好导出固定尺寸模型。
4.3 自定义算子处理
YOLOv5的Focus层(切片+拼接)在opset 11以下不支持,要么:
- 升级opset_version到11+
- 或者把Focus层替换为等价的Conv层(YOLOv5官方提供了替换脚本)
# 替换Focus层的技巧
from models.common import Focus
class MyModel(nn.Module):
def __init__(self):
super().__init__()
# 训练时用Focus
self.focus = Focus(...)
def forward(self, x):
if self.training:
return self.focus(x)
else:
# 导出时用等效卷积
return self.equivalent_conv(x)
训练和导出用不同路径,这是常见技巧。
五、调试:当导出失败时怎么办
- 看错误栈:ONNX导出失败会打印不支持的算子或操作,照着改。
- 简化模型:从单层开始导,逐步增加复杂度,定位问题层。
- 用
torch.onnx.export(verbose=True):打印计算图,看哪里断了。 - 查算子支持表:ONNX的算子文档和PyTorch的torch.onnx文档都列了支持情况。
我电脑里有个“导出失败记录.md”,里面记着:
torch.tensor.tolist()在脚本化里不行,得用torch.unbind()torch.arange()的步长参数在ONNX opset 11前后语法变了- LSTM的
hidden_size和input_size参数顺序容易搞反
六、经验性建议
-
导出前先冻住模型:
model.eval()是必须的,但别忘了还有torch.no_grad()上下文。BatchNorm和Dropout层在训练和评估模式行为不同。 -
版本对齐要死磕:PyTorch版本、ONNX版本、推理框架版本(TensorRT等)的兼容性矩阵,先查清楚再动手。我吃过亏:本地1.8导出的ONNX,生产环境1.7解析不了。
-
留个PyTorch备份:ONNX/TorchScript模型导出后,一定保留原始的PyTorch模型权重(
.pt文件)。哪天导出格式不兼容了,还能用新工具重新导。 -
动态尺寸是双刃剑:开发阶段图方便开了动态尺寸,生产环境如果尺寸固定,建议重新导出静态模型,推理速度能快20%以上。
-
验证要全面:别只测一张图。准备个小型测试集(10-20张),覆盖各种尺寸、宽高比,对比PyTorch和导出模型的mAP差异。我见过导出后精度掉5个点的,原因是某个激活函数量化异常。
-
文档随代码走:在导出脚本里用注释写明:“此模型用opset 14导出,支持动态高宽,但不支持批量可变”。三个月后你自己都记不清。
模型导出这事,就像给训练好的模型做一次“脱水处理”——去掉训练时的冗余,保留推理必需的骨架。刚开始会觉得束手束脚,多踩几次坑就摸出门道了。记住:一次成功的导出,始于训练代码时就考虑部署约束。下次写模型时,不妨先想想:“这代码能顺利导出吗?” 这会省去你无数个调试的深夜。
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐

所有评论(0)