# 量化分区

构建 PyTorch 前端模型：

In [1]:
from torch import nn
import torch


class Model(nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.conv = nn.Conv2d(3, 16, 3, 1, 1, bias=False)
        self.conv2 = nn.Conv2d(16, 16, 3, 1, 1, bias=True)
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool2d(kernel_size=2)
        self.dense1 = nn.Linear(64, 32, bias=False)
        self.dense2 = nn.Linear(32, 16)
        self.dense3 = nn.Linear(16, 8)

    def forward(self, x):
        x = self.conv(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.pool(x)
        x = x.flatten(1)
        x = self.dense1(x)
        x = self.relu(x)
        x = self.dense2(x)
        x = self.relu(x)
        x = self.dense3(x)
        return x

翻译 PyTorch 前端模型为 Relay 模型：

In [2]:
import set_env

In [3]:
import numpy as np
import set_env
import tvm
from tvm import relay

# 输入数据
input_shape = (1, 3, 4, 4)
input_dtype = "float32"
data_np = np.random.rand(*input_shape).astype(input_dtype)
with torch.no_grad():
    pt_model = Model().eval().float()
    traced_model = torch.jit.trace(pt_model, torch.from_numpy(data_np)).eval()
mod, params = relay.frontend.from_pytorch(traced_model, [("data", input_shape)], 
                                          use_parser_friendly_name=True)
with tvm.transform.PassContext(opt_level=3):
    mod = relay.quantize.prerequisite_optimize(mod, params)
print(mod['main'])

fn (%data: Tensor[(1, 3, 4, 4), float32] /* ty=Tensor[(1, 3, 4, 4), float32] span=aten___convolution_0_data:0:0 */) -> Tensor[(1, 8), float32] {
  %0 = nn.conv2d(%data, meta[relay.Constant][0] /* ty=Tensor[(16, 3, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten___convolution_0:0:0 */;
  %1 = nn.relu(%0) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten__relu_0:0:0 */;
  %2 = nn.conv2d(%1, meta[relay.Constant][1] /* ty=Tensor[(16, 16, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten___convolution_1:0:0 */;
  %3 = add(%2, meta[relay.Constant][2] /* ty=Tensor[(16, 1, 1), float32] */) /* ty=Tensor[(1, 16, 4, 4), float32] */;
  %4 = nn.relu(%3) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten__relu_1:0:0 */;
  %5 = nn.max_pool2d(%4, pool_size=[2, 2], strides=[2, 2], padding=[0, 0, 0, 0]) /* ty=Tensor[(1, 16, 2, 2), float32] span=aten__max_pool2d_0:0:0

调用分区 pass：

In [4]:
relay.quantize.partition()(mod)["main"]

fn (%data: Tensor[(1, 3, 4, 4), float32] /* ty=Tensor[(1, 3, 4, 4), float32] span=aten___convolution_0_data:0:0 */) -> Tensor[(1, 8), float32] {
  %0 = nn.conv2d(%data, meta[relay.Constant][0] /* ty=Tensor[(16, 3, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten___convolution_0:0:0 */;
  %1 = nn.relu(%0) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten__relu_0:0:0 */;
  %2 = annotation.cast_hint(%1, dtype="int8") /* ty=Tensor[(1, 16, 4, 4), float32] */;
  %3 = annotation.stop_fusion(%2) /* ty=Tensor[(1, 16, 4, 4), float32] */;
  %4 = nn.conv2d(%3, meta[relay.Constant][1] /* ty=Tensor[(16, 16, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten___convolution_1:0:0 */;
  %5 = add(%4, meta[relay.Constant][2] /* ty=Tensor[(16, 1, 1), float32] */) /* ty=Tensor[(1, 16, 4, 4), float32] */;
  %6 = nn.relu(%5) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten

为了展示分区的效果，将 TVM 默认注册的属性函数清除：

In [5]:
from tvm.relay.op import op as _op
def reset_partition():
    op_names = ["nn.conv2d", "nn.relu", "nn.max_pool2d", "add", "multiply", "clip", "nn.global_avg_pool2d"]
    for op_name in op_names:
        _op.get(op_name).reset_attr("FQPartitionRewrite")

此时调用分区 pass，则有：

In [6]:
reset_partition()
relay.quantize.partition()(mod)["main"]

fn (%data: Tensor[(1, 3, 4, 4), float32] /* ty=Tensor[(1, 3, 4, 4), float32] span=aten___convolution_0_data:0:0 */) -> Tensor[(1, 8), float32] {
  %0 = nn.conv2d(%data, meta[relay.Constant][0] /* ty=Tensor[(16, 3, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten___convolution_0:0:0 */;
  %1 = nn.relu(%0) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten__relu_0:0:0 */;
  %2 = nn.conv2d(%1, meta[relay.Constant][1] /* ty=Tensor[(16, 16, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten___convolution_1:0:0 */;
  %3 = add(%2, meta[relay.Constant][2] /* ty=Tensor[(16, 1, 1), float32] */) /* ty=Tensor[(1, 16, 4, 4), float32] */;
  %4 = nn.relu(%3) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten__relu_1:0:0 */;
  %5 = nn.max_pool2d(%4, pool_size=[2, 2], strides=[2, 2], padding=[0, 0, 0, 0]) /* ty=Tensor[(1, 16, 2, 2), float32] span=aten__max_pool2d_0:0:0

下面看看如何从算子层面考虑对模型进行分区。

In [7]:
from tvm.relay.quantize._partition import (
    partition_expr_check,
    QPartitionExpr,
    register_partition_function
)
from tvm.relay.quantize.quantize import _forward_op

In [8]:
register_partition_function??

[31mSignature:[39m register_partition_function(op_name, frewrite=[38;5;28;01mNone[39;00m, level=[32m10[39m)
[31mDocstring:[39m <no docstring>
[31mSource:[39m   
[38;5;28;01mdef[39;00m register_partition_function(op_name, frewrite=[38;5;28;01mNone[39;00m, level=[32m10[39m):
    [38;5;28;01mreturn[39;00m tvm.ir.register_op_attr(op_name, [33m"FQPartitionRewrite"[39m, frewrite, level)
[31mFile:[39m      /media/pc/data/board/arria10/lxw/tasks/tvm-ai/python/tvm/relay/quantize/_partition.py
[31mType:[39m      function

{func}`register_partition_function` 为算子注册 `"FQPartitionRewrite"` 属性。

In [9]:
partition_expr_check??

[31mSignature:[39m partition_expr_check(expr)
[31mDocstring:[39m <no docstring>
[31mSource:[39m   
[38;5;28;01mdef[39;00m partition_expr_check(expr):
    [38;5;28;01mif[39;00m isinstance(expr, QPartitionExpr):
        [38;5;28;01mreturn[39;00m [38;5;28;01mTrue[39;00m, expr.expr
    [38;5;28;01mreturn[39;00m [38;5;28;01mFalse[39;00m, expr
[31mFile:[39m      /media/pc/data/board/arria10/lxw/tasks/tvm-ai/python/tvm/relay/quantize/_partition.py
[31mType:[39m      function

{func}`partition_expr_check` 用于检查 `expr` 是否为 `QPartitionExpr`。

In [10]:
_forward_op??

[31mSignature:[39m _forward_op(ref_call, args)
[31mSource:[39m   
[38;5;28;01mdef[39;00m _forward_op(ref_call, args):
    [33m"""forward the operator of ref_call with provided arguments"""[39m
    [38;5;28;01mreturn[39;00m _expr.Call(ref_call.op, args, ref_call.attrs, ref_call.type_args, ref_call.span)
[31mFile:[39m      /media/pc/data/board/arria10/lxw/tasks/tvm-ai/python/tvm/relay/quantize/quantize.py
[31mType:[39m      function

{func}`_forward_op` 用于重写回调函数。

## `nn.relu` 和 `nn.max_pool2d` 分区

对 `nn.relu` 进行分区，可以：

In [11]:
from tvm.relay.op import op as _op
_op.get("nn.relu").reset_attr("FQPartitionRewrite")

@register_partition_function("nn.relu")
def _relu_partition(ref_call, new_args, ctx):
    return QPartitionExpr(_forward_op(ref_call, new_args))

In [12]:
run_mod = relay.quantize.partition()(mod)
run_mod["main"]

fn (%data: Tensor[(1, 3, 4, 4), float32] /* ty=Tensor[(1, 3, 4, 4), float32] span=aten___convolution_0_data:0:0 */) -> Tensor[(1, 8), float32] {
  %0 = nn.conv2d(%data, meta[relay.Constant][0] /* ty=Tensor[(16, 3, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten___convolution_0:0:0 */;
  %1 = nn.relu(%0) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten__relu_0:0:0 */;
  %2 = annotation.cast_hint(%1, dtype="int8") /* ty=Tensor[(1, 16, 4, 4), float32] */;
  %3 = annotation.stop_fusion(%2) /* ty=Tensor[(1, 16, 4, 4), float32] */;
  %4 = nn.conv2d(%3, meta[relay.Constant][1] /* ty=Tensor[(16, 16, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten___convolution_1:0:0 */;
  %5 = add(%4, meta[relay.Constant][2] /* ty=Tensor[(16, 1, 1), float32] */) /* ty=Tensor[(1, 16, 4, 4), float32] */;
  %6 = nn.relu(%5) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten

也可以对 `nn.max_pool2d` 执行同样的操作：

In [13]:
from tvm.relay.op import op as _op
_op.get("nn.relu").reset_attr("FQPartitionRewrite")
_op.get("nn.max_pool2d").reset_attr("FQPartitionRewrite")
@register_partition_function("nn.max_pool2d")
def _max_pool2d_partition(ref_call, new_args, ctx):
    return QPartitionExpr(_forward_op(ref_call, new_args))

In [14]:
run_mod = relay.quantize.partition()(mod)
run_mod["main"]

fn (%data: Tensor[(1, 3, 4, 4), float32] /* ty=Tensor[(1, 3, 4, 4), float32] span=aten___convolution_0_data:0:0 */) -> Tensor[(1, 8), float32] {
  %0 = nn.conv2d(%data, meta[relay.Constant][0] /* ty=Tensor[(16, 3, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten___convolution_0:0:0 */;
  %1 = nn.relu(%0) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten__relu_0:0:0 */;
  %2 = nn.conv2d(%1, meta[relay.Constant][1] /* ty=Tensor[(16, 16, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten___convolution_1:0:0 */;
  %3 = add(%2, meta[relay.Constant][2] /* ty=Tensor[(16, 1, 1), float32] */) /* ty=Tensor[(1, 16, 4, 4), float32] */;
  %4 = nn.relu(%3) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten__relu_1:0:0 */;
  %5 = nn.max_pool2d(%4, pool_size=[2, 2], strides=[2, 2], padding=[0, 0, 0, 0]) /* ty=Tensor[(1, 16, 2, 2), float32] span=aten__max_pool2d_0:0:0

{func}`_max_pool2d_partition` 和 {func}`_relu_partition` 本质上是同样的功能实现，可以将它们合并为一个函数：

In [15]:
def identity_partition_function(ref_call, new_args, ctx):
    cond, expr = partition_expr_check(new_args[0])
    return QPartitionExpr(_forward_op(ref_call, [expr]))
reset_partition()
register_partition_function("nn.relu", identity_partition_function)
register_partition_function("nn.max_pool2d", identity_partition_function)

<function __main__.identity_partition_function(ref_call, new_args, ctx)>

可以看出 `nn.max_pool2d` + `nn.relu` 实现了融合：

In [16]:
run_mod = relay.quantize.partition()(mod)
run_mod["main"]

fn (%data: Tensor[(1, 3, 4, 4), float32] /* ty=Tensor[(1, 3, 4, 4), float32] span=aten___convolution_0_data:0:0 */) -> Tensor[(1, 8), float32] {
  %0 = nn.conv2d(%data, meta[relay.Constant][0] /* ty=Tensor[(16, 3, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten___convolution_0:0:0 */;
  %1 = nn.relu(%0) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten__relu_0:0:0 */;
  %2 = annotation.cast_hint(%1, dtype="int8") /* ty=Tensor[(1, 16, 4, 4), float32] */;
  %3 = annotation.stop_fusion(%2) /* ty=Tensor[(1, 16, 4, 4), float32] */;
  %4 = nn.conv2d(%3, meta[relay.Constant][1] /* ty=Tensor[(16, 16, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten___convolution_1:0:0 */;
  %5 = add(%4, meta[relay.Constant][2] /* ty=Tensor[(16, 1, 1), float32] */) /* ty=Tensor[(1, 16, 4, 4), float32] */;
  %6 = nn.relu(%5) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten

## `dense+add+relu` 分区

In [17]:
from tvm.relay import analysis as _analysis

reset_partition()
_op.get("nn.dense").reset_attr("FQPartitionRewrite")
def add_partition_generic(ref_call, new_args, ctx):
    """Rewrite function for ewise add for partition for generic devices"""
    lhs_cond, lhs = partition_expr_check(new_args[0])
    rhs_cond, rhs = partition_expr_check(new_args[1])
    if lhs_cond and rhs_cond:
        # - introduced by ResNet, when for the first residual connection
        #     ...
        #     %0 = nn.conv2d(%data, %meta[relay.Constant])
        #     %1 = add(%0, %meta[relay.Constant])
        #     %2 = nn.relu(%1)
        #     %3 = nn.max_pool2d(%2)
        #     ...
        #     %9 = nn.conv2d(%8, %meta[relay.Constant])
        #     %10 = add(%9, %meta[relay.Constant])
        #     %11 = add(%3, %10)  <- need to insert annotations for %3, %10
        #     ...
        lhs = new_args[0].realize()
        rhs = new_args[1].realize()
        return QPartitionExpr(_forward_op(ref_call, [lhs, rhs]))
    if not lhs_cond and rhs_cond:
        # - introduced by residual connection in ResNet
        #     ...
        #     %13 = nn.conv2d(%12, %meta[relay.Constant])
        #     %14 = add(%13, %meta[relay.Constant])
        #     %15 = annotation.cast_hint(%15, 'int8')
        #     %16 = annotation.stop_fusion(%16)
        #     %17 = add(%5, %16)
        #     %18 = nn.relu(%17)
        #     ...
        #     %24 = nn.conv2d(%23, %meta[relay.Constant])
        #     %25 = add(%24, %meta[relay.Constant])
        #     %26 = add(%18, %25)  <- need to insert annotations for %25
        #     ...
        rhs = new_args[1].realize()
        return _forward_op(ref_call, [lhs, rhs])
    if lhs_cond and not rhs_cond:
        if _analysis.check_constant(rhs):
            # - introduced by batch_norm: add(out, bias)
            return QPartitionExpr(_forward_op(ref_call, [lhs, rhs]))
        # - introduced by residual connection in MobileNetV2
        #     ...
        #     %81 = add(%80, meta[relay.Constant])
        #     %82 = annotation.cast_hint(%81, 'int8')
        #     %83 = annotation.stop_fusion(%82)
        #     %84 = add(%79, %83)
        #     ...
        #     %96 = nn.conv2d(%94, %meta[relay.Constant])
        #     %96 = add(%95, %meta[relay.Constant])
        #     %97 = add(%96, %84)  <- need to insert annotations for %96
        #     ...
        lhs = new_args[0].realize()
        return _forward_op(ref_call, [lhs, rhs])
    if not lhs_cond and not rhs_cond:
        # trivial case
        return None

    raise ValueError
@register_partition_function("nn.dense")
def dense_partition_function(ref_call, new_args, ctx):
    """Rewrite function for dense for partition"""
    data_cond, data = partition_expr_check(new_args[0])
    kernel_cond, kernel = partition_expr_check(new_args[1])
    assert not kernel_cond
    if data_cond:
        print(type(new_args[0]))
        data = new_args[0].realize()
    ret = _forward_op(ref_call, [data, kernel])
    return QPartitionExpr(ret)
register_partition_function("nn.relu", identity_partition_function)
register_partition_function("nn.max_pool2d", identity_partition_function)
register_partition_function("add", add_partition_generic)
run_mod = relay.quantize.partition()(mod)
run_mod["main"]

<class 'tvm.relay.quantize._partition.QPartitionExpr'>
<class 'tvm.relay.quantize._partition.QPartitionExpr'>


fn (%data: Tensor[(1, 3, 4, 4), float32] /* ty=Tensor[(1, 3, 4, 4), float32] span=aten___convolution_0_data:0:0 */) -> Tensor[(1, 8), float32] {
  %0 = nn.conv2d(%data, meta[relay.Constant][0] /* ty=Tensor[(16, 3, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten___convolution_0:0:0 */;
  %1 = nn.relu(%0) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten__relu_0:0:0 */;
  %2 = annotation.cast_hint(%1, dtype="int8") /* ty=Tensor[(1, 16, 4, 4), float32] */;
  %3 = annotation.stop_fusion(%2) /* ty=Tensor[(1, 16, 4, 4), float32] */;
  %4 = nn.conv2d(%3, meta[relay.Constant][1] /* ty=Tensor[(16, 16, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten___convolution_1:0:0 */;
  %5 = add(%4, meta[relay.Constant][2] /* ty=Tensor[(16, 1, 1), float32] */) /* ty=Tensor[(1, 16, 4, 4), float32] */;
  %6 = nn.relu(%5) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten

In [18]:
def dataset():
    for _ in range(1):
        data = np.random.normal(size=(1, 3, 4, 4)).astype("float32")
        yield {"data": data}
with tvm.transform.PassContext(opt_level=3):
    with relay.quantize.qconfig(
        calibrate_mode="kl_divergence",
        weight_scale="max",
        skip_conv_layers=[],
        skip_dense_layer=False,
    ):
        qmod = relay.quantize.quantize(mod, params, dataset())

<class 'tvm.relay.quantize._partition.QPartitionExpr'>
<class 'tvm.relay.quantize._partition.QPartitionExpr'>




In [19]:
print(relay.transform.FuseOps()(qmod))

def @main(%data: Tensor[(1, 3, 4, 4), float32] /* ty=Tensor[(1, 3, 4, 4), float32] span=aten___convolution_0_data:0:0 */) -> Tensor[(1, 8), float32] {
  %44 = fn (%p08: Tensor[(1, 3, 4, 4), float32] /* ty=Tensor[(1, 3, 4, 4), float32] */, Primitive=1) -> Tensor[(1, 3, 4, 4), int8] {
    %41 = multiply(%p08, 77.4988f /* ty=float32 */) /* ty=Tensor[(1, 3, 4, 4), float32] */;
    %42 = round(%41) /* ty=Tensor[(1, 3, 4, 4), float32] */;
    %43 = clip(%42, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 3, 4, 4), float32] */;
    cast(%43, dtype="int8") /* ty=Tensor[(1, 3, 4, 4), int8] */
  } /* ty=fn (Tensor[(1, 3, 4, 4), float32]) -> Tensor[(1, 3, 4, 4), int8] */;
  %45 = %44(%data) /* ty=Tensor[(1, 3, 4, 4), int8] */;
  %46 = fn (%p07: Tensor[(1, 3, 4, 4), int8] /* ty=Tensor[(1, 3, 4, 4), int8] */, %p14: Tensor[(16, 3, 3, 3), int8] /* ty=Tensor[(16, 3, 3, 3), int8] */, Primitive=1) -> Tensor[(1, 16, 4, 4), int8] {
    %35 = nn.conv2d(%p07, %p14, padding=[1, 1, 1, 1], channels=16, kernel_size=