TVMRunner
#
%cd ..
from pathlib import Path
temp_dir = Path(".temp")
temp_dir.mkdir(exist_ok=True)
/media/pc/data/lxw/ai/tvm-book/doc/tutorials/msc
构建前端模型:
import numpy as np
import torch
from torch import fx
from tvm.relax.frontend.torch import from_fx
def _get_torch_model(name, training=False):
"""Get model from torch vision"""
# pylint: disable=import-outside-toplevel
try:
import torchvision
model = getattr(torchvision.models, name)()
if training:
model = model.train()
else:
model = model.eval()
return model
except: # pylint: disable=bare-except
print("please install torchvision package")
return None
构建并运行 TVMRunner
#
from tvm.contrib.msc.framework.tvm.runtime import TVMRunner
import tvm
from tvm.contrib.msc.core import utils as msc_utils
def _test_from_torch(runner_cls, device, training=False, atol=1e-1, rtol=1e-1):
"""Test runner from torch model"""
torch_model = _get_torch_model("resnet50", training)
if torch_model:
path = f"{temp_dir}/test_runner_torch_{runner_cls.__name__}_{device}"
workspace = msc_utils.set_workspace(msc_utils.msc_dir(path, keep_history=False))
log_path = workspace.relpath("MSC_LOG", keep_history=False)
msc_utils.set_global_logger("critical", log_path)
input_info = [([1, 3, 224, 224], "float32")]
datas = [np.random.rand(*i[0]).astype(i[1]) for i in input_info]
torch_datas = [torch.from_numpy(d) for d in datas]
graph_model = fx.symbolic_trace(torch_model)
if training:
input_info = [([tvm.tir.Var("bz", "int64"), 3, 224, 224], "float32")]
with torch.no_grad():
golden = torch_model(*torch_datas)
mod = from_fx(graph_model, input_info)
runner = runner_cls(mod, device=device, training=training)
runner.build()
outputs = runner.run(datas, ret_type="list")
golden = [msc_utils.cast_array(golden)]
workspace.destory()
for gol_r, out_r in zip(golden, outputs):
np.testing.assert_allclose(gol_r, msc_utils.cast_array(out_r), atol=atol, rtol=rtol)
for training in [True, False]:
_test_from_torch(TVMRunner, "cpu", training=training)
for training in [True, False]:
_test_from_torch(TVMRunner, "cuda", training=training)