目录

#

推理#

from typing import Literal, Optional

import numpy as np
import onnx
import onnxruntime
# import pytest
from onnx import ModelProto, TensorProto, helper

import tvm
import tvm.testing
from tvm import relax
from tvm.relax.frontend.onnx import from_onnx
from tvm.script import relax as R
from tvm.script import tir as T
from tvm.script import ir as I

bg = np.random.MT19937(0)
rg = np.random.Generator(bg)

def generate_random_inputs(
    model: ModelProto, inputs: Optional[dict[str, np.ndarray]] = None
) -> dict[str, np.ndarray]:
    input_values = {}
    # Iterate through model inputs and extract their shape.
    for i in model.graph.input:
        if inputs is not None and i.name in inputs and inputs[i.name] is not None:
            input_values[i.name] = inputs[i.name]
            continue
        shape = []
        for dim in i.type.tensor_type.shape.dim:
            shape.append(dim.dim_value)

        input_values[i.name] = generate_random_value(shape, i.type.tensor_type.elem_type)

    return input_values


def generate_random_value(shape, elem_type) -> np.ndarray:
    # Extract datatype for the input.
    if elem_type:
        dtype = str(helper.tensor_dtype_to_np_dtype(elem_type))
    else:
        dtype = "float32"

    # Generate random inputs for each input.
    if dtype == "bool":
        # random_value = np.random.choice(a=[False, True], size=shape)
        random_value = rg.choice(a=[False, True], size=shape)
    elif dtype.startswith("int"):
        # Keep non-zero values
        random_value = rg.integers(low=-63, high=63, size=shape).astype(dtype)
        random_value[random_value <= 0] -= 1
    else:
        random_value = rg.standard_normal(size=shape).astype(dtype)

    return random_value
def check_correctness(
    model: ModelProto,
    inputs: Optional[dict[str, np.ndarray]] = None,
    ir_version: int = 8,
    opset: int = 14,
    rtol: float = 1e-7,
    atol: float = 1e-5,
    check_dtypes: bool = False,
) -> None:
    """Run an onnx model in both onnxruntime and TVM through our importer
       confirm that the results match. Otherwise, an exception will be raised.

    Parameters
    ----------
    model: ModelProto
        The input onnx model that should be tested.
    inputs: Optional[Dict[str, np.ndarray]]
        An optional dictionary containing values for each input in the onnx model.
    ir_version: int
        Which version of the onnx IR to use.
    opset: int
        The opset version to use for the onnx importer.
    atol: float
        Set the tolerance of correctness checking. Some ops may be show more
        arithmetic variance than others.
    check_dtypes: bool
        Check if data types are the same.
    """
    # Configure model format.
    if ir_version is not None:
        model.ir_version = ir_version
    if opset is not None:
        model.opset_import[0].version = opset

    # If inputs are not provided, extract them from the onnx graph and produce random
    # values that we'll use for testing.
    inputs = generate_random_inputs(model, inputs)

    # Run the model through onnx to get the expected result.
    ort_session = onnxruntime.InferenceSession(
        model.SerializeToString(), providers=["CPUExecutionProvider"]
    )
    ort_output = ort_session.run([], inputs)

    # Convert the onnx model into relax through the onnx importer.
    tvm_model = from_onnx(model, opset=opset, keep_params_in_input=True)
    # Convert operators for inference mode.
    tvm_model = relax.transform.DecomposeOpsForInference()(tvm_model)
    # Legalize any relax ops into tensorir.
    tvm_model = relax.transform.LegalizeOps()(tvm_model)

    # Separate model from parameters.
    tvm_model, params = relax.frontend.detach_params(tvm_model)
    # Compile the relax graph into a VM then run.
    with tvm.transform.PassContext(opt_level=3):
        ex = tvm.compile(tvm_model, target="llvm")
        vm = tvm.runtime.vm.VirtualMachine(ex, tvm.cpu())
    # Prepare inputs.
    input_list = [
        inputs[key.name_hint] for key in tvm_model["main"].params if key.name_hint in inputs
    ]
    if params:
        input_list += params["main"]

    # Run model and check outputs.
    vm.set_input("main", *input_list)
    vm.invoke_stateful("main")
    tvm_output = vm.get_outputs("main")
    # Wrap as a list if there is only one output.
    if len(ort_output) == 1:
        # Do not check the output number for TVM
        # As for sequence output, the TVM output is a Tuple
        # while the ONNX output number is one, which is a list
        tvm_output = [tvm_output]

    def _get_numpy_subdtype(narray):
        if np.issubdtype(narray.dtype, np.integer):
            return "integer"
        elif np.issubdtype(narray.dtype, np.floating):
            return "floating"
        elif np.issubdtype(narray.dtype, np.bool_):
            return "bool"
        elif np.issubdtype(narray.dtype, np.complexfloating):
            return "complexfloating"
        else:
            return "other"

    def _check_output(tvm_out, ort_out):
        if isinstance(tvm_out, tuple) and isinstance(ort_out, (tvm.runtime.ShapeTuple, list)):
            assert len(tvm_out) == len(ort_out), "Unequal number of outputs"
            for tvm_out_i, ort_out_i in zip(tvm_out, ort_out):
                _check_output(tvm_out_i, ort_out_i)
        elif isinstance(tvm_out, tvm.nd.NDArray) and isinstance(ort_out, np.ndarray):
            if check_dtypes:
                assert tvm_out.numpy().dtype == ort_out.dtype
            tvm.testing.assert_allclose(tvm_out.numpy(), ort_out, rtol=rtol, atol=atol)
        elif isinstance(tvm_out, tvm.runtime.ShapeTuple) and isinstance(ort_out, np.ndarray):
            shape_out = tvm.nd.array([int(i) for i in tvm_out])
            if check_dtypes:
                assert _get_numpy_subdtype(shape_out.numpy()) == _get_numpy_subdtype(ort_out)
            tvm.testing.assert_allclose(shape_out.numpy(), ort_out, rtol=rtol, atol=atol)
        elif isinstance(tvm_out, (int, float, bool)) and isinstance(ort_out, np.ndarray):
            if check_dtypes:
                assert _get_numpy_subdtype(np.array(tvm_out)) == _get_numpy_subdtype(ort_out)
            tvm.testing.assert_allclose(np.array(tvm_out), ort_out, rtol=rtol, atol=atol)
        else:
            raise ValueError(f"Unsupported types: {type(tvm_out)}, {type(ort_out)}")

    # Check that number of outputs match.
    assert len(tvm_output) == len(ort_output), "Unequal number of outputs"
    for tvm_out, ort_out in zip(tvm_output, ort_output):
        # TODO Allow configurable tolerance.
        if ort_out is not None:
            _check_output(tvm_out, ort_out)
import torch
import torch.nn as nn

# ===== 情况 1:二维输入,最后一维归一化 =====
x = torch.randn(32, 32)  # 模拟 ONNX 中的 [32, 32] 输入
scale = torch.randn(32)  # γ
bias = torch.randn(32)   # β

# 创建 LayerNorm,normalized_shape 对应最后一维大小
layer_norm = nn.LayerNorm(normalized_shape=32, eps=1e-12, elementwise_affine=False)

# 手动应用 γ 和 β(因为 elementwise_affine=False)
y = layer_norm(x) * scale + bias
print("二维输入输出形状:", y.shape)

# ===== 情况 2:四维输入,从 axis=1 开始归一化 =====
x_img = torch.randn(1, 3, 4, 4)  # 模拟 [1, 3, 4, 4]

class LayerNorm(torch.nn.Module):
    def __init__(self, shape):
        super().__init__()
        self.shape = shape
        self.weight = torch.nn.Parameter(torch.randn(shape))  # γ
        self.bias = torch.nn.Parameter(torch.randn(shape)) # β

    def forward(self, x):
        return torch.nn.functional.layer_norm(x, self.shape, self.weight, self.bias, 1e-5)

layer_norm_nd = LayerNorm((3, 4, 4))
y_img = layer_norm_nd(x_img)
print("四维输入输出形状:", y_img.shape)
二维输入输出形状: torch.Size([32, 32])
四维输入输出形状: torch.Size([1, 3, 4, 4])
from pathlib import Path
temp_dir = Path(".temp")
temp_dir.mkdir(exist_ok=True)
torch.onnx.export(
    layer_norm_nd, 
    (x_img,), 
    temp_dir/'demo.onnx', 
    input_names=["x",],
    dynamo=True
)

Hide code cell output

[torch.onnx] Obtain model graph for `LayerNorm()` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `LayerNorm()` with `torch.export.export(..., strict=False)`... ✅
[torch.onnx] Run decomposition...
[torch.onnx] Run decomposition... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ✅
ONNXProgram(
    model=
        <
            ir_version=10,
            opset_imports={'': 20},
            producer_name='pytorch',
            producer_version='2.9.0.dev20250725+cpu',
            domain=None,
            model_version=None,
        >
        graph(
            name=main_graph,
            inputs=(
                %"x"<FLOAT,[1,3,4,4]>
            ),
            outputs=(
                %"layer_norm"<FLOAT,[1,3,4,4]>
            ),
            initializers=(
                %"weight"<FLOAT,[3,4,4]>{TorchTensor(...)},
                %"bias"<FLOAT,[3,4,4]>{TorchTensor(...)}
            ),
        ) {
            0 |  # node_layer_norm
                 %"layer_norm"<FLOAT,[1,3,4,4]>, %""<?,?>, %""<?,?> ⬅️ ::LayerNormalization(%"x", %"weight"{...}, %"bias"{...}) {axis=-3, epsilon=1e-05, stash_type=1}
            return %"layer_norm"<FLOAT,[1,3,4,4]>
        }


    ,
    exported_program=
        ExportedProgram:
            class GraphModule(torch.nn.Module):
                def forward(self, p_weight: "f32[3, 4, 4]", p_bias: "f32[3, 4, 4]", x: "f32[1, 3, 4, 4]"):
                     # File: /tmp/ipykernel_1344764/1027669049.py:27 in forward, code: return torch.nn.functional.layer_norm(x, self.shape, self.weight, self.bias, 1e-5)
                    layer_norm: "f32[1, 3, 4, 4]" = torch.ops.aten.layer_norm.default(x, [3, 4, 4], p_weight, p_bias);  x = p_weight = p_bias = None
                    return (layer_norm,)
            
        Graph signature: 
            # inputs
            p_weight: PARAMETER target='weight'
            p_bias: PARAMETER target='bias'
            x: USER_INPUT
    
            # outputs
            layer_norm: USER_OUTPUT
    
        Range constraints: {}

)
check_correctness(
    model=onnx.load(temp_dir/'demo.onnx'),
    inputs=[{"x": x_img,}],
)
/media/pc/data/lxw/ai/tvm/python/tvm/relax/frontend/onnx/onnx_frontend.py:3949: UserWarning: No Op registered for LayerNormalization with domain_version of 14

==> Context: Bad node spec for node. Name: node_layer_norm OpType: LayerNormalization
  warnings.warn(str(exception))
tvm_model = from_onnx(onnx.load(temp_dir/'demo.onnx'), opset=20, keep_params_in_input=True)
tvm_model.show()
# from tvm.script import ir as I
# from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function
    def main(x: R.Tensor((1, 3, 4, 4), dtype="float32"), weight: R.Tensor((3, 4, 4), dtype="float32"), bias: R.Tensor((3, 4, 4), dtype="float32")) -> R.Tensor((1, 3, 4, 4), dtype="float32"):
        R.func_attr({"num_input": 1, "params": [metadata["ffi.NDArray"][0], metadata["ffi.NDArray"][1]]})
        with R.dataflow():
            lv: R.Tensor((1, 3, 4, 4), dtype="float32") = R.nn.layer_norm(x, weight, bias, axes=[-3, -2, -1], epsilon=9.9999997473787516e-06, center=True, scale=True)
            gv: R.Tensor((1, 3, 4, 4), dtype="float32") = lv
            R.output(gv)
        return gv

# Metadata omitted. Use show_meta=True in script() method to show it.