#
推理#
from typing import Literal, Optional
import numpy as np
import onnx
import onnxruntime
# import pytest
from onnx import ModelProto, TensorProto, helper
import tvm
import tvm.testing
from tvm import relax
from tvm.relax.frontend.onnx import from_onnx
from tvm.script import relax as R
from tvm.script import tir as T
from tvm.script import ir as I
bg = np.random.MT19937(0)
rg = np.random.Generator(bg)
def generate_random_inputs(
model: ModelProto, inputs: Optional[dict[str, np.ndarray]] = None
) -> dict[str, np.ndarray]:
input_values = {}
# Iterate through model inputs and extract their shape.
for i in model.graph.input:
if inputs is not None and i.name in inputs and inputs[i.name] is not None:
input_values[i.name] = inputs[i.name]
continue
shape = []
for dim in i.type.tensor_type.shape.dim:
shape.append(dim.dim_value)
input_values[i.name] = generate_random_value(shape, i.type.tensor_type.elem_type)
return input_values
def generate_random_value(shape, elem_type) -> np.ndarray:
# Extract datatype for the input.
if elem_type:
dtype = str(helper.tensor_dtype_to_np_dtype(elem_type))
else:
dtype = "float32"
# Generate random inputs for each input.
if dtype == "bool":
# random_value = np.random.choice(a=[False, True], size=shape)
random_value = rg.choice(a=[False, True], size=shape)
elif dtype.startswith("int"):
# Keep non-zero values
random_value = rg.integers(low=-63, high=63, size=shape).astype(dtype)
random_value[random_value <= 0] -= 1
else:
random_value = rg.standard_normal(size=shape).astype(dtype)
return random_value
def check_correctness(
model: ModelProto,
inputs: Optional[dict[str, np.ndarray]] = None,
ir_version: int = 8,
opset: int = 14,
rtol: float = 1e-7,
atol: float = 1e-5,
check_dtypes: bool = False,
) -> None:
"""Run an onnx model in both onnxruntime and TVM through our importer
confirm that the results match. Otherwise, an exception will be raised.
Parameters
----------
model: ModelProto
The input onnx model that should be tested.
inputs: Optional[Dict[str, np.ndarray]]
An optional dictionary containing values for each input in the onnx model.
ir_version: int
Which version of the onnx IR to use.
opset: int
The opset version to use for the onnx importer.
atol: float
Set the tolerance of correctness checking. Some ops may be show more
arithmetic variance than others.
check_dtypes: bool
Check if data types are the same.
"""
# Configure model format.
if ir_version is not None:
model.ir_version = ir_version
if opset is not None:
model.opset_import[0].version = opset
# If inputs are not provided, extract them from the onnx graph and produce random
# values that we'll use for testing.
inputs = generate_random_inputs(model, inputs)
# Run the model through onnx to get the expected result.
ort_session = onnxruntime.InferenceSession(
model.SerializeToString(), providers=["CPUExecutionProvider"]
)
ort_output = ort_session.run([], inputs)
# Convert the onnx model into relax through the onnx importer.
tvm_model = from_onnx(model, opset=opset, keep_params_in_input=True)
# Convert operators for inference mode.
tvm_model = relax.transform.DecomposeOpsForInference()(tvm_model)
# Legalize any relax ops into tensorir.
tvm_model = relax.transform.LegalizeOps()(tvm_model)
# Separate model from parameters.
tvm_model, params = relax.frontend.detach_params(tvm_model)
# Compile the relax graph into a VM then run.
with tvm.transform.PassContext(opt_level=3):
ex = tvm.compile(tvm_model, target="llvm")
vm = tvm.runtime.vm.VirtualMachine(ex, tvm.cpu())
# Prepare inputs.
input_list = [
inputs[key.name_hint] for key in tvm_model["main"].params if key.name_hint in inputs
]
if params:
input_list += params["main"]
# Run model and check outputs.
vm.set_input("main", *input_list)
vm.invoke_stateful("main")
tvm_output = vm.get_outputs("main")
# Wrap as a list if there is only one output.
if len(ort_output) == 1:
# Do not check the output number for TVM
# As for sequence output, the TVM output is a Tuple
# while the ONNX output number is one, which is a list
tvm_output = [tvm_output]
def _get_numpy_subdtype(narray):
if np.issubdtype(narray.dtype, np.integer):
return "integer"
elif np.issubdtype(narray.dtype, np.floating):
return "floating"
elif np.issubdtype(narray.dtype, np.bool_):
return "bool"
elif np.issubdtype(narray.dtype, np.complexfloating):
return "complexfloating"
else:
return "other"
def _check_output(tvm_out, ort_out):
if isinstance(tvm_out, tuple) and isinstance(ort_out, (tvm.runtime.ShapeTuple, list)):
assert len(tvm_out) == len(ort_out), "Unequal number of outputs"
for tvm_out_i, ort_out_i in zip(tvm_out, ort_out):
_check_output(tvm_out_i, ort_out_i)
elif isinstance(tvm_out, tvm.nd.NDArray) and isinstance(ort_out, np.ndarray):
if check_dtypes:
assert tvm_out.numpy().dtype == ort_out.dtype
tvm.testing.assert_allclose(tvm_out.numpy(), ort_out, rtol=rtol, atol=atol)
elif isinstance(tvm_out, tvm.runtime.ShapeTuple) and isinstance(ort_out, np.ndarray):
shape_out = tvm.nd.array([int(i) for i in tvm_out])
if check_dtypes:
assert _get_numpy_subdtype(shape_out.numpy()) == _get_numpy_subdtype(ort_out)
tvm.testing.assert_allclose(shape_out.numpy(), ort_out, rtol=rtol, atol=atol)
elif isinstance(tvm_out, (int, float, bool)) and isinstance(ort_out, np.ndarray):
if check_dtypes:
assert _get_numpy_subdtype(np.array(tvm_out)) == _get_numpy_subdtype(ort_out)
tvm.testing.assert_allclose(np.array(tvm_out), ort_out, rtol=rtol, atol=atol)
else:
raise ValueError(f"Unsupported types: {type(tvm_out)}, {type(ort_out)}")
# Check that number of outputs match.
assert len(tvm_output) == len(ort_output), "Unequal number of outputs"
for tvm_out, ort_out in zip(tvm_output, ort_output):
# TODO Allow configurable tolerance.
if ort_out is not None:
_check_output(tvm_out, ort_out)
import torch
import torch.nn as nn
# ===== 情况 1:二维输入,最后一维归一化 =====
x = torch.randn(32, 32) # 模拟 ONNX 中的 [32, 32] 输入
scale = torch.randn(32) # γ
bias = torch.randn(32) # β
# 创建 LayerNorm,normalized_shape 对应最后一维大小
layer_norm = nn.LayerNorm(normalized_shape=32, eps=1e-12, elementwise_affine=False)
# 手动应用 γ 和 β(因为 elementwise_affine=False)
y = layer_norm(x) * scale + bias
print("二维输入输出形状:", y.shape)
# ===== 情况 2:四维输入,从 axis=1 开始归一化 =====
x_img = torch.randn(1, 3, 4, 4) # 模拟 [1, 3, 4, 4]
class LayerNorm(torch.nn.Module):
def __init__(self, shape):
super().__init__()
self.shape = shape
self.weight = torch.nn.Parameter(torch.randn(shape)) # γ
self.bias = torch.nn.Parameter(torch.randn(shape)) # β
def forward(self, x):
return torch.nn.functional.layer_norm(x, self.shape, self.weight, self.bias, 1e-5)
layer_norm_nd = LayerNorm((3, 4, 4))
y_img = layer_norm_nd(x_img)
print("四维输入输出形状:", y_img.shape)
二维输入输出形状: torch.Size([32, 32])
四维输入输出形状: torch.Size([1, 3, 4, 4])
from pathlib import Path
temp_dir = Path(".temp")
temp_dir.mkdir(exist_ok=True)
torch.onnx.export(
layer_norm_nd,
(x_img,),
temp_dir/'demo.onnx',
input_names=["x",],
dynamo=True
)
check_correctness(
model=onnx.load(temp_dir/'demo.onnx'),
inputs=[{"x": x_img,}],
)
/media/pc/data/lxw/ai/tvm/python/tvm/relax/frontend/onnx/onnx_frontend.py:3949: UserWarning: No Op registered for LayerNormalization with domain_version of 14
==> Context: Bad node spec for node. Name: node_layer_norm OpType: LayerNormalization
warnings.warn(str(exception))
tvm_model = from_onnx(onnx.load(temp_dir/'demo.onnx'), opset=20, keep_params_in_input=True)
tvm_model.show()
# from tvm.script import ir as I
# from tvm.script import relax as R
@I.ir_module
class Module:
@R.function
def main(x: R.Tensor((1, 3, 4, 4), dtype="float32"), weight: R.Tensor((3, 4, 4), dtype="float32"), bias: R.Tensor((3, 4, 4), dtype="float32")) -> R.Tensor((1, 3, 4, 4), dtype="float32"):
R.func_attr({"num_input": 1, "params": [metadata["ffi.NDArray"][0], metadata["ffi.NDArray"][1]]})
with R.dataflow():
lv: R.Tensor((1, 3, 4, 4), dtype="float32") = R.nn.layer_norm(x, weight, bias, axes=[-3, -2, -1], epsilon=9.9999997473787516e-06, center=True, scale=True)
gv: R.Tensor((1, 3, 4, 4), dtype="float32") = lv
R.output(gv)
return gv
# Metadata omitted. Use show_meta=True in script() method to show it.