由于某些原因,需要用到通过python调用yolov5的onnx模型的情况,于是开始一天的踩雷!!!

onnx模型中是包含类别信息,不用跟网上的一些demo中非要有个class.txt文件

# 加载模型
session = ort.InferenceSession('yolov5n.onnx')

# 获取模型的元数据信息
metadata = session.get_modelmeta()

# 获取自定义元数据中的类别信息
custom_metadata_map = metadata.custom_metadata_map
name_class = custom_metadata_map.get("names")

fixed_json = '{' + ', '.join(f'"{k}": "{v}"' for k, v in eval(name_class.replace("'", '"')).items()) + '}'  

# 解析类别信息
names = json.loads(fixed_json)

# 获取输入信息
input_info = session.get_inputs()
node_info = input_info[0]  

# 获取输入的形状信息
net_shape = node_info.shape
count = net_shape[0]
channels = net_shape[1]
net_height = net_shape[2]
net_width = net_shape[3]

# 输出获取到的信息
print("Names:", names)
print("Count:", count)
print("Channels:", channels)
print("Net Height:", net_height)
print("Net Width:", net_width)

通过session进行推测时出现了个关键的错误!

# 运行模型推理
outputs = session.run(None, {node_info.name: input_image})

错误信息:[ONNXRuntimeError] : 2 : INVALID_ARGUMENT : Unexpected input data type. Actual: (tensor(uint8)) , expected: (tensor(float16))

这里输入的Tensor只能是float16需要转一下,但后续一定要及时转换回来,不然计算的时候很容易出现np.inf,导致结果异常,如nms时,会出现很多结果框,明明视觉上nms超过0.5,仍存在重合框,就是过程结果出现inf导致NMS计算值为0,结果框无法去重。

# 转换数据类型为 float16
input_image_16 = input_image.astype(np.float16)
outputs = session.run(None, {node_info.name: input_image_16})
detections = outputs[0]
detections = detections.astype(np.float32)

GitHub 加速计划 / on / onnxruntime
18
3
下载
microsoft/onnxruntime: 是一个用于运行各种机器学习模型的开源库。适合对机器学习和深度学习有兴趣的人,特别是在开发和部署机器学习模型时需要处理各种不同框架和算子的人。特点是支持多种机器学习框架和算子,包括 TensorFlow、PyTorch、Caffe 等,具有高性能和广泛的兼容性。
最近提交(Master分支:4 个月前 )
aedb49be ### Description <!-- Describe your changes. --> Changed all support tensor type from ir 9 to ir 10. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> - See issue https://github.com/microsoft/onnxruntime/issues/23205 Co-authored-by: Yueqing Zhang <yueqingz@amd.com> 13 小时前
bc91f5c7 ### Description <!-- Describe your changes. --> For legacy jetson users who use jetpack 5.x, the latest TRT version is 8.5. Add version check to newer trt features to fix build on jetpack 5.x (cuda11.8+gcc11 are required) ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> 1 天前
Logo

旨在为数千万中国开发者提供一个无缝且高效的云端环境,以支持学习、使用和贡献开源项目。

更多推荐