测试 ONNX Relax#

创建缓存目录:

from pathlib import Path

temp_dir = Path(".temp")
temp_dir.mkdir(exist_ok=True)

构建 ONNX 模型#

import torch
import torch.nn.functional as F
class M(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = torch.nn.Conv2d(3, 16, 3, bias=False)
        self.conv2 = torch.nn.Conv2d(16, 32, 1, bias=False)

    def forward(self, x):
        # x = self.conv(x)
        x = F.interpolate(x, size=None, scale_factor=(0.5, 0.5), mode="nearest",)
        return x


torch_model = M()
input_tensor = torch.randn(1, 3, 10, 10)
torch.onnx.export(
    torch_model, 
    (input_tensor,), 
    temp_dir/"test.onnx", 
    input_names=["x"],
    opset_version=11,
)
torch.onnx.export(
    torch_model, 
    (input_tensor,), 
    temp_dir/"test19.onnx", 
    input_names=["x"],
    opset_version=19,
)

转换 ONNX 模型为 Relax 模型#

import onnx
from tvm.relax.frontend.onnx import from_onnx
model = onnx.load(temp_dir/"test.onnx")
tvm_model = from_onnx(model,  keep_params_in_input=True, opset=20)
Error converting operator Resize, with inputs: [x, metadata["relax.expr.Constant"][0]
# Metadata omitted. Use show_meta=True in script() method to show it., metadata["relax.expr.Constant"][0]
# Metadata omitted. Use show_meta=True in script() method to show it.]
---------------------------------------------------------------------------
TVMError                                  Traceback (most recent call last)
Cell In[5], line 4
      2 from tvm.relax.frontend.onnx import from_onnx
      3 model = onnx.load(temp_dir/"test.onnx")
----> 4 tvm_model = from_onnx(model,  keep_params_in_input=True, opset=20)

File /media/pc/data/lxw/ai/tvm/python/tvm/relax/frontend/onnx/onnx_frontend.py:3690, in from_onnx(model, shape_dict, dtype_dict, opset, keep_params_in_input, sanitize_input_names)
   3683     warnings.warn(
   3684         ""
   3685         f"You are overwritting original opset ver = {opset_in_model} by lower ver = {opset}. "
   3686         f"That might cause model conversion errors."
   3687     )
   3689 # Use the graph proto as a scope so that ops can access other nodes if needed.
-> 3690 return g.from_onnx(graph, opset)

File /media/pc/data/lxw/ai/tvm/python/tvm/relax/frontend/onnx/onnx_frontend.py:3321, in ONNXGraphImporter.from_onnx(self, graph, opset)
   3319 self._parse_graph_input(graph)
   3320 self._check_for_unsupported_ops(graph)
-> 3321 self._construct_nodes(graph)
   3323 # now return the outputs
   3324 outputs = [self._nodes[self._parse_value_proto(i)] for i in graph.output]

File /media/pc/data/lxw/ai/tvm/python/tvm/relax/frontend/onnx/onnx_frontend.py:3501, in ONNXGraphImporter._construct_nodes(self, graph)
   3499 except TVMError as err:
   3500     print(f"Error converting operator {op_name}, with inputs: {inputs}")
-> 3501     raise err
   3503 if op_name in return_tuple_ops:
   3504     outputs_num = 1

File /media/pc/data/lxw/ai/tvm/python/tvm/relax/frontend/onnx/onnx_frontend.py:3496, in ONNXGraphImporter._construct_nodes(self, graph)
   3494         raise ValueError(f"Node {node.name} cannot handle ShapeExpr inputs.")
   3495 try:
-> 3496     op = self._convert_operator(op_name, inputs, attr, self.opset)
   3497     # Create struct information for the new operator.
   3498     op = self.bb.normalize(op)

File /media/pc/data/lxw/ai/tvm/python/tvm/relax/frontend/onnx/onnx_frontend.py:3596, in ONNXGraphImporter._convert_operator(self, op_name, inputs, attrs, opset)
   3594     convert_class = convert_map[op_name]
   3595     op_function = convert_class.get_converter(opset)
-> 3596     sym = op_function(self.bb, inputs, attrs, [self._nodes, self._params])
   3597 else:
   3598     raise NotImplementedError("Operator {} not implemented.".format(op_name))

File /media/pc/data/lxw/ai/tvm/python/tvm/relax/frontend/onnx/onnx_frontend.py:2146, in Resize._impl_v18(cls, bb, inputs, attr, params)
   2141     assert isinstance(
   2142         sizes, relax.Constant
   2143     ), "Only constant output size currently supported."
   2144     sizes = sizes.data.numpy().astype("int64").tolist()[2:]
-> 2146 return relax.op.image.resize2d(
   2147     x,
   2148     size=relax.ShapeExpr(sizes),
   2149     roi=roi,
   2150     layout="NCHW",
   2151     method=mode,
   2152     coordinate_transformation_mode=coord_mode,
   2153     rounding_method=rounding_method,
   2154     cubic_alpha=cubic_coeff_a,
   2155     cubic_exclude=exclude_outside,
   2156     extrapolation_value=extrapolation_value,
   2157 )

File /media/pc/data/lxw/ai/tvm/python/tvm/relax/op/image/image.py:116, in resize2d(data, size, roi, layout, method, coordinate_transformation_mode, rounding_method, cubic_alpha, cubic_exclude, extrapolation_value, out_dtype)
    113     else:
    114         size = ShapeExpr(size)
--> 116 return _ffi_api.resize2d(  # type: ignore
    117     data,
    118     size,
    119     roi,
    120     layout,
    121     method,
    122     coordinate_transformation_mode,
    123     rounding_method,
    124     cubic_alpha,
    125     cubic_exclude,
    126     extrapolation_value,
    127     out_dtype,
    128 )

File /media/pc/data/lxw/ai/tvm/python/tvm/_ffi/_cython/packed_func.pxi:339, in tvm._ffi._cy3.core.PackedFuncBase.__call__()

File /media/pc/data/lxw/ai/tvm/python/tvm/_ffi/_cython/packed_func.pxi:284, in tvm._ffi._cy3.core.FuncCall()

File /media/pc/data/lxw/ai/tvm/python/tvm/_ffi/_cython/base.pxi:185, in tvm._ffi._cy3.core.CHECK_CALL()

File /media/pc/data/lxw/ai/tvm/python/tvm/_ffi/base.py:468, in raise_last_ffi_error()
    462 # The exception PyObject may contain a large amount of state,
    463 # including all stack frames that may be inspected in a later
    464 # PDB post-mortem.  Therefore, we must make sure to remove the
    465 # underlying PyObject* from the C++ side after we retrieve it.
    466 _LIB.TVMDropLastPythonError()
--> 468 raise py_err

TVMError: Traceback (most recent call last):
  File "/media/pc/data/lxw/ai/tvm/include/tvm/runtime/packed_func.h", line 924
TVMError: In function relax.op.image.resize2d(0: RelaxExpr, 1: RelaxExpr, 2: Array<FloatImm>, 3: runtime.String, 4: runtime.String, 5: runtime.String, 6: runtime.String, 7: double, 8: int, 9: double, 10: DataType) -> RelaxExpr: error while converting argument 2: [17:25:38] /media/pc/data/lxw/ai/tvm/include/tvm/runtime/packed_func.h:2274: InternalError: Check failed: (!checked_type.defined()) is false: Expected Array[runtime.Object], but got relax.expr.Call
import onnx
from tvm.relax.frontend.onnx import from_onnx
model = onnx.load(temp_dir/"test19.onnx")
tvm_model = from_onnx(model,  keep_params_in_input=True)
from_onnx?
Signature:
from_onnx(
    model: onnx.onnx_ml_pb2.GraphProto,
    shape_dict: Optional[Dict[str, List]] = None,
    dtype_dict: Union[str, Dict[str, str], NoneType] = 'float32',
    opset: int = None,
    keep_params_in_input: bool = False,
    sanitize_input_names: bool = True,
) -> tvm.ir.module.IRModule
Docstring:
Convert a ONNX model into an equivalent Relax Function.
ONNX graphs are represented as Python Protobuf objects.

The current implementation assumes that the input model is after ONNX v1.1.0.

Parameters
----------
model : protobuf object
    ONNX ModelProto after ONNX v1.1.0
shape_dict : dict of str to tuple, optional
    The input shape to the graph
dtype_dict : str or dict of str to str, optional
    The input types to the graph
opset : int, optional
    Override to autodetected opset.
    This can be helpful for some testing.
keep_params_in_input : bool
    If True, parameters will be treated as input variables. If false,
    parameters are treated as constant and folded directly into the graph.
sanitize_input_names : bool, optional
    Whether to sanitize the input names to ensure they are valid Relax identifiers.

Returns
-------
mod : tvm.IRModule
    The relax module for compilation
File:      /media/pc/data/lxw/ai/tvm/python/tvm/relax/frontend/onnx/onnx_frontend.py
Type:      function
tvm_model
from io import StringIO
from contextlib import redirect_stdout, redirect_stderr
import tempfile
import torch
import torch.nn.functional as F
import onnx
from tvm.relax.frontend.onnx import from_onnx

def test_resize():
    class Resize(torch.nn.Module):
        def forward(self, x):
            x = F.interpolate(x, size=None, scale_factor=(0.5, 0.5), mode="nearest",)
            return x

    torch_model = Resize()
    input_tensor = torch.randn(1, 3, 10, 10)
    with tempfile.TemporaryDirectory() as temp_dir:
        onnx_path = f"{temp_dir}/test.onnx"
        torch.onnx.export(
            torch_model, 
            (input_tensor,), 
            onnx_path, 
            input_names=["x"],
            opset_version=11,
        )
        model = onnx.load(onnx_path)
        # need fix
        try:
            with redirect_stdout(StringIO()) as sio:
                tvm_model = from_onnx(model, keep_params_in_input=True)
        except Exception as e:
            print(f"Exception: {e}")
            assert (
                sio.getvalue() == 
                'Error converting operator Resize, with inputs: [x, metadata["relax.expr.Constant"][0]\n# Metadata omitted. '
                'Use show_meta=True in script() method to show it., metadata["relax.expr.Constant"][0]\n# Metadata omitted. '
                'Use show_meta=True in script() method to show it.]\n'
            )
from io import StringIO
from contextlib import redirect_stdout
import numpy as np
from onnx import helper, TensorProto
from onnxscript import script
from onnxscript import FLOAT
from onnxscript import opset11 as op
from tvm.relax.frontend.onnx import from_onnx

def test_resize():
    @script()
    def Resize(X: FLOAT[1, 3, 20, 20]):
        scales = op.Constant(value=helper.make_tensor("scales", TensorProto.FLOAT, (4,), [1, 1, 0.5, 0.5]))
        roi = op.Constant(value=helper.make_tensor("roi", TensorProto.FLOAT, (), [10]))
        return op.Resize(X, roi=roi, scales=scales,)

    onnx_result = Resize(X=np.random.randn(1, 3, 20, 20).astype("float32"))
    model = Resize.to_model_proto() # returns an onnx.ModelProto
    # need fix
    try:
        with redirect_stdout(StringIO()) as sio:
            tvm_model = from_onnx(model, keep_params_in_input=True)
    except Exception as e:
        print(f"Exception: {e}")
        assert (
            sio.getvalue() == 
            'Error converting operator Resize, with inputs: [X, R.const(10.0, "float32"), '
            'metadata["relax.expr.Constant"][0]\n# Metadata omitted. '
            'Use show_meta=True in script() method to show it.]\n'
        )
test_resize()