# ONNX 前端测试

In [1]:
import set_env
from d2py.utils.file import mkdir
root_dir = ".temp"
mkdir(root_dir)

In [2]:
import platform
import pytest
import builtins
import importlib

import tvm
from unittest import mock
from tvm.ir.module import IRModule

from tvm.driver import tvmc
from tvm.driver.tvmc import TVMCException, TVMCImportError
from tvm.driver.tvmc.model import TVMCModel
from d2py.tools.sanstyle.github.file import lfs_url
import httpx
import tempfile
from tqdm.asyncio import tqdm

orig_import = importlib.import_module

def verify_load_model__onnx(model, **kwargs):
    tvmc_model = tvmc.frontends.load_model(model, **kwargs)
    assert type(tvmc_model) is TVMCModel
    assert type(tvmc_model.mod) is IRModule
    assert type(tvmc_model.params) is dict
    return tvmc_model

In [3]:
sut = tvmc.frontends.guess_frontend("a_model.onnx")
assert type(sut) is tvmc.frontends.OnnxFrontend

In [4]:
tvmc.load("not/a/file.txt", model_format="onnx")

FileNotFoundError: [Errno 2] No such file or directory: 'not/a/file.txt'

In [6]:
url = lfs_url("onnx/models", "vision/classification/resnet/model/resnet50-v2-7.onnx", branch="bd206494e8b6a27b25e5cf7199dbcdbfe9d05d1c")
prefix = url.split("/")[-1].split(".onnx")[0]
client = httpx.AsyncClient()
with httpx.stream("GET", url) as response:
    print('response', response)
    total = int(response.headers["Content-Length"])
    print('total', total)

with tempfile.NamedTemporaryFile(delete=False, prefix=f"{prefix}-", suffix=".onnx", dir=root_dir) as download_file:
    async with client.stream('GET', url) as response:
        async for chunk in tqdm(response.aiter_bytes()):
            download_file.write(chunk)

response <Response [200 OK]>
total 102442450


6264it [10:44,  9.72it/s]


In [9]:
onnx_resnet50 = download_file.name

In [10]:
tvmc_model = verify_load_model__onnx(onnx_resnet50, freeze_params=False)
# check whether one known value is part of the params dict
assert "resnetv24_batchnorm0_gamma" in tvmc_model.params.keys()
tvmc_model = verify_load_model__onnx(onnx_resnet50, freeze_params=True)
# check that the parameter dict is empty, implying that they have been folded into constants
assert tvmc_model.params == {}

In [11]:
tvmc_model = tvmc.frontends.load_model(onnx_resnet50)
before = tvmc_model.mod

expected_layout = "NHWC"
with tvm.transform.PassContext(opt_level=3):
    after = tvmc.transform.convert_graph_layout(before, expected_layout)

layout_transform_calls = []

def _is_layout_transform(node):
    if isinstance(node, tvm.relay.expr.Call):
        layout_transform_calls.append(
            node.op.name == "layout_transform"
            and node.attrs.src_layout == "NCHW"
            and node.attrs.dst_layout == "NHWC"
        )

tvm.relay.analysis.post_order_visit(after["main"], _is_layout_transform)

assert any(layout_transform_calls), "Expected 'layout_transform NCWH->NHWC' not found"



In [12]:


tvmc_model = tvmc.frontends.load_model(onnx_resnet50)
before = tvmc_model.mod

expected_layout = "NCHW"

with tvm.transform.PassContext(opt_level=3):
    after = tvmc.transform.convert_graph_layout(before, expected_layout)

layout_transform_calls = []

def _is_layout_transform(node):
    if isinstance(node, tvm.relay.expr.Call):
        layout_transform_calls.append(
            node.op.name == "layout_transform"
            and node.attrs.src_layout == "NCHW"
            and node.attrs.dst_layout == "NCHW"
        )

tvm.relay.analysis.post_order_visit(after["main"], _is_layout_transform)

assert not any(layout_transform_calls), "Unexpected 'layout_transform' call"


