
#!usr/bin/env python
# -*- coding:utf-8 _*-

import torch
import torch.nn as nn
from torchvision import models
import os
from torchvision import transforms
import time
import datetime
import cv2
from PIL import Image
import onnx
import onnxruntime
import torch.nn.functional as F

class MobileNet(nn.Module):
    def __init__(self, num_classes=4):
        super(MobileNet, self).__init__()
        net = models.mobilenet_v2(pretrained=True)
        self.features = net.features
        self.classifier = nn.Sequential(
            nn.Linear(1280, 1000),
            nn.Linear(1000, num_classes))

    def forward(self, x):
        x = self.features(x)
        x = x.mean([2, 3])  #输入和输出格式均为(n, c, h, w)mean([2, 3])将宽高加和后取平均值,特征经过处理后由(1, 1280, 7, 7)变为(1, 1280, 1, 1)
        logit = self.classifier(x)
        return logit

num_class = 4
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net = MobileNet(num_class).to(device)
model_path = "./MobileNet_clothes.pth"
transform = transforms.Compose([
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),

img_input = cv2.imread("./clothes.jpg")
img_input = Image.fromarray(img_input.astype('uint8')[:,:,::-1], mode='RGB')
img_input = transform(img_input)
img_input = img_input.unsqueeze(0).to(device)
logist = net(img_input)
prob = torch.softmax(logist, dim = 1)
score = prob.max(dim = 1).values.item()
pred = torch.argmax(logist, dim = 1)
print("score:", score)
print("int(pred):", int(pred))

torch.onnx.export(net, img_input, "./MobileNet_clothes.onnx", input_names=['input'], output_names=['output'])

# load onnx model
onnx_model = onnx.load("./MobileNet_clothes.onnx")
# check onnx model
ort_session = onnxruntime.InferenceSession("./MobileNet_clothes.onnx")

def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

# compute ONNX Runtime output prediction
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(img_input)}
ort_outs = ort_session.run(None, ort_inputs)

# softmax
tensor_ort_out = torch.from_numpy(ort_outs[0])
onnx_test_out = F.softmax(tensor_ort_out, dim=1)
onnx_score = onnx_test_out.max(dim = 1).values.item()
onnx_pred = torch.argmax(onnx_test_out, dim = 1)
print("onnx_score:", onnx_score)
print("onnx_pred:", int(onnx_pred))

print("the onnx result is {}".format(onnx_test_out))


