SimulatedQuantize#
源码:tvm/src/relay/quantize/quantize.cc
和 tvm/python/tvm/relay/quantize/_annotate.py
%cd ..
import testing
TVM_REGISTER_NODE_TYPE(SimulatedQuantizeAttrs);
bool SimulatedQuantizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
ICHECK_EQ(types.size(), 5);
const auto param = attrs.as<SimulatedQuantizeAttrs>();
ICHECK(param != nullptr);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) {
return false;
}
ICHECK_NE(data->shape.size(), 0) << "Input shape cannot be empty";
reporter->Assign(types[1], TensorType({}, DataType::Float(32))); // dom_scale
reporter->Assign(types[2], TensorType({}, DataType::Float(32))); // clip_min
reporter->Assign(types[3], TensorType({}, DataType::Float(32))); // clip_max
reporter->Assign(types[4], types[0]); // output
return true;
}
这段代码定义了 SimulatedQuantizeRel
函数,它的作用是检查输入的类型是否符合预期。具体来说,它首先检查输入的类型数量是否为 5,然后从属性中获取 SimulatedQuantizeAttrs
类型的参数。接着,它检查第一个类型是否为 TensorTypeNode
类型,如果不是则返回 false
。最后,它将输出的类型分别设置为 dom_scale
、clip_min
、clip_max
和输入数据的类型。
RELAY_REGISTER_OP("relay.op.annotation.simulated_quantize")
.describe(R"code(simulated quantize op)code" TVM_ADD_FILELINE)
.set_num_inputs(4)
.add_argument("data", "Tensor", "The input data.")
.add_argument("dom_scale", "Tensor", "The domain scale of input data. It should be a scalar")
.add_argument("clip_min", "Tensor", "lower bound. It should be a scalar")
.add_argument("clip_max", "Tensor", "upper bound. It should be a scalar")
.set_attrs_type<SimulatedQuantizeAttrs>()
.set_support_level(11)
.add_type_rel("SimulatedQuantize", SimulatedQuantizeRel);
TVM_REGISTER_GLOBAL("relay._quantize.simulated_quantize")
.set_body_typed([](Expr data, Expr dom_scale, Expr clip_min, Expr clip_max, int kind, bool sign,
String rounding) {
auto attrs = make_object<SimulatedQuantizeAttrs>();
attrs->kind = kind;
attrs->sign = sign;
attrs->rounding = rounding;
static const Op& op = Op::Get("relay.op.annotation.simulated_quantize");
return Call(op, {data, dom_scale, clip_min, clip_max}, Attrs(attrs), {});
});
RELAY_REGISTER_OP
宏注册名为 relay.op.annotation.simulated_quantize
的算子,该算子有 4 个输入参数:data
、dom_scale
、clip_min
和 clip_max
。它还设置了属性类型为 SimulatedQuantizeAttrs
,并添加了类型关系函数 SimulatedQuantizeRel
。
TVM_REGISTER_GLOBAL
宏注册全局函数 relay._quantize.simulated_quantize
,该函数接受 6 个参数:data
、dom_scale
、clip_min
、clip_max
、kind
、sign
和 rounding
。在这个函数中,首先创建 SimulatedQuantizeAttrs
对象,并设置其属性值。然后,调用 relay.op.annotation.simulated_quantize
算子,并将结果返回。