量化分区#
构建 PyTorch 前端模型:
from torch import nn
import torch
class Model(nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.conv = nn.Conv2d(3, 16, 3, 1, 1, bias=False)
self.conv2 = nn.Conv2d(16, 16, 3, 1, 1, bias=True)
self.relu = nn.ReLU()
self.pool = nn.MaxPool2d(kernel_size=2)
self.dense1 = nn.Linear(64, 32, bias=False)
self.dense2 = nn.Linear(32, 16)
self.dense3 = nn.Linear(16, 8)
def forward(self, x):
x = self.conv(x)
x = self.relu(x)
x = self.conv2(x)
x = self.relu(x)
x = self.pool(x)
x = x.flatten(1)
x = self.dense1(x)
x = self.relu(x)
x = self.dense2(x)
x = self.relu(x)
x = self.dense3(x)
return x
翻译 PyTorch 前端模型为 Relay 模型:
%cd ..
import testing
/media/pc/data/lxw/ai/tvm-book/doc/read/relay
import numpy as np
import set_env
import tvm
from tvm import relay
# 输入数据
input_shape = (1, 3, 4, 4)
input_dtype = "float32"
data_np = np.random.rand(*input_shape).astype(input_dtype)
with torch.no_grad():
pt_model = Model().eval().float()
traced_model = torch.jit.trace(pt_model, torch.from_numpy(data_np)).eval()
mod, params = relay.frontend.from_pytorch(traced_model, [("data", input_shape)],
use_parser_friendly_name=True)
with tvm.transform.PassContext(opt_level=3):
mod = relay.quantize.prerequisite_optimize(mod, params)
print(mod['main'])
fn (%data: Tensor[(1, 3, 4, 4), float32] /* ty=Tensor[(1, 3, 4, 4), float32] span=aten___convolution_0_data:0:0 */) -> Tensor[(1, 8), float32] {
%0 = nn.conv2d(%data, meta[relay.Constant][0] /* ty=Tensor[(16, 3, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten___convolution_0:0:0 */;
%1 = nn.relu(%0) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten__relu_0:0:0 */;
%2 = nn.conv2d(%1, meta[relay.Constant][1] /* ty=Tensor[(16, 16, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten___convolution_1:0:0 */;
%3 = add(%2, meta[relay.Constant][2] /* ty=Tensor[(16, 1, 1), float32] */) /* ty=Tensor[(1, 16, 4, 4), float32] */;
%4 = nn.relu(%3) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten__relu_1:0:0 */;
%5 = nn.max_pool2d(%4, pool_size=[2, 2], strides=[2, 2], padding=[0, 0, 0, 0]) /* ty=Tensor[(1, 16, 2, 2), float32] span=aten__max_pool2d_0:0:0 */;
%6 = reshape(%5, newshape=[0, -1, 1, 1]) /* ty=Tensor[(1, 64, 1, 1), float32] span=aten__flatten_0:0:0 */;
%7 = squeeze(%6, axis=[2, 3]) /* ty=Tensor[(1, 64), float32] span=aten__flatten_0:0:0 */;
%8 = nn.dense(%7, meta[relay.Constant][3] /* ty=Tensor[(32, 64), float32] */, units=None) /* ty=Tensor[(1, 32), float32] span=aten__linear_0:0:0 */;
%9 = nn.relu(%8) /* ty=Tensor[(1, 32), float32] span=aten__relu_2:0:0 */;
%10 = nn.dense(%9, meta[relay.Constant][4] /* ty=Tensor[(16, 32), float32] */, units=None) /* ty=Tensor[(1, 16), float32] span=aten__linear_1:0:0 */;
%11 = add(%10, meta[relay.Constant][5] /* ty=Tensor[(16), float32] */) /* ty=Tensor[(1, 16), float32] */;
%12 = nn.relu(%11) /* ty=Tensor[(1, 16), float32] span=aten__relu_3:0:0 */;
%13 = nn.dense(%12, meta[relay.Constant][6] /* ty=Tensor[(8, 16), float32] */, units=None) /* ty=Tensor[(1, 8), float32] span=aten__linear_2:0:0 */;
add(%13, meta[relay.Constant][7] /* ty=Tensor[(8), float32] */) /* ty=Tensor[(1, 8), float32] */
} /* ty=fn (Tensor[(1, 3, 4, 4), float32]) -> Tensor[(1, 8), float32] */
调用分区 pass:
relay.quantize.partition()(mod)["main"]
fn (%data: Tensor[(1, 3, 4, 4), float32] /* ty=Tensor[(1, 3, 4, 4), float32] span=aten___convolution_0_data:0:0 */) -> Tensor[(1, 8), float32] {
%0 = nn.conv2d(%data, meta[relay.Constant][0] /* ty=Tensor[(16, 3, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten___convolution_0:0:0 */;
%1 = nn.relu(%0) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten__relu_0:0:0 */;
%2 = annotation.cast_hint(%1, dtype="int8") /* ty=Tensor[(1, 16, 4, 4), float32] */;
%3 = annotation.stop_fusion(%2) /* ty=Tensor[(1, 16, 4, 4), float32] */;
%4 = nn.conv2d(%3, meta[relay.Constant][1] /* ty=Tensor[(16, 16, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten___convolution_1:0:0 */;
%5 = add(%4, meta[relay.Constant][2] /* ty=Tensor[(16, 1, 1), float32] */) /* ty=Tensor[(1, 16, 4, 4), float32] */;
%6 = nn.relu(%5) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten__relu_1:0:0 */;
%7 = nn.max_pool2d(%6, pool_size=[2, 2], strides=[2, 2], padding=[0, 0, 0, 0]) /* ty=Tensor[(1, 16, 2, 2), float32] span=aten__max_pool2d_0:0:0 */;
%8 = annotation.cast_hint(%7, dtype="int8") /* ty=Tensor[(1, 16, 2, 2), float32] */;
%9 = annotation.stop_fusion(%8) /* ty=Tensor[(1, 16, 2, 2), float32] */;
%10 = reshape(%9, newshape=[0, -1, 1, 1]) /* ty=Tensor[(1, 64, 1, 1), float32] span=aten__flatten_0:0:0 */;
%11 = squeeze(%10, axis=[2, 3]) /* ty=Tensor[(1, 64), float32] span=aten__flatten_0:0:0 */;
%12 = nn.dense(%11, meta[relay.Constant][3] /* ty=Tensor[(32, 64), float32] */, units=None) /* ty=Tensor[(1, 32), float32] span=aten__linear_0:0:0 */;
%13 = nn.relu(%12) /* ty=Tensor[(1, 32), float32] span=aten__relu_2:0:0 */;
%14 = nn.dense(%13, meta[relay.Constant][4] /* ty=Tensor[(16, 32), float32] */, units=None) /* ty=Tensor[(1, 16), float32] span=aten__linear_1:0:0 */;
%15 = add(%14, meta[relay.Constant][5] /* ty=Tensor[(16), float32] */) /* ty=Tensor[(1, 16), float32] */;
%16 = nn.relu(%15) /* ty=Tensor[(1, 16), float32] span=aten__relu_3:0:0 */;
%17 = nn.dense(%16, meta[relay.Constant][6] /* ty=Tensor[(8, 16), float32] */, units=None) /* ty=Tensor[(1, 8), float32] span=aten__linear_2:0:0 */;
add(%17, meta[relay.Constant][7] /* ty=Tensor[(8), float32] */) /* ty=Tensor[(1, 8), float32] */
} /* ty=fn (Tensor[(1, 3, 4, 4), float32]) -> Tensor[(1, 8), float32] */
为了展示分区的效果,将 TVM 默认注册的属性函数清除:
from tvm.relay.op import op as _op
def reset_partition():
op_names = ["nn.conv2d", "nn.relu", "nn.max_pool2d", "add", "multiply", "clip", "nn.global_avg_pool2d"]
for op_name in op_names:
_op.get(op_name).reset_attr("FQPartitionRewrite")
此时调用分区 pass,则有:
reset_partition()
relay.quantize.partition()(mod)["main"]
fn (%data: Tensor[(1, 3, 4, 4), float32] /* ty=Tensor[(1, 3, 4, 4), float32] span=aten___convolution_0_data:0:0 */) -> Tensor[(1, 8), float32] {
%0 = nn.conv2d(%data, meta[relay.Constant][0] /* ty=Tensor[(16, 3, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten___convolution_0:0:0 */;
%1 = nn.relu(%0) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten__relu_0:0:0 */;
%2 = nn.conv2d(%1, meta[relay.Constant][1] /* ty=Tensor[(16, 16, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten___convolution_1:0:0 */;
%3 = add(%2, meta[relay.Constant][2] /* ty=Tensor[(16, 1, 1), float32] */) /* ty=Tensor[(1, 16, 4, 4), float32] */;
%4 = nn.relu(%3) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten__relu_1:0:0 */;
%5 = nn.max_pool2d(%4, pool_size=[2, 2], strides=[2, 2], padding=[0, 0, 0, 0]) /* ty=Tensor[(1, 16, 2, 2), float32] span=aten__max_pool2d_0:0:0 */;
%6 = reshape(%5, newshape=[0, -1, 1, 1]) /* ty=Tensor[(1, 64, 1, 1), float32] span=aten__flatten_0:0:0 */;
%7 = squeeze(%6, axis=[2, 3]) /* ty=Tensor[(1, 64), float32] span=aten__flatten_0:0:0 */;
%8 = nn.dense(%7, meta[relay.Constant][3] /* ty=Tensor[(32, 64), float32] */, units=None) /* ty=Tensor[(1, 32), float32] span=aten__linear_0:0:0 */;
%9 = nn.relu(%8) /* ty=Tensor[(1, 32), float32] span=aten__relu_2:0:0 */;
%10 = nn.dense(%9, meta[relay.Constant][4] /* ty=Tensor[(16, 32), float32] */, units=None) /* ty=Tensor[(1, 16), float32] span=aten__linear_1:0:0 */;
%11 = add(%10, meta[relay.Constant][5] /* ty=Tensor[(16), float32] */) /* ty=Tensor[(1, 16), float32] */;
%12 = nn.relu(%11) /* ty=Tensor[(1, 16), float32] span=aten__relu_3:0:0 */;
%13 = nn.dense(%12, meta[relay.Constant][6] /* ty=Tensor[(8, 16), float32] */, units=None) /* ty=Tensor[(1, 8), float32] span=aten__linear_2:0:0 */;
add(%13, meta[relay.Constant][7] /* ty=Tensor[(8), float32] */) /* ty=Tensor[(1, 8), float32] */
} /* ty=fn (Tensor[(1, 3, 4, 4), float32]) -> Tensor[(1, 8), float32] */
下面看看如何从算子层面考虑对模型进行分区。
from tvm.relay.quantize._partition import (
partition_expr_check,
QPartitionExpr,
register_partition_function
)
from tvm.relay.quantize.quantize import _forward_op
register_partition_function??
Signature: register_partition_function(op_name, frewrite=None, level=10)
Docstring: <no docstring>
Source:
def register_partition_function(op_name, frewrite=None, level=10):
return tvm.ir.register_op_attr(op_name, "FQPartitionRewrite", frewrite, level)
File: /media/pc/data/lxw/ai/tvm/python/tvm/relay/quantize/_partition.py
Type: function
register_partition_function()
为算子注册 "FQPartitionRewrite"
属性。
partition_expr_check??
Signature: partition_expr_check(expr)
Docstring: <no docstring>
Source:
def partition_expr_check(expr):
if isinstance(expr, QPartitionExpr):
return True, expr.expr
return False, expr
File: /media/pc/data/lxw/ai/tvm/python/tvm/relay/quantize/_partition.py
Type: function
partition_expr_check()
用于检查 expr
是否为 QPartitionExpr
。
_forward_op??
Signature: _forward_op(ref_call, args)
Source:
def _forward_op(ref_call, args):
"""forward the operator of ref_call with provided arguments"""
return _expr.Call(ref_call.op, args, ref_call.attrs, ref_call.type_args, ref_call.span)
File: /media/pc/data/lxw/ai/tvm/python/tvm/relay/quantize/quantize.py
Type: function
_forward_op()
用于重写回调函数。
nn.relu
和 nn.max_pool2d
分区#
对 nn.relu
进行分区,可以:
from tvm.relay.op import op as _op
_op.get("nn.relu").reset_attr("FQPartitionRewrite")
@register_partition_function("nn.relu")
def _relu_partition(ref_call, new_args, ctx):
return QPartitionExpr(_forward_op(ref_call, new_args))
run_mod = relay.quantize.partition()(mod)
run_mod["main"]
fn (%data: Tensor[(1, 3, 4, 4), float32] /* ty=Tensor[(1, 3, 4, 4), float32] span=aten___convolution_0_data:0:0 */) -> Tensor[(1, 8), float32] {
%0 = nn.conv2d(%data, meta[relay.Constant][0] /* ty=Tensor[(16, 3, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten___convolution_0:0:0 */;
%1 = nn.relu(%0) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten__relu_0:0:0 */;
%2 = annotation.cast_hint(%1, dtype="int8") /* ty=Tensor[(1, 16, 4, 4), float32] */;
%3 = annotation.stop_fusion(%2) /* ty=Tensor[(1, 16, 4, 4), float32] */;
%4 = nn.conv2d(%3, meta[relay.Constant][1] /* ty=Tensor[(16, 16, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten___convolution_1:0:0 */;
%5 = add(%4, meta[relay.Constant][2] /* ty=Tensor[(16, 1, 1), float32] */) /* ty=Tensor[(1, 16, 4, 4), float32] */;
%6 = nn.relu(%5) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten__relu_1:0:0 */;
%7 = annotation.cast_hint(%6, dtype="int8") /* ty=Tensor[(1, 16, 4, 4), float32] */;
%8 = annotation.stop_fusion(%7) /* ty=Tensor[(1, 16, 4, 4), float32] */;
%9 = nn.max_pool2d(%8, pool_size=[2, 2], strides=[2, 2], padding=[0, 0, 0, 0]) /* ty=Tensor[(1, 16, 2, 2), float32] span=aten__max_pool2d_0:0:0 */;
%10 = reshape(%9, newshape=[0, -1, 1, 1]) /* ty=Tensor[(1, 64, 1, 1), float32] span=aten__flatten_0:0:0 */;
%11 = squeeze(%10, axis=[2, 3]) /* ty=Tensor[(1, 64), float32] span=aten__flatten_0:0:0 */;
%12 = nn.dense(%11, meta[relay.Constant][3] /* ty=Tensor[(32, 64), float32] */, units=None) /* ty=Tensor[(1, 32), float32] span=aten__linear_0:0:0 */;
%13 = nn.relu(%12) /* ty=Tensor[(1, 32), float32] span=aten__relu_2:0:0 */;
%14 = annotation.cast_hint(%13, dtype="int8") /* ty=Tensor[(1, 32), float32] */;
%15 = annotation.stop_fusion(%14) /* ty=Tensor[(1, 32), float32] */;
%16 = nn.dense(%15, meta[relay.Constant][4] /* ty=Tensor[(16, 32), float32] */, units=None) /* ty=Tensor[(1, 16), float32] span=aten__linear_1:0:0 */;
%17 = add(%16, meta[relay.Constant][5] /* ty=Tensor[(16), float32] */) /* ty=Tensor[(1, 16), float32] */;
%18 = nn.relu(%17) /* ty=Tensor[(1, 16), float32] span=aten__relu_3:0:0 */;
%19 = annotation.cast_hint(%18, dtype="int8") /* ty=Tensor[(1, 16), float32] */;
%20 = annotation.stop_fusion(%19) /* ty=Tensor[(1, 16), float32] */;
%21 = nn.dense(%20, meta[relay.Constant][6] /* ty=Tensor[(8, 16), float32] */, units=None) /* ty=Tensor[(1, 8), float32] span=aten__linear_2:0:0 */;
add(%21, meta[relay.Constant][7] /* ty=Tensor[(8), float32] */) /* ty=Tensor[(1, 8), float32] */
} /* ty=fn (Tensor[(1, 3, 4, 4), float32]) -> Tensor[(1, 8), float32] */
也可以对 nn.max_pool2d
执行同样的操作:
from tvm.relay.op import op as _op
_op.get("nn.relu").reset_attr("FQPartitionRewrite")
_op.get("nn.max_pool2d").reset_attr("FQPartitionRewrite")
@register_partition_function("nn.max_pool2d")
def _max_pool2d_partition(ref_call, new_args, ctx):
return QPartitionExpr(_forward_op(ref_call, new_args))
run_mod = relay.quantize.partition()(mod)
run_mod["main"]
fn (%data: Tensor[(1, 3, 4, 4), float32] /* ty=Tensor[(1, 3, 4, 4), float32] span=aten___convolution_0_data:0:0 */) -> Tensor[(1, 8), float32] {
%0 = nn.conv2d(%data, meta[relay.Constant][0] /* ty=Tensor[(16, 3, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten___convolution_0:0:0 */;
%1 = nn.relu(%0) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten__relu_0:0:0 */;
%2 = nn.conv2d(%1, meta[relay.Constant][1] /* ty=Tensor[(16, 16, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten___convolution_1:0:0 */;
%3 = add(%2, meta[relay.Constant][2] /* ty=Tensor[(16, 1, 1), float32] */) /* ty=Tensor[(1, 16, 4, 4), float32] */;
%4 = nn.relu(%3) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten__relu_1:0:0 */;
%5 = nn.max_pool2d(%4, pool_size=[2, 2], strides=[2, 2], padding=[0, 0, 0, 0]) /* ty=Tensor[(1, 16, 2, 2), float32] span=aten__max_pool2d_0:0:0 */;
%6 = annotation.cast_hint(%5, dtype="int8") /* ty=Tensor[(1, 16, 2, 2), float32] */;
%7 = annotation.stop_fusion(%6) /* ty=Tensor[(1, 16, 2, 2), float32] */;
%8 = reshape(%7, newshape=[0, -1, 1, 1]) /* ty=Tensor[(1, 64, 1, 1), float32] span=aten__flatten_0:0:0 */;
%9 = squeeze(%8, axis=[2, 3]) /* ty=Tensor[(1, 64), float32] span=aten__flatten_0:0:0 */;
%10 = nn.dense(%9, meta[relay.Constant][3] /* ty=Tensor[(32, 64), float32] */, units=None) /* ty=Tensor[(1, 32), float32] span=aten__linear_0:0:0 */;
%11 = nn.relu(%10) /* ty=Tensor[(1, 32), float32] span=aten__relu_2:0:0 */;
%12 = nn.dense(%11, meta[relay.Constant][4] /* ty=Tensor[(16, 32), float32] */, units=None) /* ty=Tensor[(1, 16), float32] span=aten__linear_1:0:0 */;
%13 = add(%12, meta[relay.Constant][5] /* ty=Tensor[(16), float32] */) /* ty=Tensor[(1, 16), float32] */;
%14 = nn.relu(%13) /* ty=Tensor[(1, 16), float32] span=aten__relu_3:0:0 */;
%15 = nn.dense(%14, meta[relay.Constant][6] /* ty=Tensor[(8, 16), float32] */, units=None) /* ty=Tensor[(1, 8), float32] span=aten__linear_2:0:0 */;
add(%15, meta[relay.Constant][7] /* ty=Tensor[(8), float32] */) /* ty=Tensor[(1, 8), float32] */
} /* ty=fn (Tensor[(1, 3, 4, 4), float32]) -> Tensor[(1, 8), float32] */
_max_pool2d_partition()
和 _relu_partition()
本质上是同样的功能实现,可以将它们合并为一个函数:
def identity_partition_function(ref_call, new_args, ctx):
cond, expr = partition_expr_check(new_args[0])
return QPartitionExpr(_forward_op(ref_call, [expr]))
reset_partition()
register_partition_function("nn.relu", identity_partition_function)
register_partition_function("nn.max_pool2d", identity_partition_function)
<function __main__.identity_partition_function(ref_call, new_args, ctx)>
可以看出 nn.max_pool2d
+ nn.relu
实现了融合:
run_mod = relay.quantize.partition()(mod)
run_mod["main"]
fn (%data: Tensor[(1, 3, 4, 4), float32] /* ty=Tensor[(1, 3, 4, 4), float32] span=aten___convolution_0_data:0:0 */) -> Tensor[(1, 8), float32] {
%0 = nn.conv2d(%data, meta[relay.Constant][0] /* ty=Tensor[(16, 3, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten___convolution_0:0:0 */;
%1 = nn.relu(%0) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten__relu_0:0:0 */;
%2 = annotation.cast_hint(%1, dtype="int8") /* ty=Tensor[(1, 16, 4, 4), float32] */;
%3 = annotation.stop_fusion(%2) /* ty=Tensor[(1, 16, 4, 4), float32] */;
%4 = nn.conv2d(%3, meta[relay.Constant][1] /* ty=Tensor[(16, 16, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten___convolution_1:0:0 */;
%5 = add(%4, meta[relay.Constant][2] /* ty=Tensor[(16, 1, 1), float32] */) /* ty=Tensor[(1, 16, 4, 4), float32] */;
%6 = nn.relu(%5) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten__relu_1:0:0 */;
%7 = nn.max_pool2d(%6, pool_size=[2, 2], strides=[2, 2], padding=[0, 0, 0, 0]) /* ty=Tensor[(1, 16, 2, 2), float32] span=aten__max_pool2d_0:0:0 */;
%8 = annotation.cast_hint(%7, dtype="int8") /* ty=Tensor[(1, 16, 2, 2), float32] */;
%9 = annotation.stop_fusion(%8) /* ty=Tensor[(1, 16, 2, 2), float32] */;
%10 = reshape(%9, newshape=[0, -1, 1, 1]) /* ty=Tensor[(1, 64, 1, 1), float32] span=aten__flatten_0:0:0 */;
%11 = squeeze(%10, axis=[2, 3]) /* ty=Tensor[(1, 64), float32] span=aten__flatten_0:0:0 */;
%12 = nn.dense(%11, meta[relay.Constant][3] /* ty=Tensor[(32, 64), float32] */, units=None) /* ty=Tensor[(1, 32), float32] span=aten__linear_0:0:0 */;
%13 = nn.relu(%12) /* ty=Tensor[(1, 32), float32] span=aten__relu_2:0:0 */;
%14 = annotation.cast_hint(%13, dtype="int8") /* ty=Tensor[(1, 32), float32] */;
%15 = annotation.stop_fusion(%14) /* ty=Tensor[(1, 32), float32] */;
%16 = nn.dense(%15, meta[relay.Constant][4] /* ty=Tensor[(16, 32), float32] */, units=None) /* ty=Tensor[(1, 16), float32] span=aten__linear_1:0:0 */;
%17 = add(%16, meta[relay.Constant][5] /* ty=Tensor[(16), float32] */) /* ty=Tensor[(1, 16), float32] */;
%18 = nn.relu(%17) /* ty=Tensor[(1, 16), float32] span=aten__relu_3:0:0 */;
%19 = annotation.cast_hint(%18, dtype="int8") /* ty=Tensor[(1, 16), float32] */;
%20 = annotation.stop_fusion(%19) /* ty=Tensor[(1, 16), float32] */;
%21 = nn.dense(%20, meta[relay.Constant][6] /* ty=Tensor[(8, 16), float32] */, units=None) /* ty=Tensor[(1, 8), float32] span=aten__linear_2:0:0 */;
add(%21, meta[relay.Constant][7] /* ty=Tensor[(8), float32] */) /* ty=Tensor[(1, 8), float32] */
} /* ty=fn (Tensor[(1, 3, 4, 4), float32]) -> Tensor[(1, 8), float32] */
dense+add+relu
分区#
from tvm.relay import analysis as _analysis
reset_partition()
_op.get("nn.dense").reset_attr("FQPartitionRewrite")
def add_partition_generic(ref_call, new_args, ctx):
"""Rewrite function for ewise add for partition for generic devices"""
lhs_cond, lhs = partition_expr_check(new_args[0])
rhs_cond, rhs = partition_expr_check(new_args[1])
if lhs_cond and rhs_cond:
# - introduced by ResNet, when for the first residual connection
# ...
# %0 = nn.conv2d(%data, %meta[relay.Constant])
# %1 = add(%0, %meta[relay.Constant])
# %2 = nn.relu(%1)
# %3 = nn.max_pool2d(%2)
# ...
# %9 = nn.conv2d(%8, %meta[relay.Constant])
# %10 = add(%9, %meta[relay.Constant])
# %11 = add(%3, %10) <- need to insert annotations for %3, %10
# ...
lhs = new_args[0].realize()
rhs = new_args[1].realize()
return QPartitionExpr(_forward_op(ref_call, [lhs, rhs]))
if not lhs_cond and rhs_cond:
# - introduced by residual connection in ResNet
# ...
# %13 = nn.conv2d(%12, %meta[relay.Constant])
# %14 = add(%13, %meta[relay.Constant])
# %15 = annotation.cast_hint(%15, 'int8')
# %16 = annotation.stop_fusion(%16)
# %17 = add(%5, %16)
# %18 = nn.relu(%17)
# ...
# %24 = nn.conv2d(%23, %meta[relay.Constant])
# %25 = add(%24, %meta[relay.Constant])
# %26 = add(%18, %25) <- need to insert annotations for %25
# ...
rhs = new_args[1].realize()
return _forward_op(ref_call, [lhs, rhs])
if lhs_cond and not rhs_cond:
if _analysis.check_constant(rhs):
# - introduced by batch_norm: add(out, bias)
return QPartitionExpr(_forward_op(ref_call, [lhs, rhs]))
# - introduced by residual connection in MobileNetV2
# ...
# %81 = add(%80, meta[relay.Constant])
# %82 = annotation.cast_hint(%81, 'int8')
# %83 = annotation.stop_fusion(%82)
# %84 = add(%79, %83)
# ...
# %96 = nn.conv2d(%94, %meta[relay.Constant])
# %96 = add(%95, %meta[relay.Constant])
# %97 = add(%96, %84) <- need to insert annotations for %96
# ...
lhs = new_args[0].realize()
return _forward_op(ref_call, [lhs, rhs])
if not lhs_cond and not rhs_cond:
# trivial case
return None
raise ValueError
@register_partition_function("nn.dense")
def dense_partition_function(ref_call, new_args, ctx):
"""Rewrite function for dense for partition"""
data_cond, data = partition_expr_check(new_args[0])
kernel_cond, kernel = partition_expr_check(new_args[1])
assert not kernel_cond
if data_cond:
print(type(new_args[0]))
data = new_args[0].realize()
ret = _forward_op(ref_call, [data, kernel])
return QPartitionExpr(ret)
register_partition_function("nn.relu", identity_partition_function)
register_partition_function("nn.max_pool2d", identity_partition_function)
register_partition_function("add", add_partition_generic)
run_mod = relay.quantize.partition()(mod)
run_mod["main"]
<class 'tvm.relay.quantize._partition.QPartitionExpr'>
<class 'tvm.relay.quantize._partition.QPartitionExpr'>
fn (%data: Tensor[(1, 3, 4, 4), float32] /* ty=Tensor[(1, 3, 4, 4), float32] span=aten___convolution_0_data:0:0 */) -> Tensor[(1, 8), float32] {
%0 = nn.conv2d(%data, meta[relay.Constant][0] /* ty=Tensor[(16, 3, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten___convolution_0:0:0 */;
%1 = nn.relu(%0) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten__relu_0:0:0 */;
%2 = annotation.cast_hint(%1, dtype="int8") /* ty=Tensor[(1, 16, 4, 4), float32] */;
%3 = annotation.stop_fusion(%2) /* ty=Tensor[(1, 16, 4, 4), float32] */;
%4 = nn.conv2d(%3, meta[relay.Constant][1] /* ty=Tensor[(16, 16, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten___convolution_1:0:0 */;
%5 = add(%4, meta[relay.Constant][2] /* ty=Tensor[(16, 1, 1), float32] */) /* ty=Tensor[(1, 16, 4, 4), float32] */;
%6 = nn.relu(%5) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten__relu_1:0:0 */;
%7 = nn.max_pool2d(%6, pool_size=[2, 2], strides=[2, 2], padding=[0, 0, 0, 0]) /* ty=Tensor[(1, 16, 2, 2), float32] span=aten__max_pool2d_0:0:0 */;
%8 = annotation.cast_hint(%7, dtype="int8") /* ty=Tensor[(1, 16, 2, 2), float32] */;
%9 = annotation.stop_fusion(%8) /* ty=Tensor[(1, 16, 2, 2), float32] */;
%10 = reshape(%9, newshape=[0, -1, 1, 1]) /* ty=Tensor[(1, 64, 1, 1), float32] span=aten__flatten_0:0:0 */;
%11 = squeeze(%10, axis=[2, 3]) /* ty=Tensor[(1, 64), float32] span=aten__flatten_0:0:0 */;
%12 = nn.dense(%11, meta[relay.Constant][3] /* ty=Tensor[(32, 64), float32] */, units=None) /* ty=Tensor[(1, 32), float32] span=aten__linear_0:0:0 */;
%13 = nn.relu(%12) /* ty=Tensor[(1, 32), float32] span=aten__relu_2:0:0 */;
%14 = annotation.cast_hint(%13, dtype="int8") /* ty=Tensor[(1, 32), float32] */;
%15 = annotation.stop_fusion(%14) /* ty=Tensor[(1, 32), float32] */;
%16 = nn.dense(%15, meta[relay.Constant][4] /* ty=Tensor[(16, 32), float32] */, units=None) /* ty=Tensor[(1, 16), float32] span=aten__linear_1:0:0 */;
%17 = add(%16, meta[relay.Constant][5] /* ty=Tensor[(16), float32] */) /* ty=Tensor[(1, 16), float32] */;
%18 = nn.relu(%17) /* ty=Tensor[(1, 16), float32] span=aten__relu_3:0:0 */;
%19 = annotation.cast_hint(%18, dtype="int8") /* ty=Tensor[(1, 16), float32] */;
%20 = annotation.stop_fusion(%19) /* ty=Tensor[(1, 16), float32] */;
%21 = nn.dense(%20, meta[relay.Constant][6] /* ty=Tensor[(8, 16), float32] */, units=None) /* ty=Tensor[(1, 8), float32] span=aten__linear_2:0:0 */;
%22 = add(%21, meta[relay.Constant][7] /* ty=Tensor[(8), float32] */) /* ty=Tensor[(1, 8), float32] */;
%23 = annotation.cast_hint(%22, dtype="int8") /* ty=Tensor[(1, 8), float32] */;
annotation.stop_fusion(%23) /* ty=Tensor[(1, 8), float32] */
} /* ty=fn (Tensor[(1, 3, 4, 4), float32]) -> Tensor[(1, 8), float32] */
def dataset():
for _ in range(1):
data = np.random.normal(size=(1, 3, 4, 4)).astype("float32")
yield {"data": data}
with tvm.transform.PassContext(opt_level=3):
with relay.quantize.qconfig(
calibrate_mode="kl_divergence",
weight_scale="max",
skip_conv_layers=[],
skip_dense_layer=False,
):
qmod = relay.quantize.quantize(mod, params, dataset())
<class 'tvm.relay.quantize._partition.QPartitionExpr'>
<class 'tvm.relay.quantize._partition.QPartitionExpr'>
WARNING:autotvm:One or more operators have not been tuned. Please tune your model for better performance. Use DEBUG logging level to see more details.
print(relay.transform.FuseOps()(qmod))
def @main(%data: Tensor[(1, 3, 4, 4), float32] /* ty=Tensor[(1, 3, 4, 4), float32] span=aten___convolution_0_data:0:0 */) -> Tensor[(1, 8), float32] {
%44 = fn (%p08: Tensor[(1, 3, 4, 4), float32] /* ty=Tensor[(1, 3, 4, 4), float32] */, Primitive=1) -> Tensor[(1, 3, 4, 4), int8] {
%41 = multiply(%p08, 97.7937f /* ty=float32 */) /* ty=Tensor[(1, 3, 4, 4), float32] */;
%42 = round(%41) /* ty=Tensor[(1, 3, 4, 4), float32] */;
%43 = clip(%42, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 3, 4, 4), float32] */;
cast(%43, dtype="int8") /* ty=Tensor[(1, 3, 4, 4), int8] */
} /* ty=fn (Tensor[(1, 3, 4, 4), float32]) -> Tensor[(1, 3, 4, 4), int8] */;
%45 = %44(%data) /* ty=Tensor[(1, 3, 4, 4), int8] */;
%46 = fn (%p07: Tensor[(1, 3, 4, 4), int8] /* ty=Tensor[(1, 3, 4, 4), int8] */, %p14: Tensor[(16, 3, 3, 3), int8] /* ty=Tensor[(16, 3, 3, 3), int8] */, Primitive=1) -> Tensor[(1, 16, 4, 4), int8] {
%35 = nn.conv2d(%p07, %p14, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3], out_dtype="int32") /* ty=Tensor[(1, 16, 4, 4), int32] */;
%36 = nn.relu(%35) /* ty=Tensor[(1, 16, 4, 4), int32] */;
%37 = cast(%36, dtype="int64") /* ty=Tensor[(1, 16, 4, 4), int64] */;
%38 = fixed_point_multiply(%37, multiplier=1144341760, shift=-8) /* ty=Tensor[(1, 16, 4, 4), int64] */;
%39 = clip(%38, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 16, 4, 4), int64] */;
%40 = cast(%39, dtype="int32") /* ty=Tensor[(1, 16, 4, 4), int32] */;
cast(%40, dtype="int8") /* ty=Tensor[(1, 16, 4, 4), int8] */
} /* ty=fn (Tensor[(1, 3, 4, 4), int8], Tensor[(16, 3, 3, 3), int8]) -> Tensor[(1, 16, 4, 4), int8] */;
%47 = %46(%45, meta[relay.Constant][0] /* ty=Tensor[(16, 3, 3, 3), int8] */) /* ty=Tensor[(1, 16, 4, 4), int8] */;
%48 = fn (%p06: Tensor[(1, 16, 4, 4), int8] /* ty=Tensor[(1, 16, 4, 4), int8] */, %p13: Tensor[(16, 16, 3, 3), int8] /* ty=Tensor[(16, 16, 3, 3), int8] */, %p22: Tensor[(16, 1, 1), int32] /* ty=Tensor[(16, 1, 1), int32] */, Primitive=1) -> Tensor[(1, 16, 4, 4), int8] {
%28 = nn.conv2d(%p06, %p13, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3], out_dtype="int32") /* ty=Tensor[(1, 16, 4, 4), int32] */;
%29 = add(%28, %p22) /* ty=Tensor[(1, 16, 4, 4), int32] */;
%30 = nn.relu(%29) /* ty=Tensor[(1, 16, 4, 4), int32] */;
%31 = cast(%30, dtype="int64") /* ty=Tensor[(1, 16, 4, 4), int64] */;
%32 = fixed_point_multiply(%31, multiplier=1562188928, shift=-9) /* ty=Tensor[(1, 16, 4, 4), int64] */;
%33 = clip(%32, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 16, 4, 4), int64] */;
%34 = cast(%33, dtype="int32") /* ty=Tensor[(1, 16, 4, 4), int32] */;
cast(%34, dtype="int8") /* ty=Tensor[(1, 16, 4, 4), int8] */
} /* ty=fn (Tensor[(1, 16, 4, 4), int8], Tensor[(16, 16, 3, 3), int8], Tensor[(16, 1, 1), int32]) -> Tensor[(1, 16, 4, 4), int8] */;
%49 = %48(%47, meta[relay.Constant][1] /* ty=Tensor[(16, 16, 3, 3), int8] */, meta[relay.Constant][2] /* ty=Tensor[(16, 1, 1), int32] */) /* ty=Tensor[(1, 16, 4, 4), int8] */;
%50 = fn (%p05: Tensor[(1, 16, 4, 4), int8] /* ty=Tensor[(1, 16, 4, 4), int8] */, Primitive=1) -> Tensor[(1, 16, 2, 2), int8] {
%27 = nn.max_pool2d(%p05, pool_size=[2, 2], strides=[2, 2], padding=[0, 0, 0, 0]) /* ty=Tensor[(1, 16, 2, 2), int8] */;
cast(%27, dtype="int8") /* ty=Tensor[(1, 16, 2, 2), int8] */
} /* ty=fn (Tensor[(1, 16, 4, 4), int8]) -> Tensor[(1, 16, 2, 2), int8] */;
%51 = %50(%49) /* ty=Tensor[(1, 16, 2, 2), int8] */;
%52 = fn (%p04: Tensor[(1, 16, 2, 2), int8] /* ty=Tensor[(1, 16, 2, 2), int8] */, Primitive=1) -> Tensor[(1, 64), int8] {
%20 = reshape(%p04, newshape=[0, -1, 1, 1]) /* ty=Tensor[(1, 64, 1, 1), int8] */;
%21 = cast(%20, dtype="float32") /* ty=Tensor[(1, 64, 1, 1), float32] */;
%22 = multiply(%21, 0.00336449f /* ty=float32 */) /* ty=Tensor[(1, 64, 1, 1), float32] */;
%23 = squeeze(%22, axis=[2, 3]) /* ty=Tensor[(1, 64), float32] span=aten__flatten_0:0:0 */;
%24 = multiply(%23, 298.605f /* ty=float32 */) /* ty=Tensor[(1, 64), float32] */;
%25 = round(%24) /* ty=Tensor[(1, 64), float32] */;
%26 = clip(%25, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 64), float32] */;
cast(%26, dtype="int8") /* ty=Tensor[(1, 64), int8] */
} /* ty=fn (Tensor[(1, 16, 2, 2), int8]) -> Tensor[(1, 64), int8] */;
%53 = %52(%51) /* ty=Tensor[(1, 64), int8] */;
%54 = fn (%p03: Tensor[(1, 64), int8] /* ty=Tensor[(1, 64), int8] */, %p12: Tensor[(32, 64), int8] /* ty=Tensor[(32, 64), int8] */, Primitive=1) -> Tensor[(1, 32), int8] {
%14 = nn.dense(%p03, %p12, units=None, out_dtype="int32") /* ty=Tensor[(1, 32), int32] */;
%15 = nn.relu(%14) /* ty=Tensor[(1, 32), int32] */;
%16 = cast(%15, dtype="int64") /* ty=Tensor[(1, 32), int64] */;
%17 = fixed_point_multiply(%16, multiplier=1383273600, shift=-8) /* ty=Tensor[(1, 32), int64] */;
%18 = clip(%17, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 32), int64] */;
%19 = cast(%18, dtype="int32") /* ty=Tensor[(1, 32), int32] */;
cast(%19, dtype="int8") /* ty=Tensor[(1, 32), int8] */
} /* ty=fn (Tensor[(1, 64), int8], Tensor[(32, 64), int8]) -> Tensor[(1, 32), int8] */;
%55 = %54(%53, meta[relay.Constant][3] /* ty=Tensor[(32, 64), int8] */) /* ty=Tensor[(1, 32), int8] */;
%56 = fn (%p02: Tensor[(1, 32), int8] /* ty=Tensor[(1, 32), int8] */, %p11: Tensor[(16, 32), int8] /* ty=Tensor[(16, 32), int8] */, %p21: Tensor[(16), int32] /* ty=Tensor[(16), int32] */, Primitive=1) -> Tensor[(1, 16), int8] {
%7 = nn.dense(%p02, %p11, units=None, out_dtype="int32") /* ty=Tensor[(1, 16), int32] */;
%8 = add(%7, %p21) /* ty=Tensor[(1, 16), int32] */;
%9 = nn.relu(%8) /* ty=Tensor[(1, 16), int32] */;
%10 = cast(%9, dtype="int64") /* ty=Tensor[(1, 16), int64] */;
%11 = fixed_point_multiply(%10, multiplier=2026565248, shift=-9) /* ty=Tensor[(1, 16), int64] */;
%12 = clip(%11, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 16), int64] */;
%13 = cast(%12, dtype="int32") /* ty=Tensor[(1, 16), int32] */;
cast(%13, dtype="int8") /* ty=Tensor[(1, 16), int8] */
} /* ty=fn (Tensor[(1, 32), int8], Tensor[(16, 32), int8], Tensor[(16), int32]) -> Tensor[(1, 16), int8] */;
%57 = %56(%55, meta[relay.Constant][4] /* ty=Tensor[(16, 32), int8] */, meta[relay.Constant][5] /* ty=Tensor[(16), int32] */) /* ty=Tensor[(1, 16), int8] */;
%58 = fn (%p01: Tensor[(1, 16), int8] /* ty=Tensor[(1, 16), int8] */, %p1: Tensor[(8, 16), int8] /* ty=Tensor[(8, 16), int8] */, %p2: Tensor[(8), int32] /* ty=Tensor[(8), int32] */, Primitive=1) -> Tensor[(1, 8), int8] {
%1 = nn.dense(%p01, %p1, units=None, out_dtype="int32") /* ty=Tensor[(1, 8), int32] */;
%2 = add(%1, %p2) /* ty=Tensor[(1, 8), int32] */;
%3 = cast(%2, dtype="int64") /* ty=Tensor[(1, 8), int64] */;
%4 = fixed_point_multiply(%3, multiplier=1088137088, shift=-9) /* ty=Tensor[(1, 8), int64] */;
%5 = clip(%4, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 8), int64] */;
%6 = cast(%5, dtype="int32") /* ty=Tensor[(1, 8), int32] */;
cast(%6, dtype="int8") /* ty=Tensor[(1, 8), int8] */
} /* ty=fn (Tensor[(1, 16), int8], Tensor[(8, 16), int8], Tensor[(8), int32]) -> Tensor[(1, 8), int8] */;
%59 = %58(%57, meta[relay.Constant][6] /* ty=Tensor[(8, 16), int8] */, meta[relay.Constant][7] /* ty=Tensor[(8), int32] */) /* ty=Tensor[(1, 8), int8] */;
%60 = fn (%p0: Tensor[(1, 8), int8] /* ty=Tensor[(1, 8), int8] */, Primitive=1) -> Tensor[(1, 8), float32] {
%0 = cast(%p0, dtype="float32") /* ty=Tensor[(1, 8), float32] */;
multiply(%0, 0.00189963f /* ty=float32 */) /* ty=Tensor[(1, 8), float32] */
} /* ty=fn (Tensor[(1, 8), int8]) -> Tensor[(1, 8), float32] */;
%60(%59) /* ty=Tensor[(1, 8), float32] */
}