import torch
import torch.nn as nn
import os
from allennlp.nn import util

torch.manual_seed(1)
os.environ["CUDA_VISIBLE_DEVICES"] = "1"


class Test(nn.Module):
    def __init__(self):
        super().__init__()
        pass

    def forward(self, mix, offsets):
        offsets2d = util.combine_initial_dims(offsets)
        # now offsets is (batch_size * d1 * ... * dn, orig_sequence_length)
        range_vector = util.get_range_vector(
            offsets2d.size(0), device=util.get_device_of(mix)
        ).unsqueeze(1)
        # selected embeddings is also (batch_size * d1 * ... * dn, orig_sequence_length)
        selected_embeddings = mix[range_vector, offsets2d]

        return util.uncombine_initial_dims(selected_embeddings, offsets.size())

model = Test()
model.cuda()
mix = torch.randn([2, 5, 3]).cuda()
print(mix)
offsets = torch.tensor([[1, 3, 0], [1, 2, 4]]).cuda()
out = model(mix, offsets)
print(out)
# 转成onnx模型
ONNX_FILE_PATH = "./test.onnx"
torch.onnx.export(model,
                  (mix, offsets),
                  ONNX_FILE_PATH, opset_version=12, verbose=True, input_names=["input_ids", "offsets"],
                  output_names=["output"],
                  dynamic_axes={
                      'input_ids': {
                          0: 'batch_size',
                          1: 'seq_len',
                      },
                      'offsets': {
                          0: 'batch_size',
                          1: 'word_len',
                      }
                  },
                  export_params=True)


# 运行图
import onnxruntime as ort
ONNX_FILE_PATH = "./test.onnx"
ort_session = ort.InferenceSession(ONNX_FILE_PATH)
print(ort.get_device())

mix = torch.randn([3, 5, 3]).cuda()
print(mix)
offsets = torch.tensor([[1, 3, 0], [1, 3, 0], [1, 2, 4]]).cuda()

ort_inputs = {
    ort_session.get_inputs()[0].name: mix.cpu().numpy(),
    ort_session.get_inputs()[1].name: offsets.cpu().numpy(),
}
outputs = ort_session.run(None, ort_inputs)
print(outputs[0])

将一段简单的处理(对张量的第二个维度,按照索引取数,出现错误),存为onnx模型后,运行时出现错误:

onnxruntime.capi.onnxruntime_pybind11_state.RuntimeException: [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Non-zero status code returned while running Add node. Name:‘Add_9’ Status Message: /onnxruntime_src/onnxruntime/core/providers/cpu/math/element_wise_ops.h:487 void onnxruntime::BroadcastIterator::Append(int64_t, int64_t) axis == 1 || axis == largest was false. Attempting to broadcast an axis by a dimension other than 1. 2 by 3

替换按照索引取数的操作

    def forward(self, mix, offsets):
        B, S, D = mix.size()
        new_mix = mix.view(-1, D)
        _, W = offsets.size()
        right_add = torch.arange(0, B).unsqueeze(-1).cuda()
        right_add = right_add * S
        right_add.expand([B, W])
        new_offsets = right_add + offsets
        new_offsets = new_offsets.view(-1)
        out1 = new_mix.index_select(0, new_offsets).view(B, W, -1)
        return out1

问题解决。

GitHub 加速计划 / on / onnxruntime
17
3
下载
microsoft/onnxruntime: 是一个用于运行各种机器学习模型的开源库。适合对机器学习和深度学习有兴趣的人,特别是在开发和部署机器学习模型时需要处理各种不同框架和算子的人。特点是支持多种机器学习框架和算子,包括 TensorFlow、PyTorch、Caffe 等,具有高性能和广泛的兼容性。
最近提交(Master分支:3 个月前 )
ebdbbb75 ### Description <!-- Describe your changes. --> 1. Add support for throwing error when hardware is not supported for VitisAI. 2. Add support for unloading VitisAI EP. 3. Add API for Win25. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> This is requirement for Win25 8 小时前
68061740 ### Description This change fixes the WebGPU delay load test. <details> <summary>Fix UB in macro</summary> The following C++ code outputs `2, 1` in MSVC, while it outputs `1, 1` in GCC: ```c++ #include <iostream> #define A 1 #define B 1 #define ENABLE defined(A) && defined(B) #if ENABLE int x = 1; #else int x = 2; #endif #if defined(A) && defined(B) int y = 1; #else int y = 2; #endif int main() { std::cout << x << ", " << y << "\n"; } ``` Clang reports `macro expansion producing 'defined' has undefined behavior [-Wexpansion-to-defined]`. </details> <details> <summary>Fix condition of build option onnxruntime_ENABLE_DELAY_LOADING_WIN_DLLS</summary> Delay load is explicitly disabled when python binding is being built. modifies the condition. </details> 17 小时前
Logo

旨在为数千万中国开发者提供一个无缝且高效的云端环境,以支持学习、使用和贡献开源项目。

更多推荐