PyTorch Relay 前端测试

PyTorch Relay 前端测试#

import set_env
import torch
import torch.nn.functional as F
import tvm
from tvm import relay
from tvm.relay.dataflow_pattern import *
from tvm.relay.dataflow_pattern import is_constant as is_const
from tvm.relay.testing import run_opt_pass

class RewriteMulConv2d(DFPatternCallback):
    def __init__(self):
        super().__init__()
        self.x = wildcard()
        self.scale = is_const()
        self.multiply = is_op("multiply")(self.x, self.scale)
        self.weight = is_const()
        self.conv = is_op("nn.conv2d")(self.multiply, self.weight)
        self.pattern = self.conv

    def callback(self, pre, post, matches):
        x_ = matches[self.x][0]
        w_ = matches[self.weight][0]
        w_ = w_ * matches[self.scale][0]
        conv = matches[self.conv][0]
        o = relay.conv2d(x_, w_, **conv.attrs)
        return o

@tvm.ir.transform.module_pass(opt_level=3, name="MulConv2dRewriter")
class MulConv2dRewriterPipeline:
    def transform_module(self, mod, ctx):
        mod["main"] = rewrite(RewriteMulConv2d(), mod["main"])
        mod = relay.transform.FoldConstant()(mod)
        return mod

class M(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = torch.nn.Conv2d(3, 16, 3, bias=False)
        self.conv2 = torch.nn.Conv2d(16, 32, 1, bias=False)
        self.scale = torch.tensor(4.019027233123779, dtype=torch.float32)

    def forward(self, x):
        x = self.conv(x)
        x = F.interpolate(
            x,
            size=None,
            scale_factor=(0.5, 0.5),
            mode="nearest",
            # align_corners=False,
        )
        x = self.scale * x 
        x = self.conv2(x)
        return x


torch_model = M()
shape = (1, 3, 10, 10)
input_shapes = [("x", shape)]
with torch.no_grad():
    trace = torch.jit.trace(torch_model, [torch.randn(*shape)])
    mod, params = relay.frontend.from_pytorch(trace, input_shapes)
print(mod["main"])
with tvm.transform.PassContext(opt_level=3):
    mod = relay.quantize.prerequisite_optimize(mod, params)
    mod = MulConv2dRewriterPipeline()(mod)
fn (%x: Tensor[(1, 3, 10, 10), float32] /* span=aten::_convolution_0.x:0:0 */, %aten::_convolution_0.weight: Tensor[(16, 3, 3, 3), float32] /* span=aten::_convolution_0.weight:0:0 */, %aten::_convolution_1.weight: Tensor[(32, 16, 1, 1), float32] /* span=aten::_convolution_1.weight:0:0 */) {
  %0 = nn.conv2d(%x, %aten::_convolution_0.weight, padding=[0, 0, 0, 0], channels=16, kernel_size=[3, 3]) /* span=aten::_convolution_0:0:0 */;
  %1 = image.resize2d(%0, size=[4, 4], roi=[0f, 0f, 0f, 0f], method="nearest_neighbor", coordinate_transformation_mode="asymmetric", rounding_method="", cubic_alpha=-0.75f) /* span=aten::upsample_nearest2d_0:0:0 */;
  %2 = multiply(4.01903f /* span=aten::mul_0:0:0 */, %1) /* span=aten::mul_0:0:0 */;
  nn.conv2d(%2, %aten::_convolution_1.weight, padding=[0, 0, 0, 0], channels=32, kernel_size=[1, 1]) /* span=aten::_convolution_1:0:0 */
}
print(mod["main"])
fn (%x: Tensor[(1, 3, 10, 10), float32] /* ty=Tensor[(1, 3, 10, 10), float32] span=aten::_convolution_0.x:0:0 */) -> Tensor[(1, 32, 4, 4), float32] {
  %0 = nn.conv2d(%x, meta[relay.Constant][0] /* ty=Tensor[(16, 3, 3, 3), float32] */, padding=[0, 0, 0, 0], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 8, 8), float32] span=aten::_convolution_0:0:0 */;
  %1 = image.resize2d(%0, size=[4, 4], roi=[0f, 0f, 0f, 0f], method="nearest_neighbor", coordinate_transformation_mode="asymmetric", rounding_method="", cubic_alpha=-0.75f) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten::upsample_nearest2d_0:0:0 */;
  nn.conv2d(%1, meta[relay.Constant][1] /* ty=Tensor[(32, 16, 1, 1), float32] */, padding=[0, 0, 0, 0], channels=32, kernel_size=[1, 1]) /* ty=Tensor[(1, 32, 4, 4), float32] */
} /* ty=fn (Tensor[(1, 3, 10, 10), float32]) -> Tensor[(1, 32, 4, 4), float32] */
import numpy as np
inputs_np = [np.random.rand(1, 3, 10, 10).astype("float32")]
evaluate = relay.create_executor("debug", mod=mod, device=tvm.cpu(), target="llvm").evaluate()
tvm_output = evaluate(inputs_np[0])
One or more operators have not been tuned. Please tune your model for better performance. Use DEBUG logging level to see more details.
with torch.no_grad():
    torch_output = torch_model(torch.from_numpy(inputs_np[0]))
np.testing.assert_allclose(tvm_output.numpy(), torch_output.numpy(), rtol=1e-7, atol=1e-5)
from tvm.relay.analysis.analysis import extract_intermdeiate_expr
mod = extract_intermdeiate_expr(mod, 35)
mod = relay.transform.InferType()(mod)
print(mod["main"])
fn (%input.1: Tensor[(1, 3, 128, 128), float32] /* ty=Tensor[(1, 3, 128, 128), float32] span=Conv_9.input.1:0:0 */) -> Tensor[(1, 48, 64, 64), float32] {
  %0 = nn.conv2d(%input.1, meta[relay.Constant][0] /* ty=Tensor[(32, 3, 3, 3), float32] span=Conv_9.conv_in.weight:0:0 */, padding=[1, 1, 1, 1], channels=32, kernel_size=[3, 3]) /* ty=Tensor[(1, 32, 128, 128), float32] span=Conv_9:0:0 */;
  %1 = nn.conv2d(%0, meta[relay.Constant][1] /* ty=Tensor[(32, 1, 1, 3), float32] span=Conv_10.body.0.body.0.dau_top.body.0.weight:0:0 */, padding=[0, 1, 0, 1], groups=32, channels=32, kernel_size=[1, 3]) /* ty=Tensor[(1, 32, 128, 128), float32] span=Conv_10:0:0 */;
  %2 = nn.conv2d(%1, meta[relay.Constant][2] /* ty=Tensor[(32, 32, 1, 1), float32] span=Conv_11.body.0.body.0.dau_top.body.1.weight:0:0 */, padding=[0, 0, 0, 0], channels=32, kernel_size=[1, 1]) /* ty=Tensor[(1, 32, 128, 128), float32] span=Conv_11:0:0 */;
  %3 = broadcast_to_like(meta[relay.Constant][3] /* ty=Tensor[(32, 1, 1), float32] span=PRelu_12.onnx::PRelu_175:0:0 */, %2) /* ty=Tensor[(1, 32, 128, 128), float32] span=PRelu_12:0:0 */;
  %4 = reshape(%2, newshape=[-1]) /* ty=Tensor[(524288), float32] span=PRelu_12:0:0 */;
  %5 = reshape(%3, newshape=[-1]) /* ty=Tensor[(524288), float32] span=PRelu_12:0:0 */;
  %6 = nn.prelu(%4, %5, axis=0) /* ty=Tensor[(524288), float32] span=PRelu_12:0:0 */;
  %7 = reshape(%6, newshape=[1, 32, 128, 128]) /* ty=Tensor[(1, 32, 128, 128), float32] span=PRelu_12:0:0 */;
  %8 = nn.conv2d(%7, meta[relay.Constant][4] /* ty=Tensor[(32, 1, 1, 3), float32] span=Conv_13.body.0.body.0.dau_top.body.3.weight:0:0 */, padding=[0, 1, 0, 1], groups=32, channels=32, kernel_size=[1, 3]) /* ty=Tensor[(1, 32, 128, 128), float32] span=Conv_13:0:0 */;
  %9 = nn.conv2d(%8, meta[relay.Constant][5] /* ty=Tensor[(32, 32, 1, 1), float32] span=Conv_14.body.0.body.0.dau_top.body.4.weight:0:0 */, padding=[0, 0, 0, 0], channels=32, kernel_size=[1, 1]) /* ty=Tensor[(1, 32, 128, 128), float32] span=Conv_14:0:0 */;
  %10 = nn.global_avg_pool2d(%9) /* ty=Tensor[(1, 32, 1, 1), float32] span=GlobalAveragePool_15:0:0 */;
  %11 = nn.conv2d(%10, meta[relay.Constant][6] /* ty=Tensor[(2, 32, 1, 1), float32] span=Conv_16.body.0.body.0.dau_top.gcnet.se.1.weight:0:0 */, padding=[0, 0, 0, 0], channels=2, kernel_size=[1, 1]) /* ty=Tensor[(1, 2, 1, 1), float32] span=Conv_16:0:0 */;
  %12 = broadcast_to_like(meta[relay.Constant][7] /* ty=Tensor[(2, 1, 1), float32] span=PRelu_17.onnx::PRelu_176:0:0 */, %11) /* ty=Tensor[(1, 2, 1, 1), float32] span=PRelu_17:0:0 */;
  %13 = reshape(%11, newshape=[-1]) /* ty=Tensor[(2), float32] span=PRelu_17:0:0 */;
  %14 = reshape(%12, newshape=[-1]) /* ty=Tensor[(2), float32] span=PRelu_17:0:0 */;
  %15 = nn.prelu(%13, %14, axis=0) /* ty=Tensor[(2), float32] span=PRelu_17:0:0 */;
  %16 = reshape(%15, newshape=[1, 2, 1, 1]) /* ty=Tensor[(1, 2, 1, 1), float32] span=PRelu_17:0:0 */;
  %17 = nn.conv2d(%16, meta[relay.Constant][8] /* ty=Tensor[(32, 2, 1, 1), float32] span=Conv_18.body.0.body.0.dau_top.gcnet.se.3.weight:0:0 */, padding=[0, 0, 0, 0], channels=32, kernel_size=[1, 1]) /* ty=Tensor[(1, 32, 1, 1), float32] span=Conv_18:0:0 */;
  %18 = sigmoid(%17) /* ty=Tensor[(1, 32, 1, 1), float32] span=Sigmoid_19:0:0 */;
  %19 = nn.conv2d(%18, meta[relay.Constant][9] /* ty=Tensor[(32, 32, 1, 1), float32] span=Conv_20.body.0.body.0.dau_top.gcnet.channel_add_conv.0.weight:0:0 */, padding=[0, 0, 0, 0], channels=32, kernel_size=[1, 1]) /* ty=Tensor[(1, 32, 1, 1), float32] span=Conv_20:0:0 */;
  %20 = broadcast_to_like(meta[relay.Constant][10] /* ty=Tensor[(32, 1, 1), float32] span=PRelu_21.onnx::PRelu_177:0:0 */, %19) /* ty=Tensor[(1, 32, 1, 1), float32] span=PRelu_21:0:0 */;
  %21 = reshape(%19, newshape=[-1]) /* ty=Tensor[(32), float32] span=PRelu_21:0:0 */;
  %22 = reshape(%20, newshape=[-1]) /* ty=Tensor[(32), float32] span=PRelu_21:0:0 */;
  %23 = nn.prelu(%21, %22, axis=0) /* ty=Tensor[(32), float32] span=PRelu_21:0:0 */;
  %24 = reshape(%23, newshape=[1, 32, 1, 1]) /* ty=Tensor[(1, 32, 1, 1), float32] span=PRelu_21:0:0 */;
  %25 = nn.conv2d(%24, meta[relay.Constant][11] /* ty=Tensor[(32, 32, 1, 1), float32] span=Conv_22.body.0.body.0.dau_top.gcnet.channel_add_conv.2.weight:0:0 */, padding=[0, 0, 0, 0], channels=32, kernel_size=[1, 1]) /* ty=Tensor[(1, 32, 1, 1), float32] span=Conv_22:0:0 */;
  %26 = add(%9, %25) /* ty=Tensor[(1, 32, 128, 128), float32] span=Add_23:0:0 */;
  %27 = broadcast_to_like(meta[relay.Constant][3] /* ty=Tensor[(32, 1, 1), float32] span=PRelu_12.onnx::PRelu_175:0:0 */, %26) /* ty=Tensor[(1, 32, 128, 128), float32] span=PRelu_24:0:0 */;
  %28 = reshape(%26, newshape=[-1]) /* ty=Tensor[(524288), float32] span=PRelu_24:0:0 */;
  %29 = reshape(%27, newshape=[-1]) /* ty=Tensor[(524288), float32] span=PRelu_24:0:0 */;
  %30 = nn.prelu(%28, %29, axis=0) /* ty=Tensor[(524288), float32] span=PRelu_24:0:0 */;
  %31 = reshape(%30, newshape=[1, 32, 128, 128]) /* ty=Tensor[(1, 32, 128, 128), float32] span=PRelu_24:0:0 */;
  %32 = add(%31, %0) /* ty=Tensor[(1, 32, 128, 128), float32] span=Add_25:0:0 */;
  %33 = image.resize2d(%32, size=[64, 64], roi=[0f, 0f, 0f, 0f], method="nearest_neighbor", coordinate_transformation_mode="asymmetric", rounding_method="floor", cubic_alpha=-0.75f) /* ty=Tensor[(1, 32, 64, 64), float32] span=Resize_27:0:0 */;
  %34 = multiply(4.01903f /* ty=float32 span=Mul_28.body.0.body.0.down2.0.alpha:0:0 */, %33) /* ty=Tensor[(1, 32, 64, 64), float32] span=Mul_28:0:0 */;
  nn.conv2d(%34, meta[relay.Constant][12] /* ty=Tensor[(48, 32, 1, 1), float32] span=Conv_29.body.0.body.0.down2.1.weight:0:0 */, padding=[0, 0, 0, 0], channels=48, kernel_size=[1, 1]) /* ty=Tensor[(1, 48, 64, 64), float32] span=Conv_29:0:0 */
} /* ty=fn (Tensor[(1, 3, 128, 128), float32]) -> Tensor[(1, 48, 64, 64), float32] */