# 测试 ONNX Relax

创建缓存目录：

In [1]:
from pathlib import Path

temp_dir = Path(".temp")
temp_dir.mkdir(exist_ok=True)

## 构建 ONNX 模型

In [2]:
import torch
import torch.nn.functional as F
class M(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = torch.nn.Conv2d(3, 16, 3, bias=False)
        self.conv2 = torch.nn.Conv2d(16, 32, 1, bias=False)

    def forward(self, x):
        # x = self.conv(x)
        x = F.interpolate(x, size=None, scale_factor=(0.5, 0.5), mode="nearest",)
        return x


torch_model = M()
input_tensor = torch.randn(1, 3, 10, 10)
torch.onnx.export(
    torch_model, 
    (input_tensor,), 
    temp_dir/"test.onnx", 
    input_names=["x"],
    opset_version=11,
)
torch.onnx.export(
    torch_model, 
    (input_tensor,), 
    temp_dir/"test19.onnx", 
    input_names=["x"],
    opset_version=19,
)

## 转换 ONNX 模型为 Relax 模型

In [5]:
import onnx
from tvm.relax.frontend.onnx import from_onnx
model = onnx.load(temp_dir/"test.onnx")
tvm_model = from_onnx(model,  keep_params_in_input=True, opset=20)

Error converting operator Resize, with inputs: [x, metadata["relax.expr.Constant"][0]
# Metadata omitted. Use show_meta=True in script() method to show it., metadata["relax.expr.Constant"][0]
# Metadata omitted. Use show_meta=True in script() method to show it.]


TVMError: Traceback (most recent call last):
  File "/media/pc/data/lxw/ai/tvm/include/tvm/runtime/packed_func.h", line 924
TVMError: In function relax.op.image.resize2d(0: RelaxExpr, 1: RelaxExpr, 2: Array<FloatImm>, 3: runtime.String, 4: runtime.String, 5: runtime.String, 6: runtime.String, 7: double, 8: int, 9: double, 10: DataType) -> RelaxExpr: error while converting argument 2: [17:25:38] /media/pc/data/lxw/ai/tvm/include/tvm/runtime/packed_func.h:2274: InternalError: Check failed: (!checked_type.defined()) is false: Expected Array[runtime.Object], but got relax.expr.Call


In [None]:
import onnx
from tvm.relax.frontend.onnx import from_onnx
model = onnx.load(temp_dir/"test19.onnx")
tvm_model = from_onnx(model,  keep_params_in_input=True)

In [4]:
from_onnx?

[0;31mSignature:[0m
[0mfrom_onnx[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mmodel[0m[0;34m:[0m [0monnx[0m[0;34m.[0m[0monnx_ml_pb2[0m[0;34m.[0m[0mGraphProto[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mshape_dict[0m[0;34m:[0m [0mOptional[0m[0;34m[[0m[0mDict[0m[0;34m[[0m[0mstr[0m[0;34m,[0m [0mList[0m[0;34m][0m[0;34m][0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mdtype_dict[0m[0;34m:[0m [0mUnion[0m[0;34m[[0m[0mstr[0m[0;34m,[0m [0mDict[0m[0;34m[[0m[0mstr[0m[0;34m,[0m [0mstr[0m[0;34m][0m[0;34m,[0m [0mNoneType[0m[0;34m][0m [0;34m=[0m [0;34m'float32'[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mopset[0m[0;34m:[0m [0mint[0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mkeep_params_in_input[0m[0;34m:[0m [0mbool[0m [0;34m=[0m [0;32mFalse[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0msanitize_input_names[0m[0;34m:[0m [0mbool[0m [0;34m=[0m [0;32mTrue[0

In [None]:
tvm_model

In [None]:
from io import StringIO
from contextlib import redirect_stdout, redirect_stderr
import tempfile
import torch
import torch.nn.functional as F
import onnx
from tvm.relax.frontend.onnx import from_onnx

def test_resize():
    class Resize(torch.nn.Module):
        def forward(self, x):
            x = F.interpolate(x, size=None, scale_factor=(0.5, 0.5), mode="nearest",)
            return x

    torch_model = Resize()
    input_tensor = torch.randn(1, 3, 10, 10)
    with tempfile.TemporaryDirectory() as temp_dir:
        onnx_path = f"{temp_dir}/test.onnx"
        torch.onnx.export(
            torch_model, 
            (input_tensor,), 
            onnx_path, 
            input_names=["x"],
            opset_version=11,
        )
        model = onnx.load(onnx_path)
        # need fix
        try:
            with redirect_stdout(StringIO()) as sio:
                tvm_model = from_onnx(model, keep_params_in_input=True)
        except Exception as e:
            print(f"Exception: {e}")
            assert (
                sio.getvalue() == 
                'Error converting operator Resize, with inputs: [x, metadata["relax.expr.Constant"][0]\n# Metadata omitted. '
                'Use show_meta=True in script() method to show it., metadata["relax.expr.Constant"][0]\n# Metadata omitted. '
                'Use show_meta=True in script() method to show it.]\n'
            )

In [None]:
from io import StringIO
from contextlib import redirect_stdout
import numpy as np
from onnx import helper, TensorProto
from onnxscript import script
from onnxscript import FLOAT
from onnxscript import opset11 as op
from tvm.relax.frontend.onnx import from_onnx

def test_resize():
    @script()
    def Resize(X: FLOAT[1, 3, 20, 20]):
        scales = op.Constant(value=helper.make_tensor("scales", TensorProto.FLOAT, (4,), [1, 1, 0.5, 0.5]))
        roi = op.Constant(value=helper.make_tensor("roi", TensorProto.FLOAT, (), [10]))
        return op.Resize(X, roi=roi, scales=scales,)

    onnx_result = Resize(X=np.random.randn(1, 3, 20, 20).astype("float32"))
    model = Resize.to_model_proto() # returns an onnx.ModelProto
    # need fix
    try:
        with redirect_stdout(StringIO()) as sio:
            tvm_model = from_onnx(model, keep_params_in_input=True)
    except Exception as e:
        print(f"Exception: {e}")
        assert (
            sio.getvalue() == 
            'Error converting operator Resize, with inputs: [X, R.const(10.0, "float32"), '
            'metadata["relax.expr.Constant"][0]\n# Metadata omitted. '
            'Use show_meta=True in script() method to show it.]\n'
        )

In [None]:
test_resize()