# PyTorch Relay 前端测试


In [1]:
import set_env

In [3]:
import torch
import torch.nn.functional as F
import tvm
from tvm import relay
from tvm.relay.dataflow_pattern import *
from tvm.relay.dataflow_pattern import is_constant as is_const
from tvm.relay.testing import run_opt_pass

class RewriteMulConv2d(DFPatternCallback):
    def __init__(self):
        super().__init__()
        self.x = wildcard()
        self.scale = is_const()
        self.multiply = is_op("multiply")(self.x, self.scale)
        self.weight = is_const()
        self.conv = is_op("nn.conv2d")(self.multiply, self.weight)
        self.pattern = self.conv

    def callback(self, pre, post, matches):
        x_ = matches[self.x][0]
        w_ = matches[self.weight][0]
        w_ = w_ * matches[self.scale][0]
        conv = matches[self.conv][0]
        o = relay.conv2d(x_, w_, **conv.attrs)
        return o

@tvm.ir.transform.module_pass(opt_level=3, name="MulConv2dRewriter")
class MulConv2dRewriterPipeline:
    def transform_module(self, mod, ctx):
        mod["main"] = rewrite(RewriteMulConv2d(), mod["main"])
        mod = relay.transform.FoldConstant()(mod)
        return mod

class M(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = torch.nn.Conv2d(3, 16, 3, bias=False)
        self.conv2 = torch.nn.Conv2d(16, 32, 1, bias=False)
        self.scale = torch.tensor(4.019027233123779, dtype=torch.float32)

    def forward(self, x):
        x = self.conv(x)
        x = F.interpolate(
            x,
            size=None,
            scale_factor=(0.5, 0.5),
            mode="nearest",
            # align_corners=False,
        )
        x = self.scale * x 
        x = self.conv2(x)
        return x


torch_model = M()
shape = (1, 3, 10, 10)
input_shapes = [("x", shape)]
with torch.no_grad():
    trace = torch.jit.trace(torch_model, [torch.randn(*shape)])
    mod, params = relay.frontend.from_pytorch(trace, input_shapes)
print(mod["main"])
with tvm.transform.PassContext(opt_level=3):
    mod = relay.quantize.prerequisite_optimize(mod, params)
    mod = MulConv2dRewriterPipeline()(mod)

fn (%x: Tensor[(1, 3, 10, 10), float32] /* span=aten::_convolution_0.x:0:0 */, %aten::_convolution_0.weight: Tensor[(16, 3, 3, 3), float32] /* span=aten::_convolution_0.weight:0:0 */, %aten::_convolution_1.weight: Tensor[(32, 16, 1, 1), float32] /* span=aten::_convolution_1.weight:0:0 */) {
  %0 = nn.conv2d(%x, %aten::_convolution_0.weight, padding=[0, 0, 0, 0], channels=16, kernel_size=[3, 3]) /* span=aten::_convolution_0:0:0 */;
  %1 = image.resize2d(%0, size=[4, 4], roi=[0f, 0f, 0f, 0f], method="nearest_neighbor", coordinate_transformation_mode="asymmetric", rounding_method="", cubic_alpha=-0.75f) /* span=aten::upsample_nearest2d_0:0:0 */;
  %2 = multiply(4.01903f /* span=aten::mul_0:0:0 */, %1) /* span=aten::mul_0:0:0 */;
  nn.conv2d(%2, %aten::_convolution_1.weight, padding=[0, 0, 0, 0], channels=32, kernel_size=[1, 1]) /* span=aten::_convolution_1:0:0 */
}


In [4]:
print(mod["main"])

fn (%x: Tensor[(1, 3, 10, 10), float32] /* ty=Tensor[(1, 3, 10, 10), float32] span=aten::_convolution_0.x:0:0 */) -> Tensor[(1, 32, 4, 4), float32] {
  %0 = nn.conv2d(%x, meta[relay.Constant][0] /* ty=Tensor[(16, 3, 3, 3), float32] */, padding=[0, 0, 0, 0], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 8, 8), float32] span=aten::_convolution_0:0:0 */;
  %1 = image.resize2d(%0, size=[4, 4], roi=[0f, 0f, 0f, 0f], method="nearest_neighbor", coordinate_transformation_mode="asymmetric", rounding_method="", cubic_alpha=-0.75f) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten::upsample_nearest2d_0:0:0 */;
  nn.conv2d(%1, meta[relay.Constant][1] /* ty=Tensor[(32, 16, 1, 1), float32] */, padding=[0, 0, 0, 0], channels=32, kernel_size=[1, 1]) /* ty=Tensor[(1, 32, 4, 4), float32] */
} /* ty=fn (Tensor[(1, 3, 10, 10), float32]) -> Tensor[(1, 32, 4, 4), float32] */



In [5]:
import numpy as np
inputs_np = [np.random.rand(1, 3, 10, 10).astype("float32")]
evaluate = relay.create_executor("debug", mod=mod, device=tvm.cpu(), target="llvm").evaluate()
tvm_output = evaluate(inputs_np[0])

One or more operators have not been tuned. Please tune your model for better performance. Use DEBUG logging level to see more details.


In [6]:
with torch.no_grad():
    torch_output = torch_model(torch.from_numpy(inputs_np[0]))
np.testing.assert_allclose(tvm_output.numpy(), torch_output.numpy(), rtol=1e-7, atol=1e-5)

In [19]:
from tvm.relay.analysis.analysis import extract_intermdeiate_expr
mod = extract_intermdeiate_expr(mod, 35)
mod = relay.transform.InferType()(mod)

In [20]:
print(mod["main"])

fn (%input.1: Tensor[(1, 3, 128, 128), float32] /* ty=Tensor[(1, 3, 128, 128), float32] span=Conv_9.input.1:0:0 */) -> Tensor[(1, 48, 64, 64), float32] {
  %0 = nn.conv2d(%input.1, meta[relay.Constant][0] /* ty=Tensor[(32, 3, 3, 3), float32] span=Conv_9.conv_in.weight:0:0 */, padding=[1, 1, 1, 1], channels=32, kernel_size=[3, 3]) /* ty=Tensor[(1, 32, 128, 128), float32] span=Conv_9:0:0 */;
  %1 = nn.conv2d(%0, meta[relay.Constant][1] /* ty=Tensor[(32, 1, 1, 3), float32] span=Conv_10.body.0.body.0.dau_top.body.0.weight:0:0 */, padding=[0, 1, 0, 1], groups=32, channels=32, kernel_size=[1, 3]) /* ty=Tensor[(1, 32, 128, 128), float32] span=Conv_10:0:0 */;
  %2 = nn.conv2d(%1, meta[relay.Constant][2] /* ty=Tensor[(32, 32, 1, 1), float32] span=Conv_11.body.0.body.0.dau_top.body.1.weight:0:0 */, padding=[0, 0, 0, 0], channels=32, kernel_size=[1, 1]) /* ty=Tensor[(1, 32, 128, 128), float32] span=Conv_11:0:0 */;
  %3 = broadcast_to_like(meta[relay.Constant][3] /* ty=Tensor[(32, 1, 1), float32] 