重写 relay 量化算子(python)#

from testing import viz_expr # 可视化 relay
import tvm
@tvm.instrument.pass_instrument
class PrintIR:
    """仅在传递执行之前,打印传递名称、IR。"""
    def run_before_pass(self, mod, info):
        print(f"Running pass: {info}->{mod['main']}")
from tvm.relay.testing import run_infer_type
from tvm.relay.dataflow_pattern import (
    wildcard, is_op, is_constant,
    # FunctionPattern,
    DFPatternCallback,
    rewrite
)
import tvm
from tvm.ir.attrs import DictAttrs
from tvm import relay, te, topi
from tvm.relay.op import op as _op
from tvm.target import generic_func

@generic_func
def schedule_special_op(attrs, outs, target):
    with target:
        outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
        output = outs[0]
        sch = te.create_schedule(output.op)   
        return sch
from d2py.utils.file import mkdir
root_dir = ".temp"
mkdir(f"{root_dir}/logs")
import numpy as np
import torch
from torch.nn import functional as F
from torch import nn


class M(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(3, 8, 1, 1, 0, bias=False, groups=1)
        self.pool = nn.AdaptiveAvgPool2d((1, 1))

    def forward(self, x):
        x = self.conv(x)
        x = self.pool(x)
        b, c, h, w = x.shape
        x = x.view((b, h, w, c))
        x = F.softmax(x, dim=3)
        x = x.view((b, h * w * c))
        return x

model = M()
model.eval()

shape = 1, 3, 8, 8
input_name = "data"
dtype = "float32"
data_np = np.random.rand(*shape).astype(dtype)
with torch.no_grad():
    pt_model = M().eval().float()
    traced_model = torch.jit.trace(pt_model, torch.from_numpy(data_np)).eval()
mod, params = relay.frontend.from_pytorch(traced_model, [("data", shape)], 
                                          use_parser_friendly_name=True)
# with tvm.transform.PassContext(opt_level=3, instruments=[PrintIR()]):
#     mod = relay.quantize.prerequisite_optimize(mod, params)
# print(mod['main'])
with open(f".temp/origin_mod.json", "w") as fp:
    fp.write(tvm.ir.save_json(mod))
with open(f".temp/origin_mod.json") as fp:
    _mod = tvm.ir.load_json(fp.read())
    tvm.ir.structural_equal(mod, _mod)
np.savez(".temp/origin_params.npz", **{k:v.numpy() for k, v in params.items()})
_params = np.load(".temp/origin_params.npz")
for n1, n2 in zip(params, _params):
    assert n1 == n2
    np.testing.assert_equal(params[n1].numpy(), _params[n2])

前端导入:

import numpy as np
import tvm
from tvm import relay

class Reshape4dSoftmaxReshape2dRewrite(DFPatternCallback):
    def __init__(self):
        super().__init__()
        self.x = wildcard()
        self.reshape4d = is_op("reshape")(self.x) # 将 NCHW 转换为 NHWC,其他 H=W=1
        self.softmax = is_op("nn.softmax")(self.reshape4d)
        self.softmax_axis = self.softmax.has_attr({"axis": 3})
        self.reshape2d = is_op("reshape")(self.softmax_axis)
        self.pattern = self.reshape2d

    def callback(self, pre, post, node_map):
        x = node_map[self.x][0]
        relay.transform.InferTypeLocal(x).shape
        x = relay.nn.softmax(x, axis=1)
        relay.transform.InferTypeLocal(x)
        x = relay.transpose(x, (0, 2, 3, 1))
        relay.transform.InferTypeLocal(x)
        x = relay.reshape(x, (1, -1))
        relay.transform.InferTypeLocal(x)
        return x

mod = relay.transform.InferType()(mod)
print(mod['main'])
mod["main"] = rewrite(Reshape4dSoftmaxReshape2dRewrite(), mod["main"])
print(mod['main'])
fn (%data: Tensor[(1, 3, 8, 8), float32] /* ty=Tensor[(1, 3, 8, 8), float32] span=aten___convolution_0_data:0:0 */) -> Tensor[(1, 8), float32] {
  %0 = nn.conv2d(%data, meta[relay.Constant][0] /* ty=Tensor[(8, 3, 1, 1), float32] */, padding=[0, 0, 0, 0], channels=8, kernel_size=[1, 1]) /* ty=Tensor[(1, 8, 8, 8), float32] span=aten___convolution_0:0:0 */;
  %1 = nn.adaptive_avg_pool2d(%0, output_size=[1, 1]) /* ty=Tensor[(1, 8, 1, 1), float32] span=aten__adaptive_avg_pool2d_0:0:0 */;
  %2 = reshape(%1, newshape=[1, 1, 1, 8]) /* ty=Tensor[(1, 1, 1, 8), float32] span=aten__view_0:0:0 */;
  %3 = nn.softmax(%2, axis=3) /* ty=Tensor[(1, 1, 1, 8), float32] span=aten__softmax_0:0:0 */;
  reshape(%3, newshape=[1, 8]) /* ty=Tensor[(1, 8), float32] span=aten__view_1:0:0 */
} /* ty=fn (Tensor[(1, 3, 8, 8), float32]) -> Tensor[(1, 8), float32] */

fn (%data: Tensor[(1, 3, 8, 8), float32] /* ty=Tensor[(1, 3, 8, 8), float32] span=aten___convolution_0_data:0:0 */) -> Tensor[(1, 8), float32] {
  %0 = nn.conv2d(%data, meta[relay.Constant][0] /* ty=Tensor[(8, 3, 1, 1), float32] */, padding=[0, 0, 0, 0], channels=8, kernel_size=[1, 1]) /* ty=Tensor[(1, 8, 8, 8), float32] span=aten___convolution_0:0:0 */;
  %1 = nn.adaptive_avg_pool2d(%0, output_size=[1, 1]) /* ty=Tensor[(1, 8, 1, 1), float32] span=aten__adaptive_avg_pool2d_0:0:0 */;
  %2 = nn.softmax(%1, axis=1) /* ty=Tensor[(1, 8, 1, 1), float32] */;
  %3 = transpose(%2, axes=[0, 2, 3, 1]) /* ty=Tensor[(1, 1, 1, 8), float32] */;
  reshape(%3, newshape=[1, -1]) /* ty=Tensor[(1, 8), float32] */
} /* ty=fn (Tensor[(1, 3, 8, 8), float32]) -> Tensor[(1, 8), float32] */

未成功量化 softmax_transpose_reshape2d 结构#

from dataclasses import dataclass

@dataclass
class Dataset:
    input_name: str
    shape: tuple

    def __iter__(self):
        for _ in range(2):
            yield {self.input_name: np.random.normal(0, 1, size=self.shape).astype("float32")}

dataset = Dataset(input_name, shape)

with tvm.transform.PassContext(opt_level=3):
    with relay.quantize.qconfig(
        skip_conv_layers=[],
        calibrate_mode="kl_divergence", 
        weight_scale="max",
        # round_for_shift=True,
        # rounding="TONEAREST", # "UPWARD" or "TONEAREST"
        skip_dense_layer=False,
    ):
        qmod = relay.quantize.quantize(mod, params, dataset)
qmod.show()
WARNING:autotvm:One or more operators have not been tuned. Please tune your model for better performance. Use DEBUG logging level to see more details.
def @main(%data: Tensor[(1, 3, 8, 8), float32] /* ty=Tensor[(1, 3, 8, 8), float32] span=aten___convolution_0_data:0:0 */) -> Tensor[(1, 8), float32] {
  %0 = multiply(%data, 46.0739f /* ty=float32 */) /* ty=Tensor[(1, 3, 8, 8), float32] */;
  %1 = round(%0) /* ty=Tensor[(1, 3, 8, 8), float32] */;
  %2 = clip(%1, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 3, 8, 8), float32] */;
  %3 = cast(%2, dtype="int8") /* ty=Tensor[(1, 3, 8, 8), int8] */;
  %4 = nn.conv2d(%3, meta[relay.Constant][0] /* ty=Tensor[(8, 3, 1, 1), int8] */, padding=[0, 0, 0, 0], channels=8, kernel_size=[1, 1], out_dtype="int32") /* ty=Tensor[(1, 8, 8, 8), int32] */;
  %5 = cast(%4, dtype="int64") /* ty=Tensor[(1, 8, 8, 8), int64] */;
  %6 = fixed_point_multiply(%5, multiplier=1126179840, shift=-6) /* ty=Tensor[(1, 8, 8, 8), int64] */;
  %7 = clip(%6, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 8, 8, 8), int64] */;
  %8 = cast(%7, dtype="int32") /* ty=Tensor[(1, 8, 8, 8), int32] */;
  %9 = cast(%8, dtype="int8") /* ty=Tensor[(1, 8, 8, 8), int8] */;
  %10 = annotation.stop_fusion(%9) /* ty=Tensor[(1, 8, 8, 8), int8] */;
  %11 = cast(%10, dtype="float32") /* ty=Tensor[(1, 8, 8, 8), float32] */;
  %12 = multiply(%11, 0.0116321f /* ty=float32 */) /* ty=Tensor[(1, 8, 8, 8), float32] */;
  %13 = nn.adaptive_avg_pool2d(%12, output_size=[1, 1]) /* ty=Tensor[(1, 8, 1, 1), float32] span=aten__adaptive_avg_pool2d_0:0:0 */;
  %14 = nn.softmax(%13, axis=1) /* ty=Tensor[(1, 8, 1, 1), float32] */;
  %15 = transpose(%14, axes=[0, 2, 3, 1]) /* ty=Tensor[(1, 1, 1, 8), float32] */;
  reshape(%15, newshape=[1, -1]) /* ty=Tensor[(1, 8), float32] */
}

nn.softmax 添加分区规则#

原始的 nn.global_avg_pool2d 规则把其后所有算子均视为非量化算子:

quant_passes = tvm.transform.Sequential([relay.quantize.partition(), relay.quantize.annotate()])
with tvm.transform.PassContext(opt_level=3):
    with relay.quantize.qconfig(
        skip_conv_layers=[],
        calibrate_mode="kl_divergence", 
        weight_scale="max",
        # round_for_shift=True,
        # rounding="TONEAREST", # "UPWARD" or "TONEAREST"
        skip_dense_layer=False,
    ): 
        # run_mod = relay.quantize.prerequisite_optimize(mod, params)
        annotate_mod = quant_passes(mod)
print(annotate_mod["main"])
fn (%data: Tensor[(1, 3, 8, 8), float32] /* ty=Tensor[(1, 3, 8, 8), float32] span=aten___convolution_0_data:0:0 */, %dom_scale: float32 /* ty=float32 */, %clip_min: float32 /* ty=float32 */, %clip_max: float32 /* ty=float32 */, %dom_scale1: float32 /* ty=float32 */, %clip_min1: float32 /* ty=float32 */, %clip_max1: float32 /* ty=float32 */, %dom_scale2: float32 /* ty=float32 */, %clip_min2: float32 /* ty=float32 */, %clip_max2: float32 /* ty=float32 */) -> Tensor[(1, 8), float32] {
  %0 = relay.op.annotation.simulated_quantize(%data, %dom_scale, %clip_min, %clip_max, kind=1) /* ty=Tensor[(1, 3, 8, 8), float32] */;
  %1 = relay.op.annotation.simulated_quantize(meta[relay.Constant][0] /* ty=Tensor[(8, 3, 1, 1), float32] */, %dom_scale1, %clip_min1, %clip_max1, kind=2) /* ty=Tensor[(8, 3, 1, 1), float32] */;
  %2 = nn.conv2d(%0, %1, padding=[0, 0, 0, 0], channels=8, kernel_size=[1, 1]) /* ty=Tensor[(1, 8, 8, 8), float32] span=aten___convolution_0:0:0 */;
  %3 = relay.op.annotation.simulated_quantize(%2, %dom_scale2, %clip_min2, %clip_max2, kind=1) /* ty=Tensor[(1, 8, 8, 8), float32] */;
  %4 = annotation.cast_hint(%3, dtype="int8") /* ty=Tensor[(1, 8, 8, 8), float32] */;
  %5 = annotation.stop_fusion(%4) /* ty=Tensor[(1, 8, 8, 8), float32] */;
  %6 = nn.adaptive_avg_pool2d(%5, output_size=[1, 1]) /* ty=Tensor[(1, 8, 1, 1), float32] span=aten__adaptive_avg_pool2d_0:0:0 */;
  %7 = nn.softmax(%6, axis=1) /* ty=Tensor[(1, 8, 1, 1), float32] */;
  %8 = transpose(%7, axes=[0, 2, 3, 1]) /* ty=Tensor[(1, 1, 1, 8), float32] */;
  reshape(%8, newshape=[1, -1]) /* ty=Tensor[(1, 8), float32] */
} /* ty=fn (Tensor[(1, 3, 8, 8), float32], float32, float32, float32, float32, float32, float32, float32, float32, float32) -> Tensor[(1, 8), float32] */
from tvm.relay.quantize._partition import (
    register_partition_function, 
    partition_expr_check,
    QPartitionExpr
)
from tvm.relay.quantize.quantize import _forward_op
def avg_pool2d_partition_function(ref_call, new_args, ctx):
    cond, expr = partition_expr_check(new_args[0])
    if cond:
        expr = new_args[0].realize()
        return QPartitionExpr(_forward_op(ref_call, [expr]))
    return None

# register_partition_function("nn.avg_pool2d", avg_pool2d_partition_function)
_op.get("nn.adaptive_avg_pool2d").reset_attr("FQPartitionRewrite")
register_partition_function("nn.adaptive_avg_pool2d", avg_pool2d_partition_function)
<function __main__.avg_pool2d_partition_function(ref_call, new_args, ctx)>
@register_partition_function("nn.softmax")
def softmax_partition_function(ref_call, new_args, ctx):
    """Rewrite function for softmax for partition"""
    data_cond, data = partition_expr_check(new_args[0])

    if data_cond:
        data = new_args[0].realize()
    ret = _forward_op(ref_call, [data])
    return QPartitionExpr(ret)
quant_passes = tvm.transform.Sequential([relay.quantize.partition(), relay.quantize.annotate()])
with tvm.transform.PassContext(
    opt_level=3, 
    required_pass=["QuantizeAnnotate", "QuantizeCalibrate", "QuantizeRealize"]):
    with relay.quantize.qconfig(
        skip_conv_layers=[],
        calibrate_mode="kl_divergence", 
        weight_scale="max",
        # round_for_shift=True,
        # rounding="TONEAREST", # "UPWARD" or "TONEAREST"
        skip_dense_layer=False,
    ): 
        annotate_mod = quant_passes(mod)
print(annotate_mod['main'])
fn (%data: Tensor[(1, 3, 8, 8), float32] /* ty=Tensor[(1, 3, 8, 8), float32] span=aten___convolution_0_data:0:0 */, %dom_scale: float32 /* ty=float32 */, %clip_min: float32 /* ty=float32 */, %clip_max: float32 /* ty=float32 */, %dom_scale1: float32 /* ty=float32 */, %clip_min1: float32 /* ty=float32 */, %clip_max1: float32 /* ty=float32 */, %dom_scale2: float32 /* ty=float32 */, %clip_min2: float32 /* ty=float32 */, %clip_max2: float32 /* ty=float32 */) -> Tensor[(1, 8), float32] {
  %0 = relay.op.annotation.simulated_quantize(%data, %dom_scale, %clip_min, %clip_max, kind=1) /* ty=Tensor[(1, 3, 8, 8), float32] */;
  %1 = relay.op.annotation.simulated_quantize(meta[relay.Constant][0] /* ty=Tensor[(8, 3, 1, 1), float32] */, %dom_scale1, %clip_min1, %clip_max1, kind=2) /* ty=Tensor[(8, 3, 1, 1), float32] */;
  %2 = nn.conv2d(%0, %1, padding=[0, 0, 0, 0], channels=8, kernel_size=[1, 1]) /* ty=Tensor[(1, 8, 8, 8), float32] span=aten___convolution_0:0:0 */;
  %3 = relay.op.annotation.simulated_quantize(%2, %dom_scale2, %clip_min2, %clip_max2, kind=1) /* ty=Tensor[(1, 8, 8, 8), float32] */;
  %4 = annotation.cast_hint(%3, dtype="int8") /* ty=Tensor[(1, 8, 8, 8), float32] */;
  %5 = annotation.stop_fusion(%4) /* ty=Tensor[(1, 8, 8, 8), float32] */;
  %6 = nn.adaptive_avg_pool2d(%5, output_size=[1, 1]) /* ty=Tensor[(1, 8, 1, 1), float32] span=aten__adaptive_avg_pool2d_0:0:0 */;
  %7 = annotation.stop_fusion(%6) /* ty=Tensor[(1, 8, 1, 1), float32] */;
  %8 = nn.softmax(%7, axis=1) /* ty=Tensor[(1, 8, 1, 1), float32] */;
  %9 = annotation.stop_fusion(%8) /* ty=Tensor[(1, 8, 1, 1), float32] */;
  %10 = transpose(%9, axes=[0, 2, 3, 1]) /* ty=Tensor[(1, 1, 1, 8), float32] */;
  reshape(%10, newshape=[1, -1]) /* ty=Tensor[(1, 8), float32] */
} /* ty=fn (Tensor[(1, 3, 8, 8), float32], float32, float32, float32, float32, float32, float32, float32, float32, float32) -> Tensor[(1, 8), float32] */

nn.softmax 添加注解规则#

from tvm.relay.quantize._annotate import (
    attach_simulated_quantize, register_annotate_function, 
    QAnnotateKind, _get_expr_kind, QAnnotateExpr
)
from tvm.relay.quantize.quantize import quantize_context

def avg_pool2d_rewrite(ref_call, new_args, ctx):
    """Rewrite function for max pool2d"""
    if quantize_context().check_to_skip(ref_call):
        return None

    expr, x_kind = _get_expr_kind(new_args[0])
    if x_kind is None:
        return None
    expr = _forward_op(ref_call, [expr])
    return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)

# register_annotate_function("nn.avg_pool2d", avg_pool2d_rewrite)
_op.get("nn.adaptive_avg_pool2d").reset_attr("FQAnnotateRewrite")
register_annotate_function("nn.adaptive_avg_pool2d", avg_pool2d_rewrite)

@register_annotate_function("nn.softmax")
def softmax_rewrite(ref_call, new_args, ctx):
    """Rewrite function for softmax. Lhs of nn.softmax will be quantized to
    input field.
    Output would be in activation field"""
    if quantize_context().check_to_skip(ref_call):
        return None

    lhs_expr, lhs_kind = _get_expr_kind(new_args[0])

    if lhs_kind is None or lhs_kind == QAnnotateKind.ACTIVATION:
        lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT)

    expr = _forward_op(ref_call, [lhs_expr])

    return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
# from tvm.relay.quantize._calibrate import calibrate

# dataset = Dataset(input_name, shape)

with tvm.transform.PassContext(
    opt_level=3, 
    required_pass=["QuantizeAnnotate"]):
   
    quant_passes = tvm.transform.Sequential([
        relay.quantize.partition(), relay.quantize.annotate(), 
    ])
    with relay.quantize.qconfig(
        skip_conv_layers=[],
        calibrate_mode="kl_divergence", 
        weight_scale="max",
        # round_for_shift=True,
        # rounding="TONEAREST", # "UPWARD" or "TONEAREST"
        skip_dense_layer=False,
    ): 
        annotate_mod = quant_passes(mod)
print(annotate_mod['main'])
fn (%data: Tensor[(1, 3, 8, 8), float32] /* ty=Tensor[(1, 3, 8, 8), float32] span=aten___convolution_0_data:0:0 */, %dom_scale: float32 /* ty=float32 */, %clip_min: float32 /* ty=float32 */, %clip_max: float32 /* ty=float32 */, %dom_scale1: float32 /* ty=float32 */, %clip_min1: float32 /* ty=float32 */, %clip_max1: float32 /* ty=float32 */, %dom_scale2: float32 /* ty=float32 */, %clip_min2: float32 /* ty=float32 */, %clip_max2: float32 /* ty=float32 */, %dom_scale3: float32 /* ty=float32 */, %clip_min3: float32 /* ty=float32 */, %clip_max3: float32 /* ty=float32 */, %dom_scale4: float32 /* ty=float32 */, %clip_min4: float32 /* ty=float32 */, %clip_max4: float32 /* ty=float32 */) -> Tensor[(1, 8), float32] {
  %0 = relay.op.annotation.simulated_quantize(%data, %dom_scale, %clip_min, %clip_max, kind=1) /* ty=Tensor[(1, 3, 8, 8), float32] */;
  %1 = relay.op.annotation.simulated_quantize(meta[relay.Constant][0] /* ty=Tensor[(8, 3, 1, 1), float32] */, %dom_scale1, %clip_min1, %clip_max1, kind=2) /* ty=Tensor[(8, 3, 1, 1), float32] */;
  %2 = nn.conv2d(%0, %1, padding=[0, 0, 0, 0], channels=8, kernel_size=[1, 1]) /* ty=Tensor[(1, 8, 8, 8), float32] span=aten___convolution_0:0:0 */;
  %3 = relay.op.annotation.simulated_quantize(%2, %dom_scale2, %clip_min2, %clip_max2, kind=1) /* ty=Tensor[(1, 8, 8, 8), float32] */;
  %4 = annotation.cast_hint(%3, dtype="int8") /* ty=Tensor[(1, 8, 8, 8), float32] */;
  %5 = annotation.stop_fusion(%4) /* ty=Tensor[(1, 8, 8, 8), float32] */;
  %6 = nn.adaptive_avg_pool2d(%5, output_size=[1, 1]) /* ty=Tensor[(1, 8, 1, 1), float32] span=aten__adaptive_avg_pool2d_0:0:0 */;
  %7 = relay.op.annotation.simulated_quantize(%6, %dom_scale3, %clip_min3, %clip_max3, kind=1) /* ty=Tensor[(1, 8, 1, 1), float32] */;
  %8 = annotation.cast_hint(%7, dtype="int8") /* ty=Tensor[(1, 8, 1, 1), float32] */;
  %9 = annotation.stop_fusion(%8) /* ty=Tensor[(1, 8, 1, 1), float32] */;
  %10 = nn.softmax(%9, axis=1) /* ty=Tensor[(1, 8, 1, 1), float32] */;
  %11 = relay.op.annotation.simulated_quantize(%10, %dom_scale4, %clip_min4, %clip_max4, kind=1) /* ty=Tensor[(1, 8, 1, 1), float32] */;
  %12 = annotation.cast_hint(%11, dtype="int8") /* ty=Tensor[(1, 8, 1, 1), float32] */;
  %13 = annotation.stop_fusion(%12) /* ty=Tensor[(1, 8, 1, 1), float32] */;
  %14 = transpose(%13, axes=[0, 2, 3, 1]) /* ty=Tensor[(1, 1, 1, 8), float32] */;
  reshape(%14, newshape=[1, -1]) /* ty=Tensor[(1, 8), float32] */
} /* ty=fn (Tensor[(1, 3, 8, 8), float32], float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float32) -> Tensor[(1, 8), float32] */

校验量化 softmax 结果#

# from tvm.relay.quantize._calibrate import calibrate

dataset = Dataset(input_name, shape)

with tvm.transform.PassContext(opt_level=3):
    with relay.quantize.qconfig(
        skip_conv_layers=[],
        calibrate_mode="kl_divergence", 
        weight_scale="max",
        # round_for_shift=True,
        # rounding="TONEAREST", # "UPWARD" or "TONEAREST"
        skip_dense_layer=False,
    ): 
        qmod = relay.quantize.quantize(mod, params=params, dataset=dataset)
qmod.show()
def @main(%data: Tensor[(1, 3, 8, 8), float32] /* ty=Tensor[(1, 3, 8, 8), float32] span=aten___convolution_0_data:0:0 */) -> Tensor[(1, 8), float32] {
  %0 = multiply(%data, 50.2388f /* ty=float32 */) /* ty=Tensor[(1, 3, 8, 8), float32] */;
  %1 = round(%0) /* ty=Tensor[(1, 3, 8, 8), float32] */;
  %2 = clip(%1, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 3, 8, 8), float32] */;
  %3 = cast(%2, dtype="int8") /* ty=Tensor[(1, 3, 8, 8), int8] */;
  %4 = nn.conv2d(%3, meta[relay.Constant][0] /* ty=Tensor[(8, 3, 1, 1), int8] */, padding=[0, 0, 0, 0], channels=8, kernel_size=[1, 1], out_dtype="int32") /* ty=Tensor[(1, 8, 8, 8), int32] */;
  %5 = cast(%4, dtype="int64") /* ty=Tensor[(1, 8, 8, 8), int64] */;
  %6 = fixed_point_multiply(%5, multiplier=1945843968, shift=-7) /* ty=Tensor[(1, 8, 8, 8), int64] */;
  %7 = clip(%6, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 8, 8, 8), int64] */;
  %8 = cast(%7, dtype="int32") /* ty=Tensor[(1, 8, 8, 8), int32] */;
  %9 = cast(%8, dtype="int8") /* ty=Tensor[(1, 8, 8, 8), int8] */;
  %10 = annotation.stop_fusion(%9) /* ty=Tensor[(1, 8, 8, 8), int8] */;
  %11 = cast(%10, dtype="float32") /* ty=Tensor[(1, 8, 8, 8), float32] */;
  %12 = multiply(%11, 0.0123482f /* ty=float32 */) /* ty=Tensor[(1, 8, 8, 8), float32] */;
  %13 = nn.adaptive_avg_pool2d(%12, output_size=[1, 1]) /* ty=Tensor[(1, 8, 1, 1), float32] span=aten__adaptive_avg_pool2d_0:0:0 */;
  %14 = multiply(%13, 1441.75f /* ty=float32 */) /* ty=Tensor[(1, 8, 1, 1), float32] */;
  %15 = round(%14) /* ty=Tensor[(1, 8, 1, 1), float32] */;
  %16 = clip(%15, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 8, 1, 1), float32] */;
  %17 = cast(%16, dtype="int8") /* ty=Tensor[(1, 8, 1, 1), int8] */;
  %18 = annotation.stop_fusion(%17) /* ty=Tensor[(1, 8, 1, 1), int8] */;
  %19 = cast(%18, dtype="float32") /* ty=Tensor[(1, 8, 1, 1), float32] */;
  %20 = multiply(%19, 0.0006936f /* ty=float32 */) /* ty=Tensor[(1, 8, 1, 1), float32] */;
  %21 = nn.softmax(%20, axis=1) /* ty=Tensor[(1, 8, 1, 1), float32] */;
  %22 = multiply(%21, 968.021f /* ty=float32 */) /* ty=Tensor[(1, 8, 1, 1), float32] */;
  %23 = round(%22) /* ty=Tensor[(1, 8, 1, 1), float32] */;
  %24 = clip(%23, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 8, 1, 1), float32] */;
  %25 = cast(%24, dtype="int8") /* ty=Tensor[(1, 8, 1, 1), int8] */;
  %26 = annotation.stop_fusion(%25) /* ty=Tensor[(1, 8, 1, 1), int8] */;
  %27 = transpose(%26, axes=[0, 2, 3, 1]) /* ty=Tensor[(1, 1, 1, 8), int8] */;
  %28 = reshape(%27, newshape=[1, -1]) /* ty=Tensor[(1, 8), int8] */;
  %29 = cast(%28, dtype="float32") /* ty=Tensor[(1, 8), float32] */;
  multiply(%29, 0.00103304f /* ty=float32 */) /* ty=Tensor[(1, 8), float32] */
}
data = np.random.normal(0, 1, size=shape).astype("float32")
torch_out = model(torch.from_numpy(data)).detach().numpy()


target = 'llvm'
dev = tvm.device(target, 0)

# 原始模型
with tvm.transform.PassContext(opt_level=3):
    lib = relay.build(mod, target, params=params)
func = lib[lib.libmod_name]
module = tvm.contrib.graph_executor.GraphModule(func(dev))
module.run(**{input_name: data})
float_output = module.get_output(0).numpy()

# 量化的模型
with tvm.transform.PassContext(opt_level=3):
    lib = relay.build(qmod, target, params=params)
func = lib[lib.libmod_name]
module = tvm.contrib.graph_executor.GraphModule(func(dev))
module.run(**{input_name: data})
quant_output = module.get_output(0).numpy()
np.concatenate([float_output, quant_output, torch_out]).T
array([[0.13273914, 0.13119549, 0.11308717],
       [0.1083948 , 0.11569996, 0.10494332],
       [0.1343954 , 0.13119549, 0.146622  ],
       [0.13157682, 0.12912942, 0.12230157],
       [0.13567673, 0.13119549, 0.11914812],
       [0.11301415, 0.11569996, 0.12496211],
       [0.12319913, 0.12086514, 0.12797777],
       [0.12100384, 0.11879906, 0.14095795]], dtype=float32)
class QSoftmaxRewrite(DFPatternCallback):
    """变换 dequantize+softmax+quantize` 为 `qnn.softmax`"""
    def __init__(self):
        super().__init__()
        self.x = wildcard()
        self.cast = is_op("cast")(self.x).has_attr({"dtype": "float32"})
        self.scale = is_constant()
        self.multiply = is_op("multiply")(self.cast, self.scale)
        self.softmax = is_op("nn.softmax")(self.multiply)
        self.output_scale = is_constant()
        self.multiply_out = is_op("multiply")(self.softmax, self.output_scale)
        self.round = is_op("round")(self.multiply_out)
        self.clip = is_op("clip")(self.round)
        self.cast_out = is_op("cast")(self.clip).has_attr({"dtype": "int8"})
        self.stop_fusion = is_op("annotation.stop_fusion")(self.cast_out)
        self.pattern = self.stop_fusion

    def callback(self, pre, post, node_map):
        x = node_map[self.x][0]
        softmax = node_map[self.softmax][0]
        scale = node_map[self.scale][0]
        output_scale = node_map[self.output_scale][0]
        output_scale = output_scale.data
        zero_point = relay.const(0, dtype="int32")
        output_scale = relay.const(1.0/output_scale.numpy(), dtype=output_scale.dtype)
        # output_scale = relay.const(1.0/128, dtype=output_scale.dtype)
        output_zero_point = relay.const(0, dtype="int32")
        out = relay.qnn.softmax(x, scale, zero_point, output_scale, output_zero_point, axis=softmax.attrs.axis)
        return out
qmod["main"] = rewrite(QSoftmaxRewrite(), qmod["main"])
qmod.show()
def @main(%data: Tensor[(1, 3, 8, 8), float32] /* ty=Tensor[(1, 3, 8, 8), float32] span=aten___convolution_0_data:0:0 */) -> Tensor[(1, 8), float32] {
  %0 = multiply(%data, 50.2388f /* ty=float32 */) /* ty=Tensor[(1, 3, 8, 8), float32] */;
  %1 = round(%0) /* ty=Tensor[(1, 3, 8, 8), float32] */;
  %2 = clip(%1, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 3, 8, 8), float32] */;
  %3 = cast(%2, dtype="int8") /* ty=Tensor[(1, 3, 8, 8), int8] */;
  %4 = nn.conv2d(%3, meta[relay.Constant][0] /* ty=Tensor[(8, 3, 1, 1), int8] */, padding=[0, 0, 0, 0], channels=8, kernel_size=[1, 1], out_dtype="int32") /* ty=Tensor[(1, 8, 8, 8), int32] */;
  %5 = cast(%4, dtype="int64") /* ty=Tensor[(1, 8, 8, 8), int64] */;
  %6 = fixed_point_multiply(%5, multiplier=1945843968, shift=-7) /* ty=Tensor[(1, 8, 8, 8), int64] */;
  %7 = clip(%6, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 8, 8, 8), int64] */;
  %8 = cast(%7, dtype="int32") /* ty=Tensor[(1, 8, 8, 8), int32] */;
  %9 = cast(%8, dtype="int8") /* ty=Tensor[(1, 8, 8, 8), int8] */;
  %10 = annotation.stop_fusion(%9) /* ty=Tensor[(1, 8, 8, 8), int8] */;
  %11 = cast(%10, dtype="float32") /* ty=Tensor[(1, 8, 8, 8), float32] */;
  %12 = multiply(%11, 0.0123482f /* ty=float32 */) /* ty=Tensor[(1, 8, 8, 8), float32] */;
  %13 = nn.adaptive_avg_pool2d(%12, output_size=[1, 1]) /* ty=Tensor[(1, 8, 1, 1), float32] span=aten__adaptive_avg_pool2d_0:0:0 */;
  %14 = multiply(%13, 1441.75f /* ty=float32 */) /* ty=Tensor[(1, 8, 1, 1), float32] */;
  %15 = round(%14) /* ty=Tensor[(1, 8, 1, 1), float32] */;
  %16 = clip(%15, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 8, 1, 1), float32] */;
  %17 = cast(%16, dtype="int8") /* ty=Tensor[(1, 8, 1, 1), int8] */;
  %18 = annotation.stop_fusion(%17) /* ty=Tensor[(1, 8, 1, 1), int8] */;
  %19 = qnn.softmax(%18, 0.0006936f /* ty=float32 */, 0, 0.00103304f, 0, axis=1);
  %20 = transpose(%19, axes=[0, 2, 3, 1]) /* ty=Tensor[(1, 1, 1, 8), int8] */;
  %21 = reshape(%20, newshape=[1, -1]) /* ty=Tensor[(1, 8), int8] */;
  %22 = cast(%21, dtype="float32") /* ty=Tensor[(1, 8), float32] */;
  multiply(%22, 0.00103304f /* ty=float32 */) /* ty=Tensor[(1, 8), float32] */
}
# QNN 量化的模型
with tvm.transform.PassContext(opt_level=3):
    lib = relay.build(qmod, target, params=params)
func = lib[lib.libmod_name]
module = tvm.contrib.graph_executor.GraphModule(func(dev))
module.run(**{input_name: data})
qnn_output = module.get_output(0).numpy()
np.concatenate([float_output, quant_output, qnn_output]).T
array([[0.13273914, 0.13119549, 0.12499727],
       [0.1083948 , 0.11569996, 0.11673299],
       [0.1343954 , 0.13119549, 0.12499727],
       [0.13157682, 0.12912942, 0.12499727],
       [0.13567673, 0.13119549, 0.12499727],
       [0.11301415, 0.11569996, 0.11673299],
       [0.12319913, 0.12086514, 0.11673299],
       [0.12100384, 0.11879906, 0.11673299]], dtype=float32)