量化注解#
%cd ..
import testing
/media/pc/data/lxw/ai/tvm-book/doc/read/relay
using namespace relay::transform;
class QAnnotateExpr;
class QAnnotateExprNode : public TempExprNode {
public:
Expr expr;
QAnnotateKind kind;
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("expr", &expr);
v->Visit("kind", &kind);
}
Expr Realize() const final;
static constexpr const char* _type_key = "relay.QAnnotateExpr";
TVM_DECLARE_FINAL_OBJECT_INFO(QAnnotateExprNode, TempExprNode);
};
class QAnnotateExpr : public TempExpr {
public:
/*!
* \brief The constructor
* \param expr The original relay expression.
* \param kind The annotation kind.
*/
TVM_DLL QAnnotateExpr(Expr expr, QAnnotateKind kind);
TVM_DEFINE_OBJECT_REF_METHODS(QAnnotateExpr, TempExpr, QAnnotateExprNode);
};
Expr QAnnotateExprNode::Realize() const { return expr; }
QAnnotateExpr::QAnnotateExpr(Expr expr, QAnnotateKind kind) {
auto rnode = make_object<QAnnotateExprNode>();
rnode->expr = std::move(expr);
rnode->kind = kind;
data_ = std::move(rnode);
}
TVM_REGISTER_GLOBAL("relay._quantize.make_annotate_expr").set_body_typed([](Expr expr, int kind) {
return QAnnotateExpr(expr, static_cast<QAnnotateKind>(kind));
});
这段代码定义了 QAnnotateExpr
类,它继承自 TempExpr
。这个类主要用于表示带有注解的表达式。其中,QAnnotateExprNode
是内部类,用于存储表达式和注解类型。VisitAttrs
方法用于访问表达式和注解类型的属性。Realize
方法返回原始表达式。
QAnnotateExpr
类的构造函数接受一个表达式和一个注解类型作为参数,并将它们存储在 QAnnotateExprNode
对象中。TVM_DEFINE_OBJECT_REF_METHODS
宏用于定义对象的引用方法。
最后,TVM_REGISTER_GLOBAL
宏用于注册全局函数,该函数接受一个表达式和一个整数类型的注解,并返回 QAnnotateExpr
对象。
Pass QuantizeAnnotate() {
// TODO(tvm-teams): since partition has added cast_hint in different
// branches, try to remove this in the future.
std::function<Expr(const Expr&)> fmulti_ref = [](const Expr& e) {
if (e->IsInstance<TempExprNode>()) {
const auto* n = e.as<QAnnotateExprNode>();
ICHECK(n);
const PackedFunc* f = runtime::Registry::Get("relay.quantize.attach_simulated_quantize");
Expr ret = (*f)(n->expr, static_cast<int>(kQInput));
return static_cast<Expr>(QAnnotateExpr(ret, kQInput));
}
return e;
};
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
auto func = Downcast<Function>(ForwardRewrite(f, "FQAnnotateRewrite", nullptr, fmulti_ref));
auto new_params = func->params;
for (const auto& x : FreeVars(func)) {
new_params.push_back(x);
}
return WithFields(func, new_params);
};
return CreateFunctionPass(pass_func, 1, "QuantizeAnnotate", {});
}
TVM_REGISTER_GLOBAL("relay._quantize.QuantizeAnnotate").set_body_typed(QuantizeAnnotate);
TVM_REGISTER_NODE_TYPE(QAnnotateExprNode);
这段代码定义了 QuantizeAnnotate
函数,它的作用是对输入的函数进行量化注解。具体来说,它首先定义了名为 fmulti_ref
的 lambda
函数,该函数接受一个表达式作为参数,如果该表达式是 TempExprNode
的实例,则对其进行量化注解,否则直接返回原表达式。
接下来,定义了一个名为 pass_func
的函数,它接受一个函数、一个 IR 模块和一个 PassContext 作为参数。在这个函数中,首先对输入的函数进行前向重写,然后遍历函数中的自由变量,将它们添加到新的参数列表中。最后,使用新的参数列表创建一个新的函数,并返回。
最后,使用 CreateFunctionPass
创建一个函数传递,并将其注册为全局函数 relay._quantize.QuantizeAnnotate
。同时,还注册了一个名为 QAnnotateExprNode
的节点类型。
from torch import nn
import torch
class Model(nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.conv = nn.Conv2d(3, 16, 3, 1, 1, bias=False)
self.conv2 = nn.Conv2d(16, 16, 3, 1, 1, bias=True)
self.relu = nn.ReLU()
def forward(self, x):
x = self.conv(x)
x1 = self.relu(x)
x = self.conv2(x)
x2 = self.relu(x)
x = x1 + x2
return x
import numpy as np
import tvm
from tvm import relay
# 输入数据
input_shape = (1, 3, 4, 4)
input_dtype = "float32"
data_np = np.random.rand(*input_shape).astype(input_dtype)
with torch.no_grad():
pt_model = Model().eval().float()
traced_model = torch.jit.trace(pt_model, torch.from_numpy(data_np)).eval()
mod, params = relay.frontend.from_pytorch(traced_model, [("data", input_shape)],
use_parser_friendly_name=True)
with tvm.transform.PassContext(opt_level=3):
mod = relay.quantize.prerequisite_optimize(mod, params)
print(mod['main'])
fn (%data: Tensor[(1, 3, 4, 4), float32] /* ty=Tensor[(1, 3, 4, 4), float32] span=aten___convolution_0_data:0:0 */) -> Tensor[(1, 16, 4, 4), float32] {
%0 = nn.conv2d(%data, meta[relay.Constant][0] /* ty=Tensor[(16, 3, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten___convolution_0:0:0 */;
%1 = nn.conv2d(%0, meta[relay.Constant][1] /* ty=Tensor[(16, 16, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten___convolution_1:0:0 */;
%2 = add(%1, meta[relay.Constant][2] /* ty=Tensor[(16, 1, 1), float32] */) /* ty=Tensor[(1, 16, 4, 4), float32] */;
%3 = nn.relu(%0) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten__relu_0:0:0 */;
%4 = nn.relu(%2) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten__relu_1:0:0 */;
add(%3, %4) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten__add_0:0:0 */
} /* ty=fn (Tensor[(1, 3, 4, 4), float32]) -> Tensor[(1, 16, 4, 4), float32] */
relay.quantize.partition()(mod)["main"]
fn (%data: Tensor[(1, 3, 4, 4), float32] /* ty=Tensor[(1, 3, 4, 4), float32] span=aten___convolution_0_data:0:0 */) -> Tensor[(1, 16, 4, 4), float32] {
%0 = nn.conv2d(%data, meta[relay.Constant][0] /* ty=Tensor[(16, 3, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten___convolution_0:0:0 */;
%1 = nn.relu(%0) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten__relu_0:0:0 */;
%2 = annotation.cast_hint(%1, dtype="int8") /* ty=Tensor[(1, 16, 4, 4), float32] */;
%3 = annotation.cast_hint(%0, dtype="int8") /* ty=Tensor[(1, 16, 4, 4), float32] */;
%4 = annotation.stop_fusion(%3) /* ty=Tensor[(1, 16, 4, 4), float32] */;
%5 = nn.conv2d(%4, meta[relay.Constant][1] /* ty=Tensor[(16, 16, 3, 3), float32] */, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten___convolution_1:0:0 */;
%6 = add(%5, meta[relay.Constant][2] /* ty=Tensor[(16, 1, 1), float32] */) /* ty=Tensor[(1, 16, 4, 4), float32] */;
%7 = nn.relu(%6) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten__relu_1:0:0 */;
%8 = annotation.cast_hint(%7, dtype="int8") /* ty=Tensor[(1, 16, 4, 4), float32] */;
%9 = annotation.stop_fusion(%2) /* ty=Tensor[(1, 16, 4, 4), float32] */;
%10 = annotation.stop_fusion(%8) /* ty=Tensor[(1, 16, 4, 4), float32] */;
%11 = add(%9, %10) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten__add_0:0:0 */;
%12 = annotation.cast_hint(%11, dtype="int8") /* ty=Tensor[(1, 16, 4, 4), float32] */;
annotation.stop_fusion(%12) /* ty=Tensor[(1, 16, 4, 4), float32] */
} /* ty=fn (Tensor[(1, 3, 4, 4), float32]) -> Tensor[(1, 16, 4, 4), float32] */
passes = tvm.transform.Sequential([
relay.quantize.partition(),
relay.quantize.annotate()
])
passes(mod)["main"]
fn (%data: Tensor[(1, 3, 4, 4), float32] /* ty=Tensor[(1, 3, 4, 4), float32] span=aten___convolution_0_data:0:0 */, %dom_scale: float32 /* ty=float32 */, %clip_min: float32 /* ty=float32 */, %clip_max: float32 /* ty=float32 */, %dom_scale1: float32 /* ty=float32 */, %clip_min1: float32 /* ty=float32 */, %clip_max1: float32 /* ty=float32 */, %dom_scale2: float32 /* ty=float32 */, %clip_min2: float32 /* ty=float32 */, %clip_max2: float32 /* ty=float32 */, %dom_scale3: float32 /* ty=float32 */, %clip_min3: float32 /* ty=float32 */, %clip_max3: float32 /* ty=float32 */, %dom_scale4: float32 /* ty=float32 */, %clip_min4: float32 /* ty=float32 */, %clip_max4: float32 /* ty=float32 */, %dom_scale5: float32 /* ty=float32 */, %clip_min5: float32 /* ty=float32 */, %clip_max5: float32 /* ty=float32 */, %dom_scale6: float32 /* ty=float32 */, %clip_min6: float32 /* ty=float32 */, %clip_max6: float32 /* ty=float32 */) -> Tensor[(1, 16, 4, 4), float32] {
%0 = relay.op.annotation.simulated_quantize(%data, %dom_scale, %clip_min, %clip_max, kind=1) /* ty=Tensor[(1, 3, 4, 4), float32] */;
%1 = relay.op.annotation.simulated_quantize(meta[relay.Constant][0] /* ty=Tensor[(16, 3, 3, 3), float32] */, %dom_scale1, %clip_min1, %clip_max1, kind=2) /* ty=Tensor[(16, 3, 3, 3), float32] */;
%2 = nn.conv2d(%0, %1, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten___convolution_0:0:0 */;
%3 = relay.op.annotation.simulated_quantize(%2, %dom_scale2, %clip_min2, %clip_max2, kind=1) /* ty=Tensor[(1, 16, 4, 4), float32] */;
%4 = nn.relu(%3) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten__relu_0:0:0 */;
%5 = annotation.cast_hint(%4, dtype="int8") /* ty=Tensor[(1, 16, 4, 4), float32] */;
%6 = annotation.cast_hint(%3, dtype="int8") /* ty=Tensor[(1, 16, 4, 4), float32] */;
%7 = annotation.stop_fusion(%6) /* ty=Tensor[(1, 16, 4, 4), float32] */;
%8 = relay.op.annotation.simulated_quantize(meta[relay.Constant][1] /* ty=Tensor[(16, 16, 3, 3), float32] */, %dom_scale3, %clip_min3, %clip_max3, kind=2) /* ty=Tensor[(16, 16, 3, 3), float32] */;
%9 = nn.conv2d(%7, %8, padding=[1, 1, 1, 1], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten___convolution_1:0:0 */;
%10 = relay.op.annotation.simulated_quantize(meta[relay.Constant][2] /* ty=Tensor[(16, 1, 1), float32] */, %dom_scale4, %clip_min4, %clip_max4, kind=2) /* ty=Tensor[(16, 1, 1), float32] */;
%11 = add(%9, %10) /* ty=Tensor[(1, 16, 4, 4), float32] */;
%12 = nn.relu(%11) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten__relu_1:0:0 */;
%13 = relay.op.annotation.simulated_quantize(%12, %dom_scale5, %clip_min5, %clip_max5, kind=1) /* ty=Tensor[(1, 16, 4, 4), float32] */;
%14 = annotation.cast_hint(%13, dtype="int8") /* ty=Tensor[(1, 16, 4, 4), float32] */;
%15 = annotation.stop_fusion(%5) /* ty=Tensor[(1, 16, 4, 4), float32] */;
%16 = annotation.stop_fusion(%14) /* ty=Tensor[(1, 16, 4, 4), float32] */;
%17 = add(%15, %16) /* ty=Tensor[(1, 16, 4, 4), float32] span=aten__add_0:0:0 */;
%18 = relay.op.annotation.simulated_quantize(%17, %dom_scale6, %clip_min6, %clip_max6, kind=1) /* ty=Tensor[(1, 16, 4, 4), float32] */;
%19 = annotation.cast_hint(%18, dtype="int8") /* ty=Tensor[(1, 16, 4, 4), float32] */;
annotation.stop_fusion(%19) /* ty=Tensor[(1, 16, 4, 4), float32] */
} /* ty=fn (Tensor[(1, 3, 4, 4), float32], float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float32) -> Tensor[(1, 16, 4, 4), float32] */
shape = [5, 10]
scale = 0.1
x_ = relay.var("x", shape=shape, dtype="int8")
x = relay.qnn.op.dequantize(x_, relay.const(scale), relay.const(0))
op = relay.op.nn.softmax(x, axis=1)
op = relay.qnn.op.quantize(
op, relay.const(1.0 / 256.0), relay.const(-128), out_dtype="int8"
)
x_np = np.random.randint(-128, 127, size=shape, dtype="int8")
x_np = np.sort(x_np)
args = [x_np]
mod = tvm.IRModule.from_expr(op)
mod = tvm.relay.transform.InferType()(mod)
mod_int = tvm.relay.transform.FakeQuantizationToInteger(
hard_fail=True, optional_qnn_ops=["nn.softmax"]
)(mod)
mod.show()
mod_int.show()