量化注解

量化注解#

%cd ..
import testing
/media/pc/data/lxw/ai/tvm-book/doc/read/relay
using namespace relay::transform;

class QAnnotateExpr;
class QAnnotateExprNode : public TempExprNode {
 public:
  Expr expr;
  QAnnotateKind kind;

  void VisitAttrs(tvm::AttrVisitor* v) {
    v->Visit("expr", &expr);
    v->Visit("kind", &kind);
  }

  Expr Realize() const final;

  static constexpr const char* _type_key = "relay.QAnnotateExpr";
  TVM_DECLARE_FINAL_OBJECT_INFO(QAnnotateExprNode, TempExprNode);
};

class QAnnotateExpr : public TempExpr {
 public:
  /*!
   * \brief The constructor
   * \param expr The original relay expression.
   * \param kind The annotation kind.
   */
  TVM_DLL QAnnotateExpr(Expr expr, QAnnotateKind kind);

  TVM_DEFINE_OBJECT_REF_METHODS(QAnnotateExpr, TempExpr, QAnnotateExprNode);
};

Expr QAnnotateExprNode::Realize() const { return expr; }

QAnnotateExpr::QAnnotateExpr(Expr expr, QAnnotateKind kind) {
  auto rnode = make_object<QAnnotateExprNode>();
  rnode->expr = std::move(expr);
  rnode->kind = kind;
  data_ = std::move(rnode);
}

TVM_REGISTER_GLOBAL("relay._quantize.make_annotate_expr").set_body_typed([](Expr expr, int kind) {
  return QAnnotateExpr(expr, static_cast<QAnnotateKind>(kind));
});

这段代码定义了 QAnnotateExpr 类,它继承自 TempExpr。这个类主要用于表示带有注解的表达式。其中,QAnnotateExprNode 是内部类,用于存储表达式和注解类型。VisitAttrs 方法用于访问表达式和注解类型的属性。Realize 方法返回原始表达式。

QAnnotateExpr 类的构造函数接受一个表达式和一个注解类型作为参数,并将它们存储在 QAnnotateExprNode 对象中。TVM_DEFINE_OBJECT_REF_METHODS 宏用于定义对象的引用方法。

最后,TVM_REGISTER_GLOBAL 宏用于注册全局函数,该函数接受一个表达式和一个整数类型的注解,并返回 QAnnotateExpr 对象。

Pass QuantizeAnnotate() {
  // TODO(tvm-teams): since partition has added cast_hint in different
  // branches, try to remove this in the future.
  std::function<Expr(const Expr&)> fmulti_ref = [](const Expr& e) {
    if (e->IsInstance<TempExprNode>()) {
      const auto* n = e.as<QAnnotateExprNode>();
      ICHECK(n);
      const PackedFunc* f = runtime::Registry::Get("relay.quantize.attach_simulated_quantize");
      Expr ret = (*f)(n->expr, static_cast<int>(kQInput));
      return static_cast<Expr>(QAnnotateExpr(ret, kQInput));
    }
    return e;
  };

  runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
      [=](Function f, IRModule m, PassContext pc) {
        auto func = Downcast<Function>(ForwardRewrite(f, "FQAnnotateRewrite", nullptr, fmulti_ref));
        auto new_params = func->params;
        for (const auto& x : FreeVars(func)) {
          new_params.push_back(x);
        }
        return WithFields(func, new_params);
      };
  return CreateFunctionPass(pass_func, 1, "QuantizeAnnotate", {});
}

TVM_REGISTER_GLOBAL("relay._quantize.QuantizeAnnotate").set_body_typed(QuantizeAnnotate);

TVM_REGISTER_NODE_TYPE(QAnnotateExprNode);

这段代码定义了 QuantizeAnnotate 函数,它的作用是对输入的函数进行量化注解。具体来说,它首先定义了名为 fmulti_reflambda 函数,该函数接受一个表达式作为参数,如果该表达式是 TempExprNode 的实例,则对其进行量化注解,否则直接返回原表达式。

接下来,定义了一个名为 pass_func 的函数,它接受一个函数、一个 IR 模块和一个 PassContext 作为参数。在这个函数中,首先对输入的函数进行前向重写,然后遍历函数中的自由变量,将它们添加到新的参数列表中。最后,使用新的参数列表创建一个新的函数,并返回。

最后,使用 CreateFunctionPass 创建一个函数传递,并将其注册为全局函数 relay._quantize.QuantizeAnnotate。同时,还注册了一个名为 QAnnotateExprNode 的节点类型。

from torch import nn
import torch


class Model(nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.conv = nn.Conv2d(3, 16, 3, 1, 1, bias=False)
        self.conv2 = nn.Conv2d(16, 16, 3, 1, 1, bias=True)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.conv(x)
        x1 = self.relu(x)
        x = self.conv2(x)
        x2 = self.relu(x)
        x = x1 + x2
        return x
import numpy as np
import tvm
from tvm import relay

# 输入数据
input_shape = (1, 3, 4, 4)
input_dtype = "float32"
data_np = np.random.rand(*input_shape).astype(input_dtype)
with torch.no_grad():
    pt_model = Model().eval().float()
    traced_model = torch.jit.trace(pt_model, torch.from_numpy(data_np)).eval()
mod, params = relay.frontend.from_pytorch(traced_model, [("data", input_shape)], 
                                          use_parser_friendly_name=True)
with tvm.transform.PassContext(opt_level=3):
    mod = relay.quantize.prerequisite_optimize(mod, params)
print(mod['main'])
fn (%data: Tensor[(1, 3, 4, 4), float32] /* ty=Tensor[(1, 3, 4, 4), float32] span=aten___convolution_0_data:0:0 */) -> Tensor[(1, 16, 4, 4), float32] {
  %0 = nn.conv2d(%data, meta[relay.Constant][0] /* ty=Tensor[(16, 3, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten___convolution_0:0:0 */;
  %1 = nn.conv2d(%0, meta[relay.Constant][1] /* ty=Tensor[(16, 16, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten___convolution_1:0:0 */;
  %2 = add(%1, meta[relay.Constant][2] /* ty=Tensor[(16, 1, 1), float32] */) /* ty=Tensor[(1, 16, 4, 4), float32] */;
  %3 = nn.relu(%0) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten__relu_0:0:0 */;
  %4 = nn.relu(%2) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten__relu_1:0:0 */;
  add(%3, %4) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten__add_0:0:0 */
} /* ty=fn (Tensor[(1, 3, 4, 4), float32]) -> Tensor[(1, 16, 4, 4), float32] */
relay.quantize.partition()(mod)["main"]
fn (%data: Tensor[(1, 3, 4, 4), float32] /* ty=Tensor[(1, 3, 4, 4), float32] span=aten___convolution_0_data:0:0 */) -> Tensor[(1, 16, 4, 4), float32] {
  %0 = nn.conv2d(%data, meta[relay.Constant][0] /* ty=Tensor[(16, 3, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten___convolution_0:0:0 */;
  %1 = nn.relu(%0) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten__relu_0:0:0 */;
  %2 = annotation.cast_hint(%1, dtype="int8") /* ty=Tensor[(1, 16, 4, 4), float32] */;
  %3 = annotation.cast_hint(%0, dtype="int8") /* ty=Tensor[(1, 16, 4, 4), float32] */;
  %4 = annotation.stop_fusion(%3) /* ty=Tensor[(1, 16, 4, 4), float32] */;
  %5 = nn.conv2d(%4, meta[relay.Constant][1] /* ty=Tensor[(16, 16, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten___convolution_1:0:0 */;
  %6 = add(%5, meta[relay.Constant][2] /* ty=Tensor[(16, 1, 1), float32] */) /* ty=Tensor[(1, 16, 4, 4), float32] */;
  %7 = nn.relu(%6) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten__relu_1:0:0 */;
  %8 = annotation.cast_hint(%7, dtype="int8") /* ty=Tensor[(1, 16, 4, 4), float32] */;
  %9 = annotation.stop_fusion(%2) /* ty=Tensor[(1, 16, 4, 4), float32] */;
  %10 = annotation.stop_fusion(%8) /* ty=Tensor[(1, 16, 4, 4), float32] */;
  %11 = add(%9, %10) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten__add_0:0:0 */;
  %12 = annotation.cast_hint(%11, dtype="int8") /* ty=Tensor[(1, 16, 4, 4), float32] */;
  annotation.stop_fusion(%12) /* ty=Tensor[(1, 16, 4, 4), float32] */
} /* ty=fn (Tensor[(1, 3, 4, 4), float32]) -> Tensor[(1, 16, 4, 4), float32] */
passes = tvm.transform.Sequential([
    relay.quantize.partition(),
    relay.quantize.annotate()
])
passes(mod)["main"]
fn (%data: Tensor[(1, 3, 4, 4), float32] /* ty=Tensor[(1, 3, 4, 4), float32] span=aten___convolution_0_data:0:0 */, %dom_scale: float32 /* ty=float32 */, %clip_min: float32 /* ty=float32 */, %clip_max: float32 /* ty=float32 */, %dom_scale1: float32 /* ty=float32 */, %clip_min1: float32 /* ty=float32 */, %clip_max1: float32 /* ty=float32 */, %dom_scale2: float32 /* ty=float32 */, %clip_min2: float32 /* ty=float32 */, %clip_max2: float32 /* ty=float32 */, %dom_scale3: float32 /* ty=float32 */, %clip_min3: float32 /* ty=float32 */, %clip_max3: float32 /* ty=float32 */, %dom_scale4: float32 /* ty=float32 */, %clip_min4: float32 /* ty=float32 */, %clip_max4: float32 /* ty=float32 */, %dom_scale5: float32 /* ty=float32 */, %clip_min5: float32 /* ty=float32 */, %clip_max5: float32 /* ty=float32 */, %dom_scale6: float32 /* ty=float32 */, %clip_min6: float32 /* ty=float32 */, %clip_max6: float32 /* ty=float32 */) -> Tensor[(1, 16, 4, 4), float32] {
  %0 = relay.op.annotation.simulated_quantize(%data, %dom_scale, %clip_min, %clip_max, kind=1) /* ty=Tensor[(1, 3, 4, 4), float32] */;
  %1 = relay.op.annotation.simulated_quantize(meta[relay.Constant][0] /* ty=Tensor[(16, 3, 3, 3), float32] */, %dom_scale1, %clip_min1, %clip_max1, kind=2) /* ty=Tensor[(16, 3, 3, 3), float32] */;
  %2 = nn.conv2d(%0, %1, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten___convolution_0:0:0 */;
  %3 = relay.op.annotation.simulated_quantize(%2, %dom_scale2, %clip_min2, %clip_max2, kind=1) /* ty=Tensor[(1, 16, 4, 4), float32] */;
  %4 = nn.relu(%3) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten__relu_0:0:0 */;
  %5 = annotation.cast_hint(%4, dtype="int8") /* ty=Tensor[(1, 16, 4, 4), float32] */;
  %6 = annotation.cast_hint(%3, dtype="int8") /* ty=Tensor[(1, 16, 4, 4), float32] */;
  %7 = annotation.stop_fusion(%6) /* ty=Tensor[(1, 16, 4, 4), float32] */;
  %8 = relay.op.annotation.simulated_quantize(meta[relay.Constant][1] /* ty=Tensor[(16, 16, 3, 3), float32] */, %dom_scale3, %clip_min3, %clip_max3, kind=2) /* ty=Tensor[(16, 16, 3, 3), float32] */;
  %9 = nn.conv2d(%7, %8, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten___convolution_1:0:0 */;
  %10 = relay.op.annotation.simulated_quantize(meta[relay.Constant][2] /* ty=Tensor[(16, 1, 1), float32] */, %dom_scale4, %clip_min4, %clip_max4, kind=2) /* ty=Tensor[(16, 1, 1), float32] */;
  %11 = add(%9, %10) /* ty=Tensor[(1, 16, 4, 4), float32] */;
  %12 = nn.relu(%11) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten__relu_1:0:0 */;
  %13 = relay.op.annotation.simulated_quantize(%12, %dom_scale5, %clip_min5, %clip_max5, kind=1) /* ty=Tensor[(1, 16, 4, 4), float32] */;
  %14 = annotation.cast_hint(%13, dtype="int8") /* ty=Tensor[(1, 16, 4, 4), float32] */;
  %15 = annotation.stop_fusion(%5) /* ty=Tensor[(1, 16, 4, 4), float32] */;
  %16 = annotation.stop_fusion(%14) /* ty=Tensor[(1, 16, 4, 4), float32] */;
  %17 = add(%15, %16) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten__add_0:0:0 */;
  %18 = relay.op.annotation.simulated_quantize(%17, %dom_scale6, %clip_min6, %clip_max6, kind=1) /* ty=Tensor[(1, 16, 4, 4), float32] */;
  %19 = annotation.cast_hint(%18, dtype="int8") /* ty=Tensor[(1, 16, 4, 4), float32] */;
  annotation.stop_fusion(%19) /* ty=Tensor[(1, 16, 4, 4), float32] */
} /* ty=fn (Tensor[(1, 3, 4, 4), float32], float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float32) -> Tensor[(1, 16, 4, 4), float32] */
shape = [5, 10]
scale = 0.1
x_ = relay.var("x", shape=shape, dtype="int8")
x = relay.qnn.op.dequantize(x_, relay.const(scale), relay.const(0))
op = relay.op.nn.softmax(x, axis=1)
op = relay.qnn.op.quantize(
    op, relay.const(1.0 / 256.0), relay.const(-128), out_dtype="int8"
)

x_np = np.random.randint(-128, 127, size=shape, dtype="int8")
x_np = np.sort(x_np)
args = [x_np]

mod = tvm.IRModule.from_expr(op)
mod = tvm.relay.transform.InferType()(mod)
mod_int = tvm.relay.transform.FakeQuantizationToInteger(
    hard_fail=True, optional_qnn_ops=["nn.softmax"]
)(mod)
mod.show()
mod_int.show()