量化分区#

构建 PyTorch 前端模型:

from torch import nn
import torch


class Model(nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.conv = nn.Conv2d(3, 16, 3, 1, 1, bias=False)
        self.conv2 = nn.Conv2d(16, 16, 3, 1, 1, bias=True)
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool2d(kernel_size=2)
        self.dense1 = nn.Linear(64, 32, bias=False)
        self.dense2 = nn.Linear(32, 16)
        self.dense3 = nn.Linear(16, 8)

    def forward(self, x):
        x = self.conv(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.pool(x)
        x = x.flatten(1)
        x = self.dense1(x)
        x = self.relu(x)
        x = self.dense2(x)
        x = self.relu(x)
        x = self.dense3(x)
        return x

翻译 PyTorch 前端模型为 Relay 模型:

%cd ..
import testing
/media/pc/data/lxw/ai/tvm-book/doc/read/relay
import numpy as np
import set_env
import tvm
from tvm import relay

# 输入数据
input_shape = (1, 3, 4, 4)
input_dtype = "float32"
data_np = np.random.rand(*input_shape).astype(input_dtype)
with torch.no_grad():
    pt_model = Model().eval().float()
    traced_model = torch.jit.trace(pt_model, torch.from_numpy(data_np)).eval()
mod, params = relay.frontend.from_pytorch(traced_model, [("data", input_shape)], 
                                          use_parser_friendly_name=True)
with tvm.transform.PassContext(opt_level=3):
    mod = relay.quantize.prerequisite_optimize(mod, params)
print(mod['main'])
fn (%data: Tensor[(1, 3, 4, 4), float32] /* ty=Tensor[(1, 3, 4, 4), float32] span=aten___convolution_0_data:0:0 */) -> Tensor[(1, 8), float32] {
  %0 = nn.conv2d(%data, meta[relay.Constant][0] /* ty=Tensor[(16, 3, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten___convolution_0:0:0 */;
  %1 = nn.relu(%0) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten__relu_0:0:0 */;
  %2 = nn.conv2d(%1, meta[relay.Constant][1] /* ty=Tensor[(16, 16, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten___convolution_1:0:0 */;
  %3 = add(%2, meta[relay.Constant][2] /* ty=Tensor[(16, 1, 1), float32] */) /* ty=Tensor[(1, 16, 4, 4), float32] */;
  %4 = nn.relu(%3) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten__relu_1:0:0 */;
  %5 = nn.max_pool2d(%4, pool_size=[2, 2], strides=[2, 2], padding=[0, 0, 0, 0]) /* ty=Tensor[(1, 16, 2, 2), float32] span=aten__max_pool2d_0:0:0 */;
  %6 = reshape(%5, newshape=[0, -1, 1, 1]) /* ty=Tensor[(1, 64, 1, 1), float32] span=aten__flatten_0:0:0 */;
  %7 = squeeze(%6, axis=[2, 3]) /* ty=Tensor[(1, 64), float32] span=aten__flatten_0:0:0 */;
  %8 = nn.dense(%7, meta[relay.Constant][3] /* ty=Tensor[(32, 64), float32] */, units=None) /* ty=Tensor[(1, 32), float32] span=aten__linear_0:0:0 */;
  %9 = nn.relu(%8) /* ty=Tensor[(1, 32), float32] span=aten__relu_2:0:0 */;
  %10 = nn.dense(%9, meta[relay.Constant][4] /* ty=Tensor[(16, 32), float32] */, units=None) /* ty=Tensor[(1, 16), float32] span=aten__linear_1:0:0 */;
  %11 = add(%10, meta[relay.Constant][5] /* ty=Tensor[(16), float32] */) /* ty=Tensor[(1, 16), float32] */;
  %12 = nn.relu(%11) /* ty=Tensor[(1, 16), float32] span=aten__relu_3:0:0 */;
  %13 = nn.dense(%12, meta[relay.Constant][6] /* ty=Tensor[(8, 16), float32] */, units=None) /* ty=Tensor[(1, 8), float32] span=aten__linear_2:0:0 */;
  add(%13, meta[relay.Constant][7] /* ty=Tensor[(8), float32] */) /* ty=Tensor[(1, 8), float32] */
} /* ty=fn (Tensor[(1, 3, 4, 4), float32]) -> Tensor[(1, 8), float32] */

调用分区 pass:

relay.quantize.partition()(mod)["main"]
fn (%data: Tensor[(1, 3, 4, 4), float32] /* ty=Tensor[(1, 3, 4, 4), float32] span=aten___convolution_0_data:0:0 */) -> Tensor[(1, 8), float32] {
  %0 = nn.conv2d(%data, meta[relay.Constant][0] /* ty=Tensor[(16, 3, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten___convolution_0:0:0 */;
  %1 = nn.relu(%0) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten__relu_0:0:0 */;
  %2 = annotation.cast_hint(%1, dtype="int8") /* ty=Tensor[(1, 16, 4, 4), float32] */;
  %3 = annotation.stop_fusion(%2) /* ty=Tensor[(1, 16, 4, 4), float32] */;
  %4 = nn.conv2d(%3, meta[relay.Constant][1] /* ty=Tensor[(16, 16, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten___convolution_1:0:0 */;
  %5 = add(%4, meta[relay.Constant][2] /* ty=Tensor[(16, 1, 1), float32] */) /* ty=Tensor[(1, 16, 4, 4), float32] */;
  %6 = nn.relu(%5) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten__relu_1:0:0 */;
  %7 = nn.max_pool2d(%6, pool_size=[2, 2], strides=[2, 2], padding=[0, 0, 0, 0]) /* ty=Tensor[(1, 16, 2, 2), float32] span=aten__max_pool2d_0:0:0 */;
  %8 = annotation.cast_hint(%7, dtype="int8") /* ty=Tensor[(1, 16, 2, 2), float32] */;
  %9 = annotation.stop_fusion(%8) /* ty=Tensor[(1, 16, 2, 2), float32] */;
  %10 = reshape(%9, newshape=[0, -1, 1, 1]) /* ty=Tensor[(1, 64, 1, 1), float32] span=aten__flatten_0:0:0 */;
  %11 = squeeze(%10, axis=[2, 3]) /* ty=Tensor[(1, 64), float32] span=aten__flatten_0:0:0 */;
  %12 = nn.dense(%11, meta[relay.Constant][3] /* ty=Tensor[(32, 64), float32] */, units=None) /* ty=Tensor[(1, 32), float32] span=aten__linear_0:0:0 */;
  %13 = nn.relu(%12) /* ty=Tensor[(1, 32), float32] span=aten__relu_2:0:0 */;
  %14 = nn.dense(%13, meta[relay.Constant][4] /* ty=Tensor[(16, 32), float32] */, units=None) /* ty=Tensor[(1, 16), float32] span=aten__linear_1:0:0 */;
  %15 = add(%14, meta[relay.Constant][5] /* ty=Tensor[(16), float32] */) /* ty=Tensor[(1, 16), float32] */;
  %16 = nn.relu(%15) /* ty=Tensor[(1, 16), float32] span=aten__relu_3:0:0 */;
  %17 = nn.dense(%16, meta[relay.Constant][6] /* ty=Tensor[(8, 16), float32] */, units=None) /* ty=Tensor[(1, 8), float32] span=aten__linear_2:0:0 */;
  add(%17, meta[relay.Constant][7] /* ty=Tensor[(8), float32] */) /* ty=Tensor[(1, 8), float32] */
} /* ty=fn (Tensor[(1, 3, 4, 4), float32]) -> Tensor[(1, 8), float32] */

为了展示分区的效果,将 TVM 默认注册的属性函数清除:

from tvm.relay.op import op as _op
def reset_partition():
    op_names = ["nn.conv2d", "nn.relu", "nn.max_pool2d", "add", "multiply", "clip", "nn.global_avg_pool2d"]
    for op_name in op_names:
        _op.get(op_name).reset_attr("FQPartitionRewrite")

此时调用分区 pass,则有:

reset_partition()
relay.quantize.partition()(mod)["main"]
fn (%data: Tensor[(1, 3, 4, 4), float32] /* ty=Tensor[(1, 3, 4, 4), float32] span=aten___convolution_0_data:0:0 */) -> Tensor[(1, 8), float32] {
  %0 = nn.conv2d(%data, meta[relay.Constant][0] /* ty=Tensor[(16, 3, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten___convolution_0:0:0 */;
  %1 = nn.relu(%0) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten__relu_0:0:0 */;
  %2 = nn.conv2d(%1, meta[relay.Constant][1] /* ty=Tensor[(16, 16, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten___convolution_1:0:0 */;
  %3 = add(%2, meta[relay.Constant][2] /* ty=Tensor[(16, 1, 1), float32] */) /* ty=Tensor[(1, 16, 4, 4), float32] */;
  %4 = nn.relu(%3) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten__relu_1:0:0 */;
  %5 = nn.max_pool2d(%4, pool_size=[2, 2], strides=[2, 2], padding=[0, 0, 0, 0]) /* ty=Tensor[(1, 16, 2, 2), float32] span=aten__max_pool2d_0:0:0 */;
  %6 = reshape(%5, newshape=[0, -1, 1, 1]) /* ty=Tensor[(1, 64, 1, 1), float32] span=aten__flatten_0:0:0 */;
  %7 = squeeze(%6, axis=[2, 3]) /* ty=Tensor[(1, 64), float32] span=aten__flatten_0:0:0 */;
  %8 = nn.dense(%7, meta[relay.Constant][3] /* ty=Tensor[(32, 64), float32] */, units=None) /* ty=Tensor[(1, 32), float32] span=aten__linear_0:0:0 */;
  %9 = nn.relu(%8) /* ty=Tensor[(1, 32), float32] span=aten__relu_2:0:0 */;
  %10 = nn.dense(%9, meta[relay.Constant][4] /* ty=Tensor[(16, 32), float32] */, units=None) /* ty=Tensor[(1, 16), float32] span=aten__linear_1:0:0 */;
  %11 = add(%10, meta[relay.Constant][5] /* ty=Tensor[(16), float32] */) /* ty=Tensor[(1, 16), float32] */;
  %12 = nn.relu(%11) /* ty=Tensor[(1, 16), float32] span=aten__relu_3:0:0 */;
  %13 = nn.dense(%12, meta[relay.Constant][6] /* ty=Tensor[(8, 16), float32] */, units=None) /* ty=Tensor[(1, 8), float32] span=aten__linear_2:0:0 */;
  add(%13, meta[relay.Constant][7] /* ty=Tensor[(8), float32] */) /* ty=Tensor[(1, 8), float32] */
} /* ty=fn (Tensor[(1, 3, 4, 4), float32]) -> Tensor[(1, 8), float32] */

下面看看如何从算子层面考虑对模型进行分区。

from tvm.relay.quantize._partition import (
    partition_expr_check,
    QPartitionExpr,
    register_partition_function
)
from tvm.relay.quantize.quantize import _forward_op
register_partition_function??
Signature: register_partition_function(op_name, frewrite=None, level=10)
Docstring: <no docstring>
Source:   
def register_partition_function(op_name, frewrite=None, level=10):
    return tvm.ir.register_op_attr(op_name, "FQPartitionRewrite", frewrite, level)
File:      /media/pc/data/lxw/ai/tvm/python/tvm/relay/quantize/_partition.py
Type:      function

register_partition_function() 为算子注册 "FQPartitionRewrite" 属性。

partition_expr_check??
Signature: partition_expr_check(expr)
Docstring: <no docstring>
Source:   
def partition_expr_check(expr):
    if isinstance(expr, QPartitionExpr):
        return True, expr.expr
    return False, expr
File:      /media/pc/data/lxw/ai/tvm/python/tvm/relay/quantize/_partition.py
Type:      function

partition_expr_check() 用于检查 expr 是否为 QPartitionExpr

_forward_op??
Signature: _forward_op(ref_call, args)
Source:   
def _forward_op(ref_call, args):
    """forward the operator of ref_call with provided arguments"""
    return _expr.Call(ref_call.op, args, ref_call.attrs, ref_call.type_args, ref_call.span)
File:      /media/pc/data/lxw/ai/tvm/python/tvm/relay/quantize/quantize.py
Type:      function

_forward_op() 用于重写回调函数。

nn.relunn.max_pool2d 分区#

nn.relu 进行分区,可以:

from tvm.relay.op import op as _op
_op.get("nn.relu").reset_attr("FQPartitionRewrite")

@register_partition_function("nn.relu")
def _relu_partition(ref_call, new_args, ctx):
    return QPartitionExpr(_forward_op(ref_call, new_args))
run_mod = relay.quantize.partition()(mod)
run_mod["main"]
fn (%data: Tensor[(1, 3, 4, 4), float32] /* ty=Tensor[(1, 3, 4, 4), float32] span=aten___convolution_0_data:0:0 */) -> Tensor[(1, 8), float32] {
  %0 = nn.conv2d(%data, meta[relay.Constant][0] /* ty=Tensor[(16, 3, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten___convolution_0:0:0 */;
  %1 = nn.relu(%0) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten__relu_0:0:0 */;
  %2 = annotation.cast_hint(%1, dtype="int8") /* ty=Tensor[(1, 16, 4, 4), float32] */;
  %3 = annotation.stop_fusion(%2) /* ty=Tensor[(1, 16, 4, 4), float32] */;
  %4 = nn.conv2d(%3, meta[relay.Constant][1] /* ty=Tensor[(16, 16, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten___convolution_1:0:0 */;
  %5 = add(%4, meta[relay.Constant][2] /* ty=Tensor[(16, 1, 1), float32] */) /* ty=Tensor[(1, 16, 4, 4), float32] */;
  %6 = nn.relu(%5) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten__relu_1:0:0 */;
  %7 = annotation.cast_hint(%6, dtype="int8") /* ty=Tensor[(1, 16, 4, 4), float32] */;
  %8 = annotation.stop_fusion(%7) /* ty=Tensor[(1, 16, 4, 4), float32] */;
  %9 = nn.max_pool2d(%8, pool_size=[2, 2], strides=[2, 2], padding=[0, 0, 0, 0]) /* ty=Tensor[(1, 16, 2, 2), float32] span=aten__max_pool2d_0:0:0 */;
  %10 = reshape(%9, newshape=[0, -1, 1, 1]) /* ty=Tensor[(1, 64, 1, 1), float32] span=aten__flatten_0:0:0 */;
  %11 = squeeze(%10, axis=[2, 3]) /* ty=Tensor[(1, 64), float32] span=aten__flatten_0:0:0 */;
  %12 = nn.dense(%11, meta[relay.Constant][3] /* ty=Tensor[(32, 64), float32] */, units=None) /* ty=Tensor[(1, 32), float32] span=aten__linear_0:0:0 */;
  %13 = nn.relu(%12) /* ty=Tensor[(1, 32), float32] span=aten__relu_2:0:0 */;
  %14 = annotation.cast_hint(%13, dtype="int8") /* ty=Tensor[(1, 32), float32] */;
  %15 = annotation.stop_fusion(%14) /* ty=Tensor[(1, 32), float32] */;
  %16 = nn.dense(%15, meta[relay.Constant][4] /* ty=Tensor[(16, 32), float32] */, units=None) /* ty=Tensor[(1, 16), float32] span=aten__linear_1:0:0 */;
  %17 = add(%16, meta[relay.Constant][5] /* ty=Tensor[(16), float32] */) /* ty=Tensor[(1, 16), float32] */;
  %18 = nn.relu(%17) /* ty=Tensor[(1, 16), float32] span=aten__relu_3:0:0 */;
  %19 = annotation.cast_hint(%18, dtype="int8") /* ty=Tensor[(1, 16), float32] */;
  %20 = annotation.stop_fusion(%19) /* ty=Tensor[(1, 16), float32] */;
  %21 = nn.dense(%20, meta[relay.Constant][6] /* ty=Tensor[(8, 16), float32] */, units=None) /* ty=Tensor[(1, 8), float32] span=aten__linear_2:0:0 */;
  add(%21, meta[relay.Constant][7] /* ty=Tensor[(8), float32] */) /* ty=Tensor[(1, 8), float32] */
} /* ty=fn (Tensor[(1, 3, 4, 4), float32]) -> Tensor[(1, 8), float32] */

也可以对 nn.max_pool2d 执行同样的操作:

from tvm.relay.op import op as _op
_op.get("nn.relu").reset_attr("FQPartitionRewrite")
_op.get("nn.max_pool2d").reset_attr("FQPartitionRewrite")
@register_partition_function("nn.max_pool2d")
def _max_pool2d_partition(ref_call, new_args, ctx):
    return QPartitionExpr(_forward_op(ref_call, new_args))
run_mod = relay.quantize.partition()(mod)
run_mod["main"]
fn (%data: Tensor[(1, 3, 4, 4), float32] /* ty=Tensor[(1, 3, 4, 4), float32] span=aten___convolution_0_data:0:0 */) -> Tensor[(1, 8), float32] {
  %0 = nn.conv2d(%data, meta[relay.Constant][0] /* ty=Tensor[(16, 3, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten___convolution_0:0:0 */;
  %1 = nn.relu(%0) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten__relu_0:0:0 */;
  %2 = nn.conv2d(%1, meta[relay.Constant][1] /* ty=Tensor[(16, 16, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten___convolution_1:0:0 */;
  %3 = add(%2, meta[relay.Constant][2] /* ty=Tensor[(16, 1, 1), float32] */) /* ty=Tensor[(1, 16, 4, 4), float32] */;
  %4 = nn.relu(%3) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten__relu_1:0:0 */;
  %5 = nn.max_pool2d(%4, pool_size=[2, 2], strides=[2, 2], padding=[0, 0, 0, 0]) /* ty=Tensor[(1, 16, 2, 2), float32] span=aten__max_pool2d_0:0:0 */;
  %6 = annotation.cast_hint(%5, dtype="int8") /* ty=Tensor[(1, 16, 2, 2), float32] */;
  %7 = annotation.stop_fusion(%6) /* ty=Tensor[(1, 16, 2, 2), float32] */;
  %8 = reshape(%7, newshape=[0, -1, 1, 1]) /* ty=Tensor[(1, 64, 1, 1), float32] span=aten__flatten_0:0:0 */;
  %9 = squeeze(%8, axis=[2, 3]) /* ty=Tensor[(1, 64), float32] span=aten__flatten_0:0:0 */;
  %10 = nn.dense(%9, meta[relay.Constant][3] /* ty=Tensor[(32, 64), float32] */, units=None) /* ty=Tensor[(1, 32), float32] span=aten__linear_0:0:0 */;
  %11 = nn.relu(%10) /* ty=Tensor[(1, 32), float32] span=aten__relu_2:0:0 */;
  %12 = nn.dense(%11, meta[relay.Constant][4] /* ty=Tensor[(16, 32), float32] */, units=None) /* ty=Tensor[(1, 16), float32] span=aten__linear_1:0:0 */;
  %13 = add(%12, meta[relay.Constant][5] /* ty=Tensor[(16), float32] */) /* ty=Tensor[(1, 16), float32] */;
  %14 = nn.relu(%13) /* ty=Tensor[(1, 16), float32] span=aten__relu_3:0:0 */;
  %15 = nn.dense(%14, meta[relay.Constant][6] /* ty=Tensor[(8, 16), float32] */, units=None) /* ty=Tensor[(1, 8), float32] span=aten__linear_2:0:0 */;
  add(%15, meta[relay.Constant][7] /* ty=Tensor[(8), float32] */) /* ty=Tensor[(1, 8), float32] */
} /* ty=fn (Tensor[(1, 3, 4, 4), float32]) -> Tensor[(1, 8), float32] */

_max_pool2d_partition()_relu_partition() 本质上是同样的功能实现,可以将它们合并为一个函数:

def identity_partition_function(ref_call, new_args, ctx):
    cond, expr = partition_expr_check(new_args[0])
    return QPartitionExpr(_forward_op(ref_call, [expr]))
reset_partition()
register_partition_function("nn.relu", identity_partition_function)
register_partition_function("nn.max_pool2d", identity_partition_function)
<function __main__.identity_partition_function(ref_call, new_args, ctx)>

可以看出 nn.max_pool2d + nn.relu 实现了融合:

run_mod = relay.quantize.partition()(mod)
run_mod["main"]
fn (%data: Tensor[(1, 3, 4, 4), float32] /* ty=Tensor[(1, 3, 4, 4), float32] span=aten___convolution_0_data:0:0 */) -> Tensor[(1, 8), float32] {
  %0 = nn.conv2d(%data, meta[relay.Constant][0] /* ty=Tensor[(16, 3, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten___convolution_0:0:0 */;
  %1 = nn.relu(%0) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten__relu_0:0:0 */;
  %2 = annotation.cast_hint(%1, dtype="int8") /* ty=Tensor[(1, 16, 4, 4), float32] */;
  %3 = annotation.stop_fusion(%2) /* ty=Tensor[(1, 16, 4, 4), float32] */;
  %4 = nn.conv2d(%3, meta[relay.Constant][1] /* ty=Tensor[(16, 16, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten___convolution_1:0:0 */;
  %5 = add(%4, meta[relay.Constant][2] /* ty=Tensor[(16, 1, 1), float32] */) /* ty=Tensor[(1, 16, 4, 4), float32] */;
  %6 = nn.relu(%5) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten__relu_1:0:0 */;
  %7 = nn.max_pool2d(%6, pool_size=[2, 2], strides=[2, 2], padding=[0, 0, 0, 0]) /* ty=Tensor[(1, 16, 2, 2), float32] span=aten__max_pool2d_0:0:0 */;
  %8 = annotation.cast_hint(%7, dtype="int8") /* ty=Tensor[(1, 16, 2, 2), float32] */;
  %9 = annotation.stop_fusion(%8) /* ty=Tensor[(1, 16, 2, 2), float32] */;
  %10 = reshape(%9, newshape=[0, -1, 1, 1]) /* ty=Tensor[(1, 64, 1, 1), float32] span=aten__flatten_0:0:0 */;
  %11 = squeeze(%10, axis=[2, 3]) /* ty=Tensor[(1, 64), float32] span=aten__flatten_0:0:0 */;
  %12 = nn.dense(%11, meta[relay.Constant][3] /* ty=Tensor[(32, 64), float32] */, units=None) /* ty=Tensor[(1, 32), float32] span=aten__linear_0:0:0 */;
  %13 = nn.relu(%12) /* ty=Tensor[(1, 32), float32] span=aten__relu_2:0:0 */;
  %14 = annotation.cast_hint(%13, dtype="int8") /* ty=Tensor[(1, 32), float32] */;
  %15 = annotation.stop_fusion(%14) /* ty=Tensor[(1, 32), float32] */;
  %16 = nn.dense(%15, meta[relay.Constant][4] /* ty=Tensor[(16, 32), float32] */, units=None) /* ty=Tensor[(1, 16), float32] span=aten__linear_1:0:0 */;
  %17 = add(%16, meta[relay.Constant][5] /* ty=Tensor[(16), float32] */) /* ty=Tensor[(1, 16), float32] */;
  %18 = nn.relu(%17) /* ty=Tensor[(1, 16), float32] span=aten__relu_3:0:0 */;
  %19 = annotation.cast_hint(%18, dtype="int8") /* ty=Tensor[(1, 16), float32] */;
  %20 = annotation.stop_fusion(%19) /* ty=Tensor[(1, 16), float32] */;
  %21 = nn.dense(%20, meta[relay.Constant][6] /* ty=Tensor[(8, 16), float32] */, units=None) /* ty=Tensor[(1, 8), float32] span=aten__linear_2:0:0 */;
  add(%21, meta[relay.Constant][7] /* ty=Tensor[(8), float32] */) /* ty=Tensor[(1, 8), float32] */
} /* ty=fn (Tensor[(1, 3, 4, 4), float32]) -> Tensor[(1, 8), float32] */

dense+add+relu 分区#

from tvm.relay import analysis as _analysis

reset_partition()
_op.get("nn.dense").reset_attr("FQPartitionRewrite")
def add_partition_generic(ref_call, new_args, ctx):
    """Rewrite function for ewise add for partition for generic devices"""
    lhs_cond, lhs = partition_expr_check(new_args[0])
    rhs_cond, rhs = partition_expr_check(new_args[1])
    if lhs_cond and rhs_cond:
        # - introduced by ResNet, when for the first residual connection
        #     ...
        #     %0 = nn.conv2d(%data, %meta[relay.Constant])
        #     %1 = add(%0, %meta[relay.Constant])
        #     %2 = nn.relu(%1)
        #     %3 = nn.max_pool2d(%2)
        #     ...
        #     %9 = nn.conv2d(%8, %meta[relay.Constant])
        #     %10 = add(%9, %meta[relay.Constant])
        #     %11 = add(%3, %10)  <- need to insert annotations for %3, %10
        #     ...
        lhs = new_args[0].realize()
        rhs = new_args[1].realize()
        return QPartitionExpr(_forward_op(ref_call, [lhs, rhs]))
    if not lhs_cond and rhs_cond:
        # - introduced by residual connection in ResNet
        #     ...
        #     %13 = nn.conv2d(%12, %meta[relay.Constant])
        #     %14 = add(%13, %meta[relay.Constant])
        #     %15 = annotation.cast_hint(%15, 'int8')
        #     %16 = annotation.stop_fusion(%16)
        #     %17 = add(%5, %16)
        #     %18 = nn.relu(%17)
        #     ...
        #     %24 = nn.conv2d(%23, %meta[relay.Constant])
        #     %25 = add(%24, %meta[relay.Constant])
        #     %26 = add(%18, %25)  <- need to insert annotations for %25
        #     ...
        rhs = new_args[1].realize()
        return _forward_op(ref_call, [lhs, rhs])
    if lhs_cond and not rhs_cond:
        if _analysis.check_constant(rhs):
            # - introduced by batch_norm: add(out, bias)
            return QPartitionExpr(_forward_op(ref_call, [lhs, rhs]))
        # - introduced by residual connection in MobileNetV2
        #     ...
        #     %81 = add(%80, meta[relay.Constant])
        #     %82 = annotation.cast_hint(%81, 'int8')
        #     %83 = annotation.stop_fusion(%82)
        #     %84 = add(%79, %83)
        #     ...
        #     %96 = nn.conv2d(%94, %meta[relay.Constant])
        #     %96 = add(%95, %meta[relay.Constant])
        #     %97 = add(%96, %84)  <- need to insert annotations for %96
        #     ...
        lhs = new_args[0].realize()
        return _forward_op(ref_call, [lhs, rhs])
    if not lhs_cond and not rhs_cond:
        # trivial case
        return None

    raise ValueError
@register_partition_function("nn.dense")
def dense_partition_function(ref_call, new_args, ctx):
    """Rewrite function for dense for partition"""
    data_cond, data = partition_expr_check(new_args[0])
    kernel_cond, kernel = partition_expr_check(new_args[1])
    assert not kernel_cond
    if data_cond:
        print(type(new_args[0]))
        data = new_args[0].realize()
    ret = _forward_op(ref_call, [data, kernel])
    return QPartitionExpr(ret)
register_partition_function("nn.relu", identity_partition_function)
register_partition_function("nn.max_pool2d", identity_partition_function)
register_partition_function("add", add_partition_generic)
run_mod = relay.quantize.partition()(mod)
run_mod["main"]
<class 'tvm.relay.quantize._partition.QPartitionExpr'>
<class 'tvm.relay.quantize._partition.QPartitionExpr'>
fn (%data: Tensor[(1, 3, 4, 4), float32] /* ty=Tensor[(1, 3, 4, 4), float32] span=aten___convolution_0_data:0:0 */) -> Tensor[(1, 8), float32] {
  %0 = nn.conv2d(%data, meta[relay.Constant][0] /* ty=Tensor[(16, 3, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten___convolution_0:0:0 */;
  %1 = nn.relu(%0) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten__relu_0:0:0 */;
  %2 = annotation.cast_hint(%1, dtype="int8") /* ty=Tensor[(1, 16, 4, 4), float32] */;
  %3 = annotation.stop_fusion(%2) /* ty=Tensor[(1, 16, 4, 4), float32] */;
  %4 = nn.conv2d(%3, meta[relay.Constant][1] /* ty=Tensor[(16, 16, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten___convolution_1:0:0 */;
  %5 = add(%4, meta[relay.Constant][2] /* ty=Tensor[(16, 1, 1), float32] */) /* ty=Tensor[(1, 16, 4, 4), float32] */;
  %6 = nn.relu(%5) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten__relu_1:0:0 */;
  %7 = nn.max_pool2d(%6, pool_size=[2, 2], strides=[2, 2], padding=[0, 0, 0, 0]) /* ty=Tensor[(1, 16, 2, 2), float32] span=aten__max_pool2d_0:0:0 */;
  %8 = annotation.cast_hint(%7, dtype="int8") /* ty=Tensor[(1, 16, 2, 2), float32] */;
  %9 = annotation.stop_fusion(%8) /* ty=Tensor[(1, 16, 2, 2), float32] */;
  %10 = reshape(%9, newshape=[0, -1, 1, 1]) /* ty=Tensor[(1, 64, 1, 1), float32] span=aten__flatten_0:0:0 */;
  %11 = squeeze(%10, axis=[2, 3]) /* ty=Tensor[(1, 64), float32] span=aten__flatten_0:0:0 */;
  %12 = nn.dense(%11, meta[relay.Constant][3] /* ty=Tensor[(32, 64), float32] */, units=None) /* ty=Tensor[(1, 32), float32] span=aten__linear_0:0:0 */;
  %13 = nn.relu(%12) /* ty=Tensor[(1, 32), float32] span=aten__relu_2:0:0 */;
  %14 = annotation.cast_hint(%13, dtype="int8") /* ty=Tensor[(1, 32), float32] */;
  %15 = annotation.stop_fusion(%14) /* ty=Tensor[(1, 32), float32] */;
  %16 = nn.dense(%15, meta[relay.Constant][4] /* ty=Tensor[(16, 32), float32] */, units=None) /* ty=Tensor[(1, 16), float32] span=aten__linear_1:0:0 */;
  %17 = add(%16, meta[relay.Constant][5] /* ty=Tensor[(16), float32] */) /* ty=Tensor[(1, 16), float32] */;
  %18 = nn.relu(%17) /* ty=Tensor[(1, 16), float32] span=aten__relu_3:0:0 */;
  %19 = annotation.cast_hint(%18, dtype="int8") /* ty=Tensor[(1, 16), float32] */;
  %20 = annotation.stop_fusion(%19) /* ty=Tensor[(1, 16), float32] */;
  %21 = nn.dense(%20, meta[relay.Constant][6] /* ty=Tensor[(8, 16), float32] */, units=None) /* ty=Tensor[(1, 8), float32] span=aten__linear_2:0:0 */;
  %22 = add(%21, meta[relay.Constant][7] /* ty=Tensor[(8), float32] */) /* ty=Tensor[(1, 8), float32] */;
  %23 = annotation.cast_hint(%22, dtype="int8") /* ty=Tensor[(1, 8), float32] */;
  annotation.stop_fusion(%23) /* ty=Tensor[(1, 8), float32] */
} /* ty=fn (Tensor[(1, 3, 4, 4), float32]) -> Tensor[(1, 8), float32] */
def dataset():
    for _ in range(1):
        data = np.random.normal(size=(1, 3, 4, 4)).astype("float32")
        yield {"data": data}
with tvm.transform.PassContext(opt_level=3):
    with relay.quantize.qconfig(
        calibrate_mode="kl_divergence",
        weight_scale="max",
        skip_conv_layers=[],
        skip_dense_layer=False,
    ):
        qmod = relay.quantize.quantize(mod, params, dataset())
<class 'tvm.relay.quantize._partition.QPartitionExpr'>
<class 'tvm.relay.quantize._partition.QPartitionExpr'>
WARNING:autotvm:One or more operators have not been tuned. Please tune your model for better performance. Use DEBUG logging level to see more details.
print(relay.transform.FuseOps()(qmod))
def @main(%data: Tensor[(1, 3, 4, 4), float32] /* ty=Tensor[(1, 3, 4, 4), float32] span=aten___convolution_0_data:0:0 */) -> Tensor[(1, 8), float32] {
  %44 = fn (%p08: Tensor[(1, 3, 4, 4), float32] /* ty=Tensor[(1, 3, 4, 4), float32] */, Primitive=1) -> Tensor[(1, 3, 4, 4), int8] {
    %41 = multiply(%p08, 97.7937f /* ty=float32 */) /* ty=Tensor[(1, 3, 4, 4), float32] */;
    %42 = round(%41) /* ty=Tensor[(1, 3, 4, 4), float32] */;
    %43 = clip(%42, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 3, 4, 4), float32] */;
    cast(%43, dtype="int8") /* ty=Tensor[(1, 3, 4, 4), int8] */
  } /* ty=fn (Tensor[(1, 3, 4, 4), float32]) -> Tensor[(1, 3, 4, 4), int8] */;
  %45 = %44(%data) /* ty=Tensor[(1, 3, 4, 4), int8] */;
  %46 = fn (%p07: Tensor[(1, 3, 4, 4), int8] /* ty=Tensor[(1, 3, 4, 4), int8] */, %p14: Tensor[(16, 3, 3, 3), int8] /* ty=Tensor[(16, 3, 3, 3), int8] */, Primitive=1) -> Tensor[(1, 16, 4, 4), int8] {
    %35 = nn.conv2d(%p07, %p14, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3], out_dtype="int32") /* ty=Tensor[(1, 16, 4, 4), int32] */;
    %36 = nn.relu(%35) /* ty=Tensor[(1, 16, 4, 4), int32] */;
    %37 = cast(%36, dtype="int64") /* ty=Tensor[(1, 16, 4, 4), int64] */;
    %38 = fixed_point_multiply(%37, multiplier=1144341760, shift=-8) /* ty=Tensor[(1, 16, 4, 4), int64] */;
    %39 = clip(%38, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 16, 4, 4), int64] */;
    %40 = cast(%39, dtype="int32") /* ty=Tensor[(1, 16, 4, 4), int32] */;
    cast(%40, dtype="int8") /* ty=Tensor[(1, 16, 4, 4), int8] */
  } /* ty=fn (Tensor[(1, 3, 4, 4), int8], Tensor[(16, 3, 3, 3), int8]) -> Tensor[(1, 16, 4, 4), int8] */;
  %47 = %46(%45, meta[relay.Constant][0] /* ty=Tensor[(16, 3, 3, 3), int8] */) /* ty=Tensor[(1, 16, 4, 4), int8] */;
  %48 = fn (%p06: Tensor[(1, 16, 4, 4), int8] /* ty=Tensor[(1, 16, 4, 4), int8] */, %p13: Tensor[(16, 16, 3, 3), int8] /* ty=Tensor[(16, 16, 3, 3), int8] */, %p22: Tensor[(16, 1, 1), int32] /* ty=Tensor[(16, 1, 1), int32] */, Primitive=1) -> Tensor[(1, 16, 4, 4), int8] {
    %28 = nn.conv2d(%p06, %p13, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3], out_dtype="int32") /* ty=Tensor[(1, 16, 4, 4), int32] */;
    %29 = add(%28, %p22) /* ty=Tensor[(1, 16, 4, 4), int32] */;
    %30 = nn.relu(%29) /* ty=Tensor[(1, 16, 4, 4), int32] */;
    %31 = cast(%30, dtype="int64") /* ty=Tensor[(1, 16, 4, 4), int64] */;
    %32 = fixed_point_multiply(%31, multiplier=1562188928, shift=-9) /* ty=Tensor[(1, 16, 4, 4), int64] */;
    %33 = clip(%32, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 16, 4, 4), int64] */;
    %34 = cast(%33, dtype="int32") /* ty=Tensor[(1, 16, 4, 4), int32] */;
    cast(%34, dtype="int8") /* ty=Tensor[(1, 16, 4, 4), int8] */
  } /* ty=fn (Tensor[(1, 16, 4, 4), int8], Tensor[(16, 16, 3, 3), int8], Tensor[(16, 1, 1), int32]) -> Tensor[(1, 16, 4, 4), int8] */;
  %49 = %48(%47, meta[relay.Constant][1] /* ty=Tensor[(16, 16, 3, 3), int8] */, meta[relay.Constant][2] /* ty=Tensor[(16, 1, 1), int32] */) /* ty=Tensor[(1, 16, 4, 4), int8] */;
  %50 = fn (%p05: Tensor[(1, 16, 4, 4), int8] /* ty=Tensor[(1, 16, 4, 4), int8] */, Primitive=1) -> Tensor[(1, 16, 2, 2), int8] {
    %27 = nn.max_pool2d(%p05, pool_size=[2, 2], strides=[2, 2], padding=[0, 0, 0, 0]) /* ty=Tensor[(1, 16, 2, 2), int8] */;
    cast(%27, dtype="int8") /* ty=Tensor[(1, 16, 2, 2), int8] */
  } /* ty=fn (Tensor[(1, 16, 4, 4), int8]) -> Tensor[(1, 16, 2, 2), int8] */;
  %51 = %50(%49) /* ty=Tensor[(1, 16, 2, 2), int8] */;
  %52 = fn (%p04: Tensor[(1, 16, 2, 2), int8] /* ty=Tensor[(1, 16, 2, 2), int8] */, Primitive=1) -> Tensor[(1, 64), int8] {
    %20 = reshape(%p04, newshape=[0, -1, 1, 1]) /* ty=Tensor[(1, 64, 1, 1), int8] */;
    %21 = cast(%20, dtype="float32") /* ty=Tensor[(1, 64, 1, 1), float32] */;
    %22 = multiply(%21, 0.00336449f /* ty=float32 */) /* ty=Tensor[(1, 64, 1, 1), float32] */;
    %23 = squeeze(%22, axis=[2, 3]) /* ty=Tensor[(1, 64), float32] span=aten__flatten_0:0:0 */;
    %24 = multiply(%23, 298.605f /* ty=float32 */) /* ty=Tensor[(1, 64), float32] */;
    %25 = round(%24) /* ty=Tensor[(1, 64), float32] */;
    %26 = clip(%25, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 64), float32] */;
    cast(%26, dtype="int8") /* ty=Tensor[(1, 64), int8] */
  } /* ty=fn (Tensor[(1, 16, 2, 2), int8]) -> Tensor[(1, 64), int8] */;
  %53 = %52(%51) /* ty=Tensor[(1, 64), int8] */;
  %54 = fn (%p03: Tensor[(1, 64), int8] /* ty=Tensor[(1, 64), int8] */, %p12: Tensor[(32, 64), int8] /* ty=Tensor[(32, 64), int8] */, Primitive=1) -> Tensor[(1, 32), int8] {
    %14 = nn.dense(%p03, %p12, units=None, out_dtype="int32") /* ty=Tensor[(1, 32), int32] */;
    %15 = nn.relu(%14) /* ty=Tensor[(1, 32), int32] */;
    %16 = cast(%15, dtype="int64") /* ty=Tensor[(1, 32), int64] */;
    %17 = fixed_point_multiply(%16, multiplier=1383273600, shift=-8) /* ty=Tensor[(1, 32), int64] */;
    %18 = clip(%17, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 32), int64] */;
    %19 = cast(%18, dtype="int32") /* ty=Tensor[(1, 32), int32] */;
    cast(%19, dtype="int8") /* ty=Tensor[(1, 32), int8] */
  } /* ty=fn (Tensor[(1, 64), int8], Tensor[(32, 64), int8]) -> Tensor[(1, 32), int8] */;
  %55 = %54(%53, meta[relay.Constant][3] /* ty=Tensor[(32, 64), int8] */) /* ty=Tensor[(1, 32), int8] */;
  %56 = fn (%p02: Tensor[(1, 32), int8] /* ty=Tensor[(1, 32), int8] */, %p11: Tensor[(16, 32), int8] /* ty=Tensor[(16, 32), int8] */, %p21: Tensor[(16), int32] /* ty=Tensor[(16), int32] */, Primitive=1) -> Tensor[(1, 16), int8] {
    %7 = nn.dense(%p02, %p11, units=None, out_dtype="int32") /* ty=Tensor[(1, 16), int32] */;
    %8 = add(%7, %p21) /* ty=Tensor[(1, 16), int32] */;
    %9 = nn.relu(%8) /* ty=Tensor[(1, 16), int32] */;
    %10 = cast(%9, dtype="int64") /* ty=Tensor[(1, 16), int64] */;
    %11 = fixed_point_multiply(%10, multiplier=2026565248, shift=-9) /* ty=Tensor[(1, 16), int64] */;
    %12 = clip(%11, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 16), int64] */;
    %13 = cast(%12, dtype="int32") /* ty=Tensor[(1, 16), int32] */;
    cast(%13, dtype="int8") /* ty=Tensor[(1, 16), int8] */
  } /* ty=fn (Tensor[(1, 32), int8], Tensor[(16, 32), int8], Tensor[(16), int32]) -> Tensor[(1, 16), int8] */;
  %57 = %56(%55, meta[relay.Constant][4] /* ty=Tensor[(16, 32), int8] */, meta[relay.Constant][5] /* ty=Tensor[(16), int32] */) /* ty=Tensor[(1, 16), int8] */;
  %58 = fn (%p01: Tensor[(1, 16), int8] /* ty=Tensor[(1, 16), int8] */, %p1: Tensor[(8, 16), int8] /* ty=Tensor[(8, 16), int8] */, %p2: Tensor[(8), int32] /* ty=Tensor[(8), int32] */, Primitive=1) -> Tensor[(1, 8), int8] {
    %1 = nn.dense(%p01, %p1, units=None, out_dtype="int32") /* ty=Tensor[(1, 8), int32] */;
    %2 = add(%1, %p2) /* ty=Tensor[(1, 8), int32] */;
    %3 = cast(%2, dtype="int64") /* ty=Tensor[(1, 8), int64] */;
    %4 = fixed_point_multiply(%3, multiplier=1088137088, shift=-9) /* ty=Tensor[(1, 8), int64] */;
    %5 = clip(%4, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 8), int64] */;
    %6 = cast(%5, dtype="int32") /* ty=Tensor[(1, 8), int32] */;
    cast(%6, dtype="int8") /* ty=Tensor[(1, 8), int8] */
  } /* ty=fn (Tensor[(1, 16), int8], Tensor[(8, 16), int8], Tensor[(8), int32]) -> Tensor[(1, 8), int8] */;
  %59 = %58(%57, meta[relay.Constant][6] /* ty=Tensor[(8, 16), int8] */, meta[relay.Constant][7] /* ty=Tensor[(8), int32] */) /* ty=Tensor[(1, 8), int8] */;
  %60 = fn (%p0: Tensor[(1, 8), int8] /* ty=Tensor[(1, 8), int8] */, Primitive=1) -> Tensor[(1, 8), float32] {
    %0 = cast(%p0, dtype="float32") /* ty=Tensor[(1, 8), float32] */;
    multiply(%0, 0.00189963f /* ty=float32 */) /* ty=Tensor[(1, 8), float32] */
  } /* ty=fn (Tensor[(1, 8), int8]) -> Tensor[(1, 8), float32] */;
  %60(%59) /* ty=Tensor[(1, 8), float32] */
}