SimulatedQuantize

SimulatedQuantize#

源码:tvm/src/relay/quantize/quantize.cctvm/python/tvm/relay/quantize/_annotate.py

%cd ..
import testing
TVM_REGISTER_NODE_TYPE(SimulatedQuantizeAttrs);

bool SimulatedQuantizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
                          const TypeReporter& reporter) {
  ICHECK_EQ(types.size(), 5);
  const auto param = attrs.as<SimulatedQuantizeAttrs>();
  ICHECK(param != nullptr);

  const auto* data = types[0].as<TensorTypeNode>();

  if (data == nullptr) {
    return false;
  }

  ICHECK_NE(data->shape.size(), 0) << "Input shape cannot be empty";

  reporter->Assign(types[1], TensorType({}, DataType::Float(32)));  // dom_scale
  reporter->Assign(types[2], TensorType({}, DataType::Float(32)));  // clip_min
  reporter->Assign(types[3], TensorType({}, DataType::Float(32)));  // clip_max
  reporter->Assign(types[4], types[0]);                             // output
  return true;
}

这段代码定义了 SimulatedQuantizeRel 函数,它的作用是检查输入的类型是否符合预期。具体来说,它首先检查输入的类型数量是否为 5,然后从属性中获取 SimulatedQuantizeAttrs 类型的参数。接着,它检查第一个类型是否为 TensorTypeNode 类型,如果不是则返回 false。最后,它将输出的类型分别设置为 dom_scaleclip_minclip_max 和输入数据的类型。

RELAY_REGISTER_OP("relay.op.annotation.simulated_quantize")
    .describe(R"code(simulated quantize op)code" TVM_ADD_FILELINE)
    .set_num_inputs(4)
    .add_argument("data", "Tensor", "The input data.")
    .add_argument("dom_scale", "Tensor", "The domain scale of input data. It should be a scalar")
    .add_argument("clip_min", "Tensor", "lower bound. It should be a scalar")
    .add_argument("clip_max", "Tensor", "upper bound. It should be a scalar")
    .set_attrs_type<SimulatedQuantizeAttrs>()
    .set_support_level(11)
    .add_type_rel("SimulatedQuantize", SimulatedQuantizeRel);

TVM_REGISTER_GLOBAL("relay._quantize.simulated_quantize")
    .set_body_typed([](Expr data, Expr dom_scale, Expr clip_min, Expr clip_max, int kind, bool sign,
                       String rounding) {
      auto attrs = make_object<SimulatedQuantizeAttrs>();
      attrs->kind = kind;
      attrs->sign = sign;
      attrs->rounding = rounding;
      static const Op& op = Op::Get("relay.op.annotation.simulated_quantize");
      return Call(op, {data, dom_scale, clip_min, clip_max}, Attrs(attrs), {});
    });

RELAY_REGISTER_OP 宏注册名为 relay.op.annotation.simulated_quantize 的算子,该算子有 4 个输入参数:datadom_scaleclip_minclip_max。它还设置了属性类型为 SimulatedQuantizeAttrs,并添加了类型关系函数 SimulatedQuantizeRel

TVM_REGISTER_GLOBAL 宏注册全局函数 relay._quantize.simulated_quantize,该函数接受 6 个参数:datadom_scaleclip_minclip_maxkindsignrounding。在这个函数中,首先创建 SimulatedQuantizeAttrs 对象,并设置其属性值。然后,调用 relay.op.annotation.simulated_quantize 算子,并将结果返回。