reshape#
配置:
temp_dir = Path(".temp")
temp_dir.mkdir(exist_ok=True)
model_path = f"{temp_dir}/Reshape.onnx" # 模型存储路径
in_shape = (4, 3, 3, 4)
ref_shape = (6, 2, 4, 3)
ref_array = np.array(ref_shape)
ref_node = onnx.helper.make_node(
"Constant",
inputs=[],
outputs=["ref_in"],
value=onnx.helper.make_tensor(
name="const_tensor",
data_type=onnx.TensorProto.INT32,
dims=ref_array.shape,
vals=ref_array.flatten().astype(int),
),
)
reshape_node = helper.make_node("Reshape", ["in", "ref_in"], ["out"])
graph = helper.make_graph(
[ref_node, reshape_node],
"reshape_test",
inputs=[helper.make_tensor_value_info("in", TensorProto.FLOAT, list(in_shape))],
outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(ref_shape))],
)
graph_def = helper.make_model(graph, producer_name="reshape_test")
onnx.save(graph_def, model_path) # 模型存储
target = "llvm"
dev = tvm.device(target)
inputs = np.random.uniform(size=in_shape).astype("int32")
tvm_result = get_tvm_output(graph_def,
inputs,
target,
dev
)
ort_out = get_onnxruntime_output(graph_def, inputs)
---------------------------------------------------------------------------
InvalidGraph Traceback (most recent call last)
Cell In[5], line 9
3 inputs = np.random.uniform(size=in_shape).astype("int32")
4 tvm_result = get_tvm_output(graph_def,
5 inputs,
6 target,
7 dev
8 )
----> 9 ort_out = get_onnxruntime_output(graph_def, inputs)
Cell In[2], line 38, in get_onnxruntime_output(graph_def, inputs)
36 """Generic function to generate onnxruntime output"""
37 # rep = onnxruntime.backend.prepare(graph_def.SerializeToString(), 'CPU', providers=['CPUExecutionProvider'])
---> 38 sess = onnxruntime.InferenceSession(
39 graph_def.SerializeToString(), providers=['CPUExecutionProvider']
40 )
41 for x, data in zip(sess.get_inputs(), inputs):
42 input_names[x.name] = data
File /media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py:465, in InferenceSession.__init__(self, path_or_bytes, sess_options, providers, provider_options, **kwargs)
462 disabled_optimizers = kwargs.get("disabled_optimizers")
464 try:
--> 465 self._create_inference_session(providers, provider_options, disabled_optimizers)
466 except (ValueError, RuntimeError) as e:
467 if self._enable_fallback:
File /media/pc/data/lxw/envs/anaconda3x/envs/xxx/lib/python3.12/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py:528, in InferenceSession._create_inference_session(self, providers, provider_options, disabled_optimizers)
526 sess = C.InferenceSession(session_options, self._model_path, True, self._read_config_from_model)
527 else:
--> 528 sess = C.InferenceSession(session_options, self._model_bytes, False, self._read_config_from_model)
530 if disabled_optimizers is None:
531 disabled_optimizers = set()
InvalidGraph: [ONNXRuntimeError] : 10 : INVALID_GRAPH : This is an invalid model. Type Error: Type 'tensor(int32)' of input parameter (ref_in) of operator (Reshape) in node (Reshape_0) is invalid.