首先用一张测试图片测试下pytorch模型,然后将pytorch模型转换为onnx模型,然后利用onnxruntime测试转换后的onnx模型.

#!usr/bin/env python
# -*- coding:utf-8 _*-

import torch
import torch.nn as nn
from torchvision import models
import os
from torchvision import transforms
import time
import datetime
import cv2
from PIL import Image
import onnx
import onnxruntime
import torch.nn.functional as F


class MobileNet(nn.Module):
    def __init__(self, num_classes=4):
        super(MobileNet, self).__init__()
        net = models.mobilenet_v2(pretrained=True)
        self.features = net.features
        self.classifier = nn.Sequential(
            nn.Linear(1280, 1000),
            nn.ReLU(True),
            nn.Dropout(0.5),
            nn.Linear(1000, num_classes))

    def forward(self, x):
        x = self.features(x)
        x = x.mean([2, 3])  #输入和输出格式均为(n, c, h, w)mean([2, 3])将宽高加和后取平均值,特征经过处理后由(1, 1280, 7, 7)变为(1, 1280, 1, 1)
        logit = self.classifier(x)
        return logit

#算法初始化
num_class = 4
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net = MobileNet(num_class).to(device)
model_path = "./MobileNet_clothes.pth"
net.load_state_dict(torch.load(model_path))
net.to(device)
net.eval()
transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])


#算法推理
img_input = cv2.imread("./clothes.jpg")
img_input = Image.fromarray(img_input.astype('uint8')[:,:,::-1], mode='RGB')
img_input = transform(img_input)
img_input = img_input.unsqueeze(0).to(device)
logist = net(img_input)
prob = torch.softmax(logist, dim = 1)
score = prob.max(dim = 1).values.item()
pred = torch.argmax(logist, dim = 1)
print("score:", score)
print("int(pred):", int(pred))


#导出onnx模型
torch.onnx.export(net, img_input, "./MobileNet_clothes.onnx", input_names=['input'], output_names=['output'])


#测试onnx模型
# load onnx model
onnx_model = onnx.load("./MobileNet_clothes.onnx")
# check onnx model
onnx.checker.check_model(onnx_model)
ort_session = onnxruntime.InferenceSession("./MobileNet_clothes.onnx")

def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

# compute ONNX Runtime output prediction
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(img_input)}
ort_outs = ort_session.run(None, ort_inputs)

# softmax
tensor_ort_out = torch.from_numpy(ort_outs[0])
onnx_test_out = F.softmax(tensor_ort_out, dim=1)
onnx_score = onnx_test_out.max(dim = 1).values.item()
onnx_pred = torch.argmax(onnx_test_out, dim = 1)
print("onnx_score:", onnx_score)
print("onnx_pred:", int(onnx_pred))

print("the onnx result is {}".format(onnx_test_out))

参考文献:

深度学习之格式转换笔记(一):模型文件pt转换成onnx格式详解   https://haxibiao.com/article/99521/

记mobilenet_v2的pytorch模型转onnx模型再转ncnn模型一段不堪回首的历程  记mobilenet_v2的pytorch模型转onnx模型再转ncnn模型一段不堪回首的历程_半路出家的猿人的博客-CSDN博客

mobilenetv2的Pytorch模型转onnx模型再转ncnn模型  mobilenetv2的Pytorch模型转onnx模型再转ncnn模型_逮仔的博客-CSDN博客

GitHub 加速计划 / on / onnxruntime
17
3
下载
microsoft/onnxruntime: 是一个用于运行各种机器学习模型的开源库。适合对机器学习和深度学习有兴趣的人,特别是在开发和部署机器学习模型时需要处理各种不同框架和算子的人。特点是支持多种机器学习框架和算子,包括 TensorFlow、PyTorch、Caffe 等,具有高性能和广泛的兼容性。
最近提交(Master分支:3 个月前 )
ebdbbb75 ### Description <!-- Describe your changes. --> 1. Add support for throwing error when hardware is not supported for VitisAI. 2. Add support for unloading VitisAI EP. 3. Add API for Win25. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> This is requirement for Win25 7 小时前
68061740 ### Description This change fixes the WebGPU delay load test. <details> <summary>Fix UB in macro</summary> The following C++ code outputs `2, 1` in MSVC, while it outputs `1, 1` in GCC: ```c++ #include <iostream> #define A 1 #define B 1 #define ENABLE defined(A) && defined(B) #if ENABLE int x = 1; #else int x = 2; #endif #if defined(A) && defined(B) int y = 1; #else int y = 2; #endif int main() { std::cout << x << ", " << y << "\n"; } ``` Clang reports `macro expansion producing 'defined' has undefined behavior [-Wexpansion-to-defined]`. </details> <details> <summary>Fix condition of build option onnxruntime_ENABLE_DELAY_LOADING_WIN_DLLS</summary> Delay load is explicitly disabled when python binding is being built. modifies the condition. </details> 16 小时前
Logo

旨在为数千万中国开发者提供一个无缝且高效的云端环境,以支持学习、使用和贡献开源项目。

更多推荐