ONNX 前端测试

ONNX 前端测试#

import set_env
from d2py.utils.file import mkdir
root_dir = ".temp"
mkdir(root_dir)
import platform
import pytest
import builtins
import importlib

import tvm
from unittest import mock
from tvm.ir.module import IRModule

from tvm.driver import tvmc
from tvm.driver.tvmc import TVMCException, TVMCImportError
from tvm.driver.tvmc.model import TVMCModel
from d2py.tools.sanstyle.github.file import lfs_url
import httpx
import tempfile
from tqdm.asyncio import tqdm

orig_import = importlib.import_module

def verify_load_model__onnx(model, **kwargs):
    tvmc_model = tvmc.frontends.load_model(model, **kwargs)
    assert type(tvmc_model) is TVMCModel
    assert type(tvmc_model.mod) is IRModule
    assert type(tvmc_model.params) is dict
    return tvmc_model
sut = tvmc.frontends.guess_frontend("a_model.onnx")
assert type(sut) is tvmc.frontends.OnnxFrontend
tvmc.load("not/a/file.txt", model_format="onnx")
---------------------------------------------------------------------------
FileNotFoundError                         Traceback (most recent call last)
Cell In[4], line 1
----> 1 tvmc.load("not/a/file.txt", model_format="onnx")

File /media/pc/data/lxw/ai/tvm/python/tvm/driver/tvmc/frontends.py:476, in load_model(path, model_format, shape_dict, **kwargs)
    473 else:
    474     frontend = guess_frontend(path)
--> 476 mod, params = frontend.load(path, shape_dict, **kwargs)
    478 return TVMCModel(mod, params)

File /media/pc/data/lxw/ai/tvm/python/tvm/driver/tvmc/frontends.py:167, in OnnxFrontend.load(self, path, shape_dict, **kwargs)
    164 onnx = lazy_import("onnx")
    166 # pylint: disable=E1101
--> 167 model = onnx.load(path)
    169 return relay.frontend.from_onnx(model, shape=shape_dict, **kwargs)

File /media/pc/data/tmp/cache/conda/envs/xin/lib/python3.12/site-packages/onnx/__init__.py:210, in load_model(f, format, load_external_data)
    189 def load_model(
    190     f: IO[bytes] | str | os.PathLike,
    191     format: _SupportedFormat | None = None,  # noqa: A002
    192     load_external_data: bool = True,
    193 ) -> ModelProto:
    194     """Loads a serialized ModelProto into memory.
    195 
    196     Args:
   (...)
    208         Loaded in-memory ModelProto.
    209     """
--> 210     model = _get_serializer(format, f).deserialize_proto(_load_bytes(f), ModelProto())
    212     if load_external_data:
    213         model_filepath = _get_file_path(f)

File /media/pc/data/tmp/cache/conda/envs/xin/lib/python3.12/site-packages/onnx/__init__.py:147, in _load_bytes(f)
    145 else:
    146     f = typing.cast(Union[str, os.PathLike], f)
--> 147     with open(f, "rb") as readable:
    148         content = readable.read()
    149 return content

FileNotFoundError: [Errno 2] No such file or directory: 'not/a/file.txt'
url = lfs_url("onnx/models", "vision/classification/resnet/model/resnet50-v2-7.onnx", branch="bd206494e8b6a27b25e5cf7199dbcdbfe9d05d1c")
prefix = url.split("/")[-1].split(".onnx")[0]
client = httpx.AsyncClient()
with httpx.stream("GET", url) as response:
    print('response', response)
    total = int(response.headers["Content-Length"])
    print('total', total)

with tempfile.NamedTemporaryFile(delete=False, prefix=f"{prefix}-", suffix=".onnx", dir=root_dir) as download_file:
    async with client.stream('GET', url) as response:
        async for chunk in tqdm(response.aiter_bytes()):
            download_file.write(chunk)
response <Response [200 OK]>
total 102442450
6264it [10:44,  9.72it/s]
onnx_resnet50 = download_file.name
tvmc_model = verify_load_model__onnx(onnx_resnet50, freeze_params=False)
# check whether one known value is part of the params dict
assert "resnetv24_batchnorm0_gamma" in tvmc_model.params.keys()
tvmc_model = verify_load_model__onnx(onnx_resnet50, freeze_params=True)
# check that the parameter dict is empty, implying that they have been folded into constants
assert tvmc_model.params == {}
tvmc_model = tvmc.frontends.load_model(onnx_resnet50)
before = tvmc_model.mod

expected_layout = "NHWC"
with tvm.transform.PassContext(opt_level=3):
    after = tvmc.transform.convert_graph_layout(before, expected_layout)

layout_transform_calls = []

def _is_layout_transform(node):
    if isinstance(node, tvm.relay.expr.Call):
        layout_transform_calls.append(
            node.op.name == "layout_transform"
            and node.attrs.src_layout == "NCHW"
            and node.attrs.dst_layout == "NHWC"
        )

tvm.relay.analysis.post_order_visit(after["main"], _is_layout_transform)

assert any(layout_transform_calls), "Expected 'layout_transform NCWH->NHWC' not found"
[00:37:12] /media/pc/data/lxw/ai/tvm/src/relay/transforms/convert_layout.cc:99: Warning: Desired layout(s) not specified for op: nn.max_pool2d
[00:37:13] /media/pc/data/lxw/ai/tvm/src/relay/transforms/convert_layout.cc:99: Warning: Desired layout(s) not specified for op: nn.global_avg_pool2d
tvmc_model = tvmc.frontends.load_model(onnx_resnet50)
before = tvmc_model.mod

expected_layout = "NCHW"

with tvm.transform.PassContext(opt_level=3):
    after = tvmc.transform.convert_graph_layout(before, expected_layout)

layout_transform_calls = []

def _is_layout_transform(node):
    if isinstance(node, tvm.relay.expr.Call):
        layout_transform_calls.append(
            node.op.name == "layout_transform"
            and node.attrs.src_layout == "NCHW"
            and node.attrs.dst_layout == "NCHW"
        )

tvm.relay.analysis.post_order_visit(after["main"], _is_layout_transform)

assert not any(layout_transform_calls), "Unexpected 'layout_transform' call"
[00:45:55] /media/pc/data/lxw/ai/tvm/src/relay/transforms/convert_layout.cc:99: Warning: Desired layout(s) not specified for op: nn.max_pool2d
[00:45:55] /media/pc/data/lxw/ai/tvm/src/relay/transforms/convert_layout.cc:99: Warning: Desired layout(s) not specified for op: nn.global_avg_pool2d