前端模型转换#

import set_env

定义简单模型:

import torch

class Conv2D1(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = torch.nn.Conv2d(3, 6, 7, bias=True)

    def forward(self, data):
        return self.conv(data)

MSCGraph 与 PyTorch 模型互转#

import numpy as np

import torch
from torch.nn import Module

import tvm.testing
from tvm.contrib.msc.framework.torch.frontend import translate
from tvm.contrib.msc.framework.torch import codegen
shape = 1, 3, 224, 224
input_info = [(shape, "float32")]
torch_model = Conv2D1()

torch 模型转换为 MSCGraph:

graph, weights = translate.from_torch(torch_model, input_info, via_relax=False)

MSCGraph 再转换会 torch 模型:

model = codegen.to_torch(graph, weights)
/media/pc/data/lxw/ai/tvm/python/tvm/contrib/msc/framework/torch/codegen/codegen.py:74: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  state_dict = torch.load(folder.relpath(graph.name + ".pth"))

验证一致性:

torch_datas = [torch.from_numpy(np.random.rand(*i[0]).astype(i[1])) for i in input_info]
with torch.no_grad():
    golden = torch_model(*torch_datas)
with torch.no_grad():
    if not graph.get_inputs():
        result = model()
    else:
        result = model(*torch_datas)
if not isinstance(golden, (list, tuple)):
    golden = [golden]
if not isinstance(result, (list, tuple)):
    result = [result]
assert len(golden) == len(result), f"golden {len(golden)} mismatch with result {len(result)}"
for gol_r, new_r in zip(golden, result):
    if isinstance(gol_r, torch.Tensor):
        tvm.testing.assert_allclose(
            gol_r.detach().numpy(), new_r.detach().numpy(), atol=1e-5, rtol=1e-5
        )
    else:
        assert gol_r == new_r

转换为 relay#

def _valid_target(target):
    if not target:
        return target
    if target == "ignore":
        return None
    if target == "cuda" and not tvm.cuda().exist:
        return None
    if isinstance(target, str):
        target = tvm.target.Target(target)
    return target
def _run_relax(relax_mod, target, datas):
    relax_mod = tvm.relax.transform.LegalizeOps()(relax_mod)
    with tvm.transform.PassContext(opt_level=3):
        relax_exec = tvm.relax.build(relax_mod, target)
        runnable = tvm.relax.VirtualMachine(relax_exec, tvm.cpu())
    res = runnable["main"](*datas)
    if isinstance(res, tvm.runtime.NDArray):
        return [res.asnumpy()]
    return [e.asnumpy() for e in res]
from tvm.relax.frontend.torch import from_fx
from tvm.relay.frontend import from_pytorch
from torch import fx
from tvm.contrib.msc.core.frontend import translate
from tvm.contrib.msc.framework.tvm import codegen as tvm_codegen
opt_config = None
codegen_config = None 
build_target=None
graph_model = fx.symbolic_trace(torch_model)
with torch.no_grad():
    expected = from_fx(graph_model, input_info)
expected = tvm.relax.transform.CanonicalizeBindings()(expected)

# graph from relay
datas = [np.random.rand(*i[0]).astype(i[1]) for i in input_info]
torch_datas = [torch.from_numpy(i) for i in datas]
with torch.no_grad():
    scripted_model = torch.jit.trace(torch_model, tuple(torch_datas)).eval()  # type: ignore
shape_list = [("input" + str(idx), i) for idx, i in enumerate(input_info)]
relay_mod, params = from_pytorch(scripted_model, shape_list)
graph, weights = translate.from_relay(relay_mod, params, opt_config=opt_config)
# to relax
codegen_config = codegen_config or {}
codegen_config.update({"explicit_name": False, "from_relay": True})
mod = tvm_codegen.to_relax(graph, weights, codegen_config)
if build_target:
    build_target = _valid_target(build_target)
    if not build_target:
        exit()
    tvm_datas = [tvm.nd.array(i) for i in datas]
    expected_res = _run_relax(expected, build_target, tvm_datas)
    if not graph.get_inputs():
        tvm_datas = []
    res = _run_relax(mod, build_target, tvm_datas)
    for exp_r, new_r in zip(expected_res, res):
        tvm.testing.assert_allclose(exp_r, new_r, atol=1e-5, rtol=1e-5)
else:
    tvm.ir.assert_structural_equal(mod, expected)

转换为 relax#

import tvm.testing
from tvm.relax.frontend.torch import from_fx
from tvm.contrib.msc.core.frontend import translate
from tvm.contrib.msc.framework.tvm import codegen as tvm_codegen


def _verify_model(torch_model, input_info, opt_config=None):
    graph_model = fx.symbolic_trace(torch_model)
    with torch.no_grad():
        orig_mod = from_fx(graph_model, input_info)

    target = "llvm"
    dev = tvm.cpu()
    args = [tvm.nd.array(np.random.random(size=shape).astype(dtype)) for shape, dtype in input_info]

    def _tvm_runtime_to_np(obj):
        if isinstance(obj, tvm.runtime.NDArray):
            return obj.numpy()
        elif isinstance(obj, tvm.runtime.ShapeTuple):
            return np.array(obj, dtype="int64")
        elif isinstance(obj, (list, tvm.ir.container.Array)):
            return [_tvm_runtime_to_np(item) for item in obj]
        elif isinstance(obj, tuple):
            return tuple(_tvm_runtime_to_np(item) for item in obj)
        else:
            return obj

    def _run_relax(relax_mod):
        relax_mod = tvm.relax.transform.LegalizeOps()(relax_mod)
        relax_exec = tvm.relax.build(relax_mod, target)
        vm_runner = tvm.relax.VirtualMachine(relax_exec, dev)
        res = vm_runner["main"](*args)

        return _tvm_runtime_to_np(res)

    rt_mod = tvm_codegen.to_relax(
        *translate.from_relax(orig_mod, opt_config=opt_config),
        codegen_config={"explicit_name": False},
    )

    orig_output = _run_relax(orig_mod)
    rt_output = _run_relax(rt_mod)
    tvm.testing.assert_allclose(orig_output, rt_output)
_verify_model(torch_model, input_info, opt_config=None)