TVMRunner

TVMRunner#

%cd ..
from pathlib import Path

temp_dir = Path(".temp")
temp_dir.mkdir(exist_ok=True)
/media/pc/data/lxw/ai/tvm-book/doc/tutorials/msc

构建前端模型:

import numpy as np
import torch
from torch import fx
from tvm.relax.frontend.torch import from_fx

def _get_torch_model(name, training=False):
    """Get model from torch vision"""

    # pylint: disable=import-outside-toplevel
    try:
        import torchvision

        model = getattr(torchvision.models, name)()
        if training:
            model = model.train()
        else:
            model = model.eval()
        return model
    except:  # pylint: disable=bare-except
        print("please install torchvision package")
        return None

构建并运行 TVMRunner#

from tvm.contrib.msc.framework.tvm.runtime import TVMRunner
import tvm
from tvm.contrib.msc.core import utils as msc_utils

def _test_from_torch(runner_cls, device, training=False, atol=1e-1, rtol=1e-1):
    """Test runner from torch model"""

    torch_model = _get_torch_model("resnet50", training)
    if torch_model:
        path = f"{temp_dir}/test_runner_torch_{runner_cls.__name__}_{device}"
        workspace = msc_utils.set_workspace(msc_utils.msc_dir(path, keep_history=False))
        log_path = workspace.relpath("MSC_LOG", keep_history=False)
        msc_utils.set_global_logger("critical", log_path)
        input_info = [([1, 3, 224, 224], "float32")]
        datas = [np.random.rand(*i[0]).astype(i[1]) for i in input_info]
        torch_datas = [torch.from_numpy(d) for d in datas]
        graph_model = fx.symbolic_trace(torch_model)
        if training:
            input_info = [([tvm.tir.Var("bz", "int64"), 3, 224, 224], "float32")]
        with torch.no_grad():
            golden = torch_model(*torch_datas)
            mod = from_fx(graph_model, input_info)
        runner = runner_cls(mod, device=device, training=training)
        runner.build()
        outputs = runner.run(datas, ret_type="list")
        golden = [msc_utils.cast_array(golden)]
        workspace.destory()
        for gol_r, out_r in zip(golden, outputs):
            np.testing.assert_allclose(gol_r, msc_utils.cast_array(out_r), atol=atol, rtol=rtol)
for training in [True, False]:
    _test_from_torch(TVMRunner, "cpu", training=training)
for training in [True, False]:
    _test_from_torch(TVMRunner, "cuda", training=training)