ONNX 前端测试#
import set_env
from d2py.utils.file import mkdir
root_dir = ".temp"
mkdir(root_dir)
import platform
import pytest
import builtins
import importlib
import tvm
from unittest import mock
from tvm.ir.module import IRModule
from tvm.driver import tvmc
from tvm.driver.tvmc import TVMCException, TVMCImportError
from tvm.driver.tvmc.model import TVMCModel
from d2py.tools.sanstyle.github.file import lfs_url
import httpx
import tempfile
from tqdm.asyncio import tqdm
orig_import = importlib.import_module
def verify_load_model__onnx(model, **kwargs):
tvmc_model = tvmc.frontends.load_model(model, **kwargs)
assert type(tvmc_model) is TVMCModel
assert type(tvmc_model.mod) is IRModule
assert type(tvmc_model.params) is dict
return tvmc_model
sut = tvmc.frontends.guess_frontend("a_model.onnx")
assert type(sut) is tvmc.frontends.OnnxFrontend
tvmc.load("not/a/file.txt", model_format="onnx")
---------------------------------------------------------------------------
FileNotFoundError Traceback (most recent call last)
Cell In[4], line 1
----> 1 tvmc.load("not/a/file.txt", model_format="onnx")
File /media/pc/data/lxw/ai/tvm/python/tvm/driver/tvmc/frontends.py:476, in load_model(path, model_format, shape_dict, **kwargs)
473 else:
474 frontend = guess_frontend(path)
--> 476 mod, params = frontend.load(path, shape_dict, **kwargs)
478 return TVMCModel(mod, params)
File /media/pc/data/lxw/ai/tvm/python/tvm/driver/tvmc/frontends.py:167, in OnnxFrontend.load(self, path, shape_dict, **kwargs)
164 onnx = lazy_import("onnx")
166 # pylint: disable=E1101
--> 167 model = onnx.load(path)
169 return relay.frontend.from_onnx(model, shape=shape_dict, **kwargs)
File /media/pc/data/tmp/cache/conda/envs/xin/lib/python3.12/site-packages/onnx/__init__.py:210, in load_model(f, format, load_external_data)
189 def load_model(
190 f: IO[bytes] | str | os.PathLike,
191 format: _SupportedFormat | None = None, # noqa: A002
192 load_external_data: bool = True,
193 ) -> ModelProto:
194 """Loads a serialized ModelProto into memory.
195
196 Args:
(...)
208 Loaded in-memory ModelProto.
209 """
--> 210 model = _get_serializer(format, f).deserialize_proto(_load_bytes(f), ModelProto())
212 if load_external_data:
213 model_filepath = _get_file_path(f)
File /media/pc/data/tmp/cache/conda/envs/xin/lib/python3.12/site-packages/onnx/__init__.py:147, in _load_bytes(f)
145 else:
146 f = typing.cast(Union[str, os.PathLike], f)
--> 147 with open(f, "rb") as readable:
148 content = readable.read()
149 return content
FileNotFoundError: [Errno 2] No such file or directory: 'not/a/file.txt'
url = lfs_url("onnx/models", "vision/classification/resnet/model/resnet50-v2-7.onnx", branch="bd206494e8b6a27b25e5cf7199dbcdbfe9d05d1c")
prefix = url.split("/")[-1].split(".onnx")[0]
client = httpx.AsyncClient()
with httpx.stream("GET", url) as response:
print('response', response)
total = int(response.headers["Content-Length"])
print('total', total)
with tempfile.NamedTemporaryFile(delete=False, prefix=f"{prefix}-", suffix=".onnx", dir=root_dir) as download_file:
async with client.stream('GET', url) as response:
async for chunk in tqdm(response.aiter_bytes()):
download_file.write(chunk)
response <Response [200 OK]>
total 102442450
6264it [10:44, 9.72it/s]
onnx_resnet50 = download_file.name
tvmc_model = verify_load_model__onnx(onnx_resnet50, freeze_params=False)
# check whether one known value is part of the params dict
assert "resnetv24_batchnorm0_gamma" in tvmc_model.params.keys()
tvmc_model = verify_load_model__onnx(onnx_resnet50, freeze_params=True)
# check that the parameter dict is empty, implying that they have been folded into constants
assert tvmc_model.params == {}
tvmc_model = tvmc.frontends.load_model(onnx_resnet50)
before = tvmc_model.mod
expected_layout = "NHWC"
with tvm.transform.PassContext(opt_level=3):
after = tvmc.transform.convert_graph_layout(before, expected_layout)
layout_transform_calls = []
def _is_layout_transform(node):
if isinstance(node, tvm.relay.expr.Call):
layout_transform_calls.append(
node.op.name == "layout_transform"
and node.attrs.src_layout == "NCHW"
and node.attrs.dst_layout == "NHWC"
)
tvm.relay.analysis.post_order_visit(after["main"], _is_layout_transform)
assert any(layout_transform_calls), "Expected 'layout_transform NCWH->NHWC' not found"
[00:37:12] /media/pc/data/lxw/ai/tvm/src/relay/transforms/convert_layout.cc:99: Warning: Desired layout(s) not specified for op: nn.max_pool2d
[00:37:13] /media/pc/data/lxw/ai/tvm/src/relay/transforms/convert_layout.cc:99: Warning: Desired layout(s) not specified for op: nn.global_avg_pool2d
tvmc_model = tvmc.frontends.load_model(onnx_resnet50)
before = tvmc_model.mod
expected_layout = "NCHW"
with tvm.transform.PassContext(opt_level=3):
after = tvmc.transform.convert_graph_layout(before, expected_layout)
layout_transform_calls = []
def _is_layout_transform(node):
if isinstance(node, tvm.relay.expr.Call):
layout_transform_calls.append(
node.op.name == "layout_transform"
and node.attrs.src_layout == "NCHW"
and node.attrs.dst_layout == "NCHW"
)
tvm.relay.analysis.post_order_visit(after["main"], _is_layout_transform)
assert not any(layout_transform_calls), "Unexpected 'layout_transform' call"
[00:45:55] /media/pc/data/lxw/ai/tvm/src/relay/transforms/convert_layout.cc:99: Warning: Desired layout(s) not specified for op: nn.max_pool2d
[00:45:55] /media/pc/data/lxw/ai/tvm/src/relay/transforms/convert_layout.cc:99: Warning: Desired layout(s) not specified for op: nn.global_avg_pool2d