MulAndDiv

MulAndDiv#

参考:tvm/src/relay/quantize/realize.cc

/* calculate `data * s1 / s2`, use shift if possible */
inline Expr MulAndDiv(Expr data, float s1, float s2, DataType dtype,
                      const Array<IndexExpr>& data_shape) {
  const QConfig& cfg = QConfig::Current();
  // here we assume the dtype of data is dtype activation
  if (s1 == s2) return data;

  float factor = s1 / s2;
  float shift_factor = std::log2(factor);
  ICHECK_GT(shift_factor, 0);
  if (static_cast<int>(shift_factor) == shift_factor) {
    return LeftShift(data, MakeConstantScalar(dtype, static_cast<int>(shift_factor)));
  } else if (static_cast<int>(factor) == factor) {
    return Multiply(data, MakeConstantScalar(dtype, factor));
  } else {
    if (cfg->rounding == "UPWARD") {
      auto [fixed_point_multiplier, shift] = qnn::GetFixedPointMultiplierShift(factor);
      data = relay::FixedPointMultiply(data, fixed_point_multiplier, shift);
    } else {
      data = qnn::FixedPointMultiplyToNearest(data, factor, data_shape);
    }

    return Cast(data, dtype);
  }
}

这段代码定义了一个名为MulAndDiv的内联函数,用于计算 data * s1 / s2。如果可能的话,它会使用位移运算来优化计算过程。

函数接收5个参数:

  • data:需要进行计算的数据;

  • s1s2:两个浮点数,用于计算 data * s1 / s2

  • dtype:数据类型;

  • data_shape:数据的形状。

函数首先获取当前的量化配置(QConfig::Current()),然后判断 s1s2 是否相等,如果相等则直接返回 data

接下来,计算 factor = s1 / s2,并计算 shift_factor = std::log2(factor)。如果 shift_factor 大于 0 且为整数,则对 data 进行左移运算。如果 factor 为整数,则对 data 进行乘法运算。否则,根据量化配置中的舍入方式(cfg->rounding)进行定点乘法运算,并将结果转换为指定的数据类型。