一、深夜报警:模型在训练端跑得好好的,部署端直接崩了

上周三凌晨两点,手机突然狂震——生产环境的目标检测服务挂了。日志里赫然一行:

RuntimeError: Expected tensor for argument #1 'input' to have the same dimension

训练时用PyTorch 1.8,部署环境却是C++推理服务。问题就出在模型导出这一步:训练脚本里动态尺寸输入跑得欢,导出成TorchScript后死活不认变长输入。这个坑让我熬了个通宵,也让我彻底明白:模型导出不是点一下torch.onnx.export()就完事的玄学操作,而是连接训练与部署的关键桥梁

今天咱们就聊聊PyTorch模型导出的那些实战细节,特别是ONNX和TorchScript这两个主流格式。我会把踩过的坑、绕过的路都摊开来写,你跟着做能省下不少调试时间。


二、TorchScript:PyTorch亲儿子的序列化方案

先看TorchScript,这是PyTorch自家的部署格式,兼容性最好。导出有两种方式:追踪(Tracing)脚本化(Scripting)

2.1 追踪模式:简单但有限制

import torch
from models.yolo import Model  # 你的YOLO模型类

model = Model('yolov5s.yaml')  # 加载配置
model.load_state_dict(torch.load('best.pt')['model'])  # 加载权重
model.eval()

# 关键在这里:准备一个示例输入
example_input = torch.rand(1, 3, 640, 640)  # 固定尺寸!

# 追踪导出
traced_script = torch.jit.trace(model, example_input)
traced_script.save("yolo_traced.pt")

# 测试一下
with torch.no_grad():
    output = traced_script(torch.rand(1, 3, 640, 640))
    print(output.shape)  # 正常
    # output2 = traced_script(torch.rand(1, 3, 320, 320))  # 这个会报错!尺寸必须和example_input一致

踩坑点1:追踪模式会记录下example_input这个具体张量在模型里的流动路径。如果你的模型有动态控制流(比如if-else分支依赖输入值),追踪只会记录当时走的那条路,其他分支就丢了。

踩坑点2:输入尺寸被写死。上面代码里用(1,3,640,640)导出的模型,推理时就必须是这个尺寸。想支持动态尺寸?得用脚本化。

2.2 脚本化模式:支持动态逻辑

# 在模型类里加装饰器
class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        # 你的层定义
    
    @torch.jit.export  # 显式标记要导出的方法
    def forward(self, x):
        # 你的前向逻辑
        if x.mean() > 0:  # 动态控制流,追踪模式处理不了
            return self.path_a(x)
        else:
            return self.path_b(x)

# 脚本化导出
scripted_model = torch.jit.script(model)
scripted_model.save("yolo_scripted.pt")

脚本化会真正解析Python代码,所以能处理条件分支、循环。但代价是:你的模型代码必须符合TorchScript的语法子集。这意味着:

  • 不能有复杂的Python类型注解(用List[Tensor]别用list
  • 不能调用外部Python函数(除非也用@torch.jit.script装饰)
  • 列表推导、字典操作受限

个人习惯:我通常先用追踪模式快速验证,如果模型简单且输入尺寸固定,这就够了。遇到动态逻辑或需要多尺寸支持时,再忍痛改代码适配脚本化。


三、ONNX:生态更广的开放格式

ONNX的优势在于跨框架:PyTorch导出,可以用TensorRT、OpenVINO、ONNX Runtime等各种后端推理。但导出过程更像走钢丝,平衡不好就掉坑里。

3.1 基础导出:一堆参数要看准

import torch.onnx

# 还是那个模型和示例输入
model.eval()
example_input = torch.rand(1, 3, 640, 640)

# 导出核心调用
torch.onnx.export(
    model,
    example_input,
    "yolo.onnx",
    export_params=True,      # 把模型参数也导进去
    opset_version=13,        # 这个很重要!版本不对算子可能不支持
    do_constant_folding=True, # 常量折叠优化,一般开着
    input_names=["images"],   # 输入节点名,后面推理用
    output_names=["output"],  # 输出节点名
    dynamic_axes={
        "images": {0: "batch", 2: "height", 3: "width"},  # 动态轴!支持变尺寸
        "output": {0: "batch"}
    }
)

关键参数解读

  • opset_version:ONNX算子集版本。YOLOv5/v7用的Focus层需要opset>=11,某些新算子需要更高版本。先查清楚你的模型用了什么特殊算子。
  • dynamic_axes:这是实现动态尺寸的关键。上面配置表示第0维(batch)、第2维(高)、第3维(宽)可变。如果你训练时用了多尺度,这里必须设对。

3.2 验证:别等到部署才发现问题

import onnx
import onnxruntime as ort

# 1. 检查模型格式是否正确
onnx_model = onnx.load("yolo.onnx")
onnx.checker.check_model(onnx_model)  # 语法检查
print(onnx.helper.printable_graph(onnx_model.graph))  # 看一眼计算图

# 2. 用ONNX Runtime推理测试
ort_session = ort.InferenceSession("yolo.onnx")
ort_inputs = {ort_session.get_inputs()[0].name: example_input.numpy()}
ort_outputs = ort_session.run(None, ort_inputs)

# 3. 和PyTorch结果对比
with torch.no_grad():
    torch_output = model(example_input)

import numpy as np
print("输出差值:", np.max(np.abs(ort_outputs[0] - torch_output.numpy())))  # 应该很小(1e-7量级)

常见坑:验证通过但推理结果不对?大概率是预处理/后处理没对齐。PyTorch里可能做了归一化(/255.0),导出时如果没包含进计算图,部署端就得自己补上。


四、YOLO模型导出的特殊处理

YOLO系列模型导出时有几个高频坑点:

4.1 后处理别丢

# 错误示范:只导出主干网络
class Detector(torch.nn.Module):
    def forward(self, x):
        features = self.backbone(x)
        # 这里少了检测头的解码、NMS
        return features  # 这样导出的模型输出是原始特征图,不是最终检测框

# 正确做法:把后处理打包(如果推理引擎支持)
class DetectorWithPostprocess(torch.nn.Module):
    def forward(self, x):
        pred = self.model(x)  # 原始输出
        boxes, scores, classes = self.non_max_suppression(pred)  # 包含NMS
        return boxes, scores, classes

但注意:很多部署框架(如TensorRT)有自己优化过的NMS算子。我通常做法是导出时不带NMS,在部署端用框架的NMS实现,性能更好。

4.2 动态尺寸与批处理

# 如果你想支持批量推理和变尺寸
dynamic_axes={
    "images": {
        0: "batch_size",
        2: "height",
        3: "width"
    },
    "output": {
        0: "batch_size"
    }
}

实际部署时,如果用了TensorRT,动态尺寸会显著增加引擎构建时间。生产环境如果尺寸固定,最好导出固定尺寸模型。

4.3 自定义算子处理

YOLOv5的Focus层(切片+拼接)在opset 11以下不支持,要么:

  1. 升级opset_version到11+
  2. 或者把Focus层替换为等价的Conv层(YOLOv5官方提供了替换脚本)
# 替换Focus层的技巧
from models.common import Focus

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        # 训练时用Focus
        self.focus = Focus(...)
    
    def forward(self, x):
        if self.training:
            return self.focus(x)
        else:
            # 导出时用等效卷积
            return self.equivalent_conv(x)

训练和导出用不同路径,这是常见技巧。


五、调试:当导出失败时怎么办

  1. 看错误栈:ONNX导出失败会打印不支持的算子或操作,照着改。
  2. 简化模型:从单层开始导,逐步增加复杂度,定位问题层。
  3. torch.onnx.export(verbose=True):打印计算图,看哪里断了。
  4. 查算子支持表:ONNX的算子文档和PyTorch的torch.onnx文档都列了支持情况。

我电脑里有个“导出失败记录.md”,里面记着:

  • torch.tensor.tolist()在脚本化里不行,得用torch.unbind()
  • torch.arange()的步长参数在ONNX opset 11前后语法变了
  • LSTM的hidden_sizeinput_size参数顺序容易搞反

六、经验性建议

  1. 导出前先冻住模型model.eval()是必须的,但别忘了还有torch.no_grad()上下文。BatchNorm和Dropout层在训练和评估模式行为不同。

  2. 版本对齐要死磕:PyTorch版本、ONNX版本、推理框架版本(TensorRT等)的兼容性矩阵,先查清楚再动手。我吃过亏:本地1.8导出的ONNX,生产环境1.7解析不了。

  3. 留个PyTorch备份:ONNX/TorchScript模型导出后,一定保留原始的PyTorch模型权重(.pt文件)。哪天导出格式不兼容了,还能用新工具重新导。

  4. 动态尺寸是双刃剑:开发阶段图方便开了动态尺寸,生产环境如果尺寸固定,建议重新导出静态模型,推理速度能快20%以上。

  5. 验证要全面:别只测一张图。准备个小型测试集(10-20张),覆盖各种尺寸、宽高比,对比PyTorch和导出模型的mAP差异。我见过导出后精度掉5个点的,原因是某个激活函数量化异常。

  6. 文档随代码走:在导出脚本里用注释写明:“此模型用opset 14导出,支持动态高宽,但不支持批量可变”。三个月后你自己都记不清。


模型导出这事,就像给训练好的模型做一次“脱水处理”——去掉训练时的冗余,保留推理必需的骨架。刚开始会觉得束手束脚,多踩几次坑就摸出门道了。记住:一次成功的导出,始于训练代码时就考虑部署约束。下次写模型时,不妨先想想:“这代码能顺利导出吗?” 这会省去你无数个调试的深夜。

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐