#
PyTorch ONNX Relax 测试#
import operator
import torch
import torch.nn.functional as F
from torch import fx
import torchvision
import tvm
from tvm import relax
import tvm.testing
from tvm.script import ir as I
from tvm.script import relax as R
from tvm.script import tir as T
from tvm.relax.frontend import detach_params
from tvm.relax.frontend.torch import from_fx
def verify_model(torch_model, input_info, binding, expected):
graph_model = fx.symbolic_trace(torch_model)
with torch.no_grad():
mod = from_fx(graph_model, input_info)
binding = {k: tvm.nd.array(v) for k, v in binding.items()}
expected = relax.transform.BindParams("main", binding)(expected)
tvm.ir.assert_structural_equal(mod, expected)
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 16, 3, bias=False)
self.conv2 = torch.nn.Conv2d(16, 32, 1, bias=False)
def forward(self, x):
x = self.conv(x)
x = F.interpolate(
x,
size=None,
scale_factor=(0.5, 0.5),
mode="nearest",
# align_corners=False,
)
x = x * 4.019027233123779
x = self.conv2(x)
return x
torch_model = M()
graph_model = fx.symbolic_trace(torch_model)
input_info = [([1, 3, 10, 10], "float32")]
with torch.no_grad():
mod = from_fx(graph_model, input_info)
from tvm.relax.dpl import *
@tvm.ir.transform.module_pass(opt_level=0, name="MulConv2dRewriter")
class MulConv2dRewriterPipeline:
def transform_module(self, mod, ctx):
x = wildcard()
scale = is_const()
multiply = is_op("relax.multiply")(x, scale)
weight = is_const()
pattern = is_op("relax.nn.conv2d")(multiply, weight)
def rewriter(_, matches):
x_ = matches[x]
w_ = matches[weight]
w_ = w_ * matches[scale]
o = R.nn.conv2d(x_, w_)
return o
mod["main"] = rewrite_call(pattern, rewriter, mod["main"])
mod = relax.transform.FoldConstant()(mod)
return mod
from copy import deepcopy
ori_mod = deepcopy(mod)
mod = MulConv2dRewriterPipeline()(mod)
mod.show()
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R
@I.ir_module
class Module:
@R.function
def main(inp_0: R.Tensor((1, 3, 10, 10), dtype="float32")) -> R.Tensor((1, 32, 4, 4), dtype="float32"):
with R.dataflow():
lv: R.Tensor((1, 16, 8, 8), dtype="float32") = R.nn.conv2d(inp_0, metadata["relax.expr.Constant"][0], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
lv1: R.Tensor((1, 16, 4, 4), dtype="float32") = R.image.resize2d(lv, R.shape([4, 4]), roi=[T.float32(0.0), T.float32(0.0), T.float32(0.0), T.float32(0.0)], layout="NCHW", method="nearest_neighbor", coordinate_transformation_mode="asymmetric", rounding_method="round", cubic_alpha=-0.5, cubic_exclude=0, extrapolation_value=0.0, out_dtype="void")
gv: R.Tensor((1, 32, 4, 4), dtype="float32") = R.nn.conv2d(lv1, metadata["relax.expr.Constant"][1], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
R.output(gv)
return gv
# Metadata omitted. Use show_meta=True in script() method to show it.
tvm_model = relax.transform.DecomposeOpsForInference()(mod)
# Legalize any relax ops into tensorir.
tvm_model = relax.transform.LegalizeOps()(tvm_model)
# Separate model from parameters.
tvm_model, params = relax.frontend.detach_params(tvm_model)
# Compile the relax graph into a VM then run.
with tvm.transform.PassContext(opt_level=3):
ex = tvm.compile(tvm_model, target="llvm")
vm = relax.VirtualMachine(ex, tvm.cpu())
import numpy as np
dev = tvm.cpu()
inputs_np = [np.random.rand(1, 3, 10, 10).astype("float32")]
inputs = [tvm.nd.array(inp, dev) for inp in inputs_np]
# Run model and check outputs.
vm.set_input("main", *inputs)
vm.invoke_stateful("main")
tvm_output = vm.get_outputs("main")
with torch.no_grad():
torch_output = torch_model(torch.from_numpy(inputs_np[0]))
np.testing.assert_allclose(tvm_output.numpy(), torch_output.numpy(), rtol=1e-7, atol=1e-5)