PyTorch pass#

torch._C._jit_pass_lower_all_tuples#

来源:[Bug] [Frontend][Pytorch] Relay IR is inconsistent with that of the original model

torch._C._jit_pass_lower_all_tuples 是 PyTorch 的内部函数,用于将 Python 组转换为 TorchScript 元组。这个函数的主要作用是在 TorchScript 编译过程中,将 Python 代码中的元组操作转换为 TorchScript 元组操作,以便在 TorchScript 环境中执行。

具体来说,torch._C._jit_pass_lower_all_tuples 函数会遍历 TorchScript IR(Intermediate Representation,中间表示)中的所有节点,找到所有使用 Python 元组的算子,并将它们替换为对应的 TorchScript 元组算子。这样,在后续的 TorchScript 优化和变换过程中,就可以直接处理这些元组算子,而不需要额外的 Python 到 TorchScript 的桥接代码。

import set_env
import torch
from torch import nn
from torch.quantization import QuantStub, DeQuantStub
from torch.quantization import prepare_qat, get_default_qat_qconfig, convert
from tvm import relay
import numpy as np


class ConvBnRelu(nn.Module):
    def __init__(self, inp, oup, kernel_size=3, stride=1, padding=1, bias=True, groups=1):
        super(ConvBnRelu, self).__init__()
        if groups > 1:
            self.conv = nn.Conv2d(inp, inp, kernel_size, stride, padding, bias=bias, groups=groups)
            self.bn = nn.BatchNorm2d(inp)
        else:
            self.conv = nn.Conv2d(inp, oup, kernel_size, stride, padding, bias=bias, groups=groups)
            self.bn = nn.BatchNorm2d(oup)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, inputs):
        x = self.conv(inputs)
        x = self.bn(x)
        x = self.relu(x)
        return x
def conv_bn(inp, oup, stride=1, width_multiplier=1):
    return ConvBnRelu(inp, oup, kernel_size=3, stride=stride, padding=1, bias=False)


def conv_dw(inp, oup, stride, width_multiplier=1, padding=1):
    dw_block = nn.Sequential()
    depth_wise = ConvBnRelu(inp, oup, kernel_size=3, stride=stride, padding=padding, bias=False, groups=inp)
    point_wise = ConvBnRelu(inp, oup, kernel_size=1, stride=1, padding=0, bias=False)

    dw_block.add_module('depth_wise', depth_wise)
    dw_block.add_module('point_wise', point_wise)

    return dw_block

class Backbone(nn.Module):
    def __init__(self, width_multiplier=1):
        super(Backbone, self).__init__()
        self.width_multiplier = width_multiplier
        self.conv1 = conv_bn(3, 16, 2, self.width_multiplier)
        self.conv2 = conv_dw(16, 32, 1, self.width_multiplier)
    
    def forward(self, inputs):
        x1 = self.conv1(inputs)
        x2 = self.conv2(x1)
        return [x1, x2]

class QuantizableBackbone(nn.Module):
    def __init__(self, inputsize=(128, 128)):
        super(QuantizableBackbone, self).__init__()
        self.quant = QuantStub()
        self.dequant = DeQuantStub()
        self.backbone = Backbone()

    def fuse_model(self):
        for idx, m in enumerate(self.modules()):
            if type(m) == ConvBnRelu:
                torch.quantization.fuse_modules(m, ['conv', 'bn', 'relu'], inplace=True)

    def forward(self, input):
        input = self.quant(input)
        y0, y1 = self.backbone(input)
        y0 = self.dequant(y0)
        y1 = self.dequant(y1)
        return y0, y1

fp32_input = torch.randn(1, 3, 128, 128)
model = QuantizableBackbone()
model.eval()
model.fuse_model()
model.qconfig = get_default_qat_qconfig("qnnpack")
model.train()
prepare_qat(model, inplace=True)
model.eval()
model(fp32_input)

model_int8 = torch.quantization.convert(model, inplace=True)
script_module = torch.jit.trace(model_int8, fp32_input).eval()

input_infos = [("input", (fp32_input.shape, "float32"))]
mod, _ = relay.frontend.from_pytorch(script_module, input_infos)
print(mod["main"])
output = mod["main"].body
assert isinstance(output, relay.Tuple) and len(output) == 2
dq1, dq2 = output
assert str(dq1.op) == str(dq2.op) == 'Op(qnn.dequantize)'
scale1 = dq1.args[1].data.numpy().item()
scale2 = dq2.args[1].data.numpy().item()
assert scale1 != scale2
fn (%input: Tensor[(1, 3, 128, 128), float32] /* span=aten::quantize_per_tensor_0.input:0:0 */, %backbone.conv1.conv_weight: Tensor[(16, 3, 3, 3), float32] /* span=quantized::conv2d_relu_0:0:0 */, %backbone.conv1.conv_bias: Tensor[(16), float32] /* span=quantized::conv2d_relu_0:0:0 */, %backbone.conv2.depth_wise.conv_weight: Tensor[(16, 1, 3, 3), float32] /* span=quantized::conv2d_relu_1:0:0 */, %backbone.conv2.depth_wise.conv_bias: Tensor[(16), float32] /* span=quantized::conv2d_relu_1:0:0 */, %backbone.conv2.point_wise.conv_weight: Tensor[(32, 16, 1, 1), float32] /* span=quantized::conv2d_relu_2:0:0 */, %backbone.conv2.point_wise.conv_bias: Tensor[(32), float32] /* span=quantized::conv2d_relu_2:0:0 */) {
  %0 = qnn.quantize(%input, 0.0347108f /* span=aten::quantize_per_tensor_0:0:0 */, 125 /* span=aten::quantize_per_tensor_0:0:0 */, out_dtype="uint8", axis=1) /* span=aten::quantize_per_tensor_0:0:0 */;
  %1 = nn.pad(%0, 125f /* span=quantized::conv2d_relu_0:0:0 */, pad_width=[[0, 0], [0, 0], [1, 1], [1, 1]]) /* span=quantized::conv2d_relu_0:0:0 */;
  %2 = qnn.quantize(%backbone.conv1.conv_weight, 0.00150606f /* span=quantized::conv2d_relu_0:0:0 */, 0 /* span=quantized::conv2d_relu_0:0:0 */, out_dtype="int8", axis=0) /* span=quantized::conv2d_relu_0:0:0 */;
  %3 = qnn.conv2d(%1, %2, 125 /* span=quantized::conv2d_relu_0:0:0 */, 0 /* span=quantized::conv2d_relu_0:0:0 */, 0.0347108f /* span=quantized::conv2d_relu_0:0:0 */, 0.00150606f /* span=quantized::conv2d_relu_0:0:0 */, strides=[2, 2], padding=[0, 0, 0, 0], channels=16, kernel_size=[3, 3], out_dtype="int32") /* span=quantized::conv2d_relu_0:0:0 */;
  %4 = qnn.quantize(%backbone.conv1.conv_bias, 5.22766e-05f /* span=quantized::conv2d_relu_0:0:0 */, 0 /* span=quantized::conv2d_relu_0:0:0 */, out_dtype="int32", axis=0) /* span=quantized::conv2d_relu_0:0:0 */;
  %5 = nn.bias_add(%3, %4) /* span=quantized::conv2d_relu_0:0:0 */;
  %6 = qnn.requantize(%5, 5.22766e-05f /* span=quantized::conv2d_relu_0:0:0 */, 0 /* span=quantized::conv2d_relu_0:0:0 */, 0.0132984f /* span=quantized::conv2d_relu_0:0:0 */, 0 /* span=quantized::conv2d_relu_0:0:0 */, axis=1, out_dtype="int32") /* span=quantized::conv2d_relu_0:0:0 */;
  %7 = clip(%6, a_min=0f, a_max=255f) /* span=quantized::conv2d_relu_0:0:0 */;
  %8 = cast(%7, dtype="uint8") /* span=quantized::conv2d_relu_0:0:0 */;
  %9 = nn.pad(%8, 0f /* span=quantized::conv2d_relu_1:0:0 */, pad_width=[[0, 0], [0, 0], [1, 1], [1, 1]]) /* span=quantized::conv2d_relu_1:0:0 */;
  %10 = qnn.quantize(%backbone.conv2.depth_wise.conv_weight, 0.00256311f /* span=quantized::conv2d_relu_1:0:0 */, 0 /* span=quantized::conv2d_relu_1:0:0 */, out_dtype="int8", axis=0) /* span=quantized::conv2d_relu_1:0:0 */;
  %11 = qnn.conv2d(%9, %10, 0 /* span=quantized::conv2d_relu_1:0:0 */, 0 /* span=quantized::conv2d_relu_1:0:0 */, 0.0132984f /* span=quantized::conv2d_relu_1:0:0 */, 0.00256311f /* span=quantized::conv2d_relu_1:0:0 */, padding=[0, 0, 0, 0], groups=16, channels=16, kernel_size=[3, 3], out_dtype="int32") /* span=quantized::conv2d_relu_1:0:0 */;
  %12 = qnn.quantize(%backbone.conv2.depth_wise.conv_bias, 3.40854e-05f /* span=quantized::conv2d_relu_1:0:0 */, 0 /* span=quantized::conv2d_relu_1:0:0 */, out_dtype="int32", axis=0) /* span=quantized::conv2d_relu_1:0:0 */;
  %13 = nn.bias_add(%11, %12) /* span=quantized::conv2d_relu_1:0:0 */;
  %14 = qnn.requantize(%13, 3.40854e-05f /* span=quantized::conv2d_relu_1:0:0 */, 0 /* span=quantized::conv2d_relu_1:0:0 */, 0.00509362f /* span=quantized::conv2d_relu_1:0:0 */, 0 /* span=quantized::conv2d_relu_1:0:0 */, axis=1, out_dtype="int32") /* span=quantized::conv2d_relu_1:0:0 */;
  %15 = clip(%14, a_min=0f, a_max=255f) /* span=quantized::conv2d_relu_1:0:0 */;
  %16 = cast(%15, dtype="uint8") /* span=quantized::conv2d_relu_1:0:0 */;
  %17 = qnn.quantize(%backbone.conv2.point_wise.conv_weight, 0.00195794f /* span=quantized::conv2d_relu_2:0:0 */, 0 /* span=quantized::conv2d_relu_2:0:0 */, out_dtype="int8", axis=0) /* span=quantized::conv2d_relu_2:0:0 */;
  %18 = qnn.conv2d(%16, %17, 0 /* span=quantized::conv2d_relu_2:0:0 */, 0 /* span=quantized::conv2d_relu_2:0:0 */, 0.00509362f /* span=quantized::conv2d_relu_2:0:0 */, 0.00195794f /* span=quantized::conv2d_relu_2:0:0 */, padding=[0, 0, 0, 0], channels=32, kernel_size=[1, 1], out_dtype="int32") /* span=quantized::conv2d_relu_2:0:0 */;
  %19 = qnn.quantize(%backbone.conv2.point_wise.conv_bias, 9.97299e-06f /* span=quantized::conv2d_relu_2:0:0 */, 0 /* span=quantized::conv2d_relu_2:0:0 */, out_dtype="int32", axis=0) /* span=quantized::conv2d_relu_2:0:0 */;
  %20 = nn.bias_add(%18, %19) /* span=quantized::conv2d_relu_2:0:0 */;
  %21 = qnn.requantize(%20, 9.97299e-06f /* span=quantized::conv2d_relu_2:0:0 */, 0 /* span=quantized::conv2d_relu_2:0:0 */, 0.00208748f /* span=quantized::conv2d_relu_2:0:0 */, 0 /* span=quantized::conv2d_relu_2:0:0 */, axis=1, out_dtype="int32") /* span=quantized::conv2d_relu_2:0:0 */;
  %22 = clip(%21, a_min=0f, a_max=255f) /* span=quantized::conv2d_relu_2:0:0 */;
  %23 = cast(%22, dtype="uint8") /* span=quantized::conv2d_relu_2:0:0 */;
  %24 = qnn.dequantize(%8, 0.0132984f /* span=aten::dequantize_0:0:0 */, 0 /* span=aten::dequantize_0:0:0 */, out_dtype="float32") /* span=aten::dequantize_0:0:0 */;
  %25 = qnn.dequantize(%23, 0.00208748f /* span=aten::dequantize_1:0:0 */, 0 /* span=aten::dequantize_1:0:0 */, out_dtype="float32") /* span=aten::dequantize_1:0:0 */;
  (%24, %25)
}
torch.jit.save(torch.jit.trace(model_int8, fp32_input), "quantized_model.pt")
model = torch.jit.load("quantized_model.pt")
# model = torch.jit.trace(model, fp32_input).eval()
model = torch.jit.script(model)
mod, params = relay.frontend.from_pytorch(model, input_infos)
print(mod["main"])
fn (%input: Tensor[(1, 3, 128, 128), float32] /* span=aten::quantize_per_tensor_0.input:0:0 */, %backbone.conv1.conv_weight: Tensor[(16, 3, 3, 3), float32] /* span=quantized::conv2d_relu_0:0:0 */, %backbone.conv1.conv_bias: Tensor[(16), float32] /* span=quantized::conv2d_relu_0:0:0 */, %backbone.conv2.depth_wise.conv_weight: Tensor[(16, 1, 3, 3), float32] /* span=quantized::conv2d_relu_1:0:0 */, %backbone.conv2.depth_wise.conv_bias: Tensor[(16), float32] /* span=quantized::conv2d_relu_1:0:0 */, %backbone.conv2.point_wise.conv_weight: Tensor[(32, 16, 1, 1), float32] /* span=quantized::conv2d_relu_2:0:0 */, %backbone.conv2.point_wise.conv_bias: Tensor[(32), float32] /* span=quantized::conv2d_relu_2:0:0 */) {
  %0 = qnn.quantize(%input, 0.0347108f /* span=aten::quantize_per_tensor_0:0:0 */, 125 /* span=aten::quantize_per_tensor_0:0:0 */, out_dtype="uint8", axis=1) /* span=aten::quantize_per_tensor_0:0:0 */;
  %1 = nn.pad(%0, 125f /* span=quantized::conv2d_relu_0:0:0 */, pad_width=[[0, 0], [0, 0], [1, 1], [1, 1]]) /* span=quantized::conv2d_relu_0:0:0 */;
  %2 = qnn.quantize(%backbone.conv1.conv_weight, 0.00150606f /* span=quantized::conv2d_relu_0:0:0 */, 0 /* span=quantized::conv2d_relu_0:0:0 */, out_dtype="int8", axis=0) /* span=quantized::conv2d_relu_0:0:0 */;
  %3 = qnn.conv2d(%1, %2, 125 /* span=quantized::conv2d_relu_0:0:0 */, 0 /* span=quantized::conv2d_relu_0:0:0 */, 0.0347108f /* span=quantized::conv2d_relu_0:0:0 */, 0.00150606f /* span=quantized::conv2d_relu_0:0:0 */, strides=[2, 2], padding=[0, 0, 0, 0], channels=16, kernel_size=[3, 3], out_dtype="int32") /* span=quantized::conv2d_relu_0:0:0 */;
  %4 = qnn.quantize(%backbone.conv1.conv_bias, 5.22766e-05f /* span=quantized::conv2d_relu_0:0:0 */, 0 /* span=quantized::conv2d_relu_0:0:0 */, out_dtype="int32", axis=0) /* span=quantized::conv2d_relu_0:0:0 */;
  %5 = nn.bias_add(%3, %4) /* span=quantized::conv2d_relu_0:0:0 */;
  %6 = qnn.requantize(%5, 5.22766e-05f /* span=quantized::conv2d_relu_0:0:0 */, 0 /* span=quantized::conv2d_relu_0:0:0 */, 0.0132984f /* span=quantized::conv2d_relu_0:0:0 */, 0 /* span=quantized::conv2d_relu_0:0:0 */, axis=1, out_dtype="int32") /* span=quantized::conv2d_relu_0:0:0 */;
  %7 = clip(%6, a_min=0f, a_max=255f) /* span=quantized::conv2d_relu_0:0:0 */;
  %8 = cast(%7, dtype="uint8") /* span=quantized::conv2d_relu_0:0:0 */;
  %9 = nn.pad(%8, 0f /* span=quantized::conv2d_relu_1:0:0 */, pad_width=[[0, 0], [0, 0], [1, 1], [1, 1]]) /* span=quantized::conv2d_relu_1:0:0 */;
  %10 = qnn.quantize(%backbone.conv2.depth_wise.conv_weight, 0.00256311f /* span=quantized::conv2d_relu_1:0:0 */, 0 /* span=quantized::conv2d_relu_1:0:0 */, out_dtype="int8", axis=0) /* span=quantized::conv2d_relu_1:0:0 */;
  %11 = qnn.conv2d(%9, %10, 0 /* span=quantized::conv2d_relu_1:0:0 */, 0 /* span=quantized::conv2d_relu_1:0:0 */, 0.0132984f /* span=quantized::conv2d_relu_1:0:0 */, 0.00256311f /* span=quantized::conv2d_relu_1:0:0 */, padding=[0, 0, 0, 0], groups=16, channels=16, kernel_size=[3, 3], out_dtype="int32") /* span=quantized::conv2d_relu_1:0:0 */;
  %12 = qnn.quantize(%backbone.conv2.depth_wise.conv_bias, 3.40854e-05f /* span=quantized::conv2d_relu_1:0:0 */, 0 /* span=quantized::conv2d_relu_1:0:0 */, out_dtype="int32", axis=0) /* span=quantized::conv2d_relu_1:0:0 */;
  %13 = nn.bias_add(%11, %12) /* span=quantized::conv2d_relu_1:0:0 */;
  %14 = qnn.requantize(%13, 3.40854e-05f /* span=quantized::conv2d_relu_1:0:0 */, 0 /* span=quantized::conv2d_relu_1:0:0 */, 0.00509362f /* span=quantized::conv2d_relu_1:0:0 */, 0 /* span=quantized::conv2d_relu_1:0:0 */, axis=1, out_dtype="int32") /* span=quantized::conv2d_relu_1:0:0 */;
  %15 = clip(%14, a_min=0f, a_max=255f) /* span=quantized::conv2d_relu_1:0:0 */;
  %16 = cast(%15, dtype="uint8") /* span=quantized::conv2d_relu_1:0:0 */;
  %17 = qnn.quantize(%backbone.conv2.point_wise.conv_weight, 0.00195794f /* span=quantized::conv2d_relu_2:0:0 */, 0 /* span=quantized::conv2d_relu_2:0:0 */, out_dtype="int8", axis=0) /* span=quantized::conv2d_relu_2:0:0 */;
  %18 = qnn.conv2d(%16, %17, 0 /* span=quantized::conv2d_relu_2:0:0 */, 0 /* span=quantized::conv2d_relu_2:0:0 */, 0.00509362f /* span=quantized::conv2d_relu_2:0:0 */, 0.00195794f /* span=quantized::conv2d_relu_2:0:0 */, padding=[0, 0, 0, 0], channels=32, kernel_size=[1, 1], out_dtype="int32") /* span=quantized::conv2d_relu_2:0:0 */;
  %19 = qnn.quantize(%backbone.conv2.point_wise.conv_bias, 9.97299e-06f /* span=quantized::conv2d_relu_2:0:0 */, 0 /* span=quantized::conv2d_relu_2:0:0 */, out_dtype="int32", axis=0) /* span=quantized::conv2d_relu_2:0:0 */;
  %20 = nn.bias_add(%18, %19) /* span=quantized::conv2d_relu_2:0:0 */;
  %21 = qnn.requantize(%20, 9.97299e-06f /* span=quantized::conv2d_relu_2:0:0 */, 0 /* span=quantized::conv2d_relu_2:0:0 */, 0.00208748f /* span=quantized::conv2d_relu_2:0:0 */, 0 /* span=quantized::conv2d_relu_2:0:0 */, axis=1, out_dtype="int32") /* span=quantized::conv2d_relu_2:0:0 */;
  %22 = clip(%21, a_min=0f, a_max=255f) /* span=quantized::conv2d_relu_2:0:0 */;
  %23 = cast(%22, dtype="uint8") /* span=quantized::conv2d_relu_2:0:0 */;
  %24 = qnn.dequantize(%8, 0.0132984f /* span=aten::dequantize_0:0:0 */, 0 /* span=aten::dequantize_0:0:0 */, out_dtype="float32") /* span=aten::dequantize_0:0:0 */;
  %25 = qnn.dequantize(%23, 0.00208748f /* span=aten::dequantize_1:0:0 */, 0 /* span=aten::dequantize_1:0:0 */, out_dtype="float32") /* span=aten::dequantize_1:0:0 */;
  (%24, %25)
}