MulAndDiv#
参考:tvm/src/relay/quantize/realize.cc
/* calculate `data * s1 / s2`, use shift if possible */
inline Expr MulAndDiv(Expr data, float s1, float s2, DataType dtype,
const Array<IndexExpr>& data_shape) {
const QConfig& cfg = QConfig::Current();
// here we assume the dtype of data is dtype activation
if (s1 == s2) return data;
float factor = s1 / s2;
float shift_factor = std::log2(factor);
ICHECK_GT(shift_factor, 0);
if (static_cast<int>(shift_factor) == shift_factor) {
return LeftShift(data, MakeConstantScalar(dtype, static_cast<int>(shift_factor)));
} else if (static_cast<int>(factor) == factor) {
return Multiply(data, MakeConstantScalar(dtype, factor));
} else {
if (cfg->rounding == "UPWARD") {
auto [fixed_point_multiplier, shift] = qnn::GetFixedPointMultiplierShift(factor);
data = relay::FixedPointMultiply(data, fixed_point_multiplier, shift);
} else {
data = qnn::FixedPointMultiplyToNearest(data, factor, data_shape);
}
return Cast(data, dtype);
}
}
这段代码定义了一个名为MulAndDiv
的内联函数,用于计算 data * s1 / s2
。如果可能的话,它会使用位移运算来优化计算过程。
函数接收5个参数:
data
:需要进行计算的数据;s1
和s2
:两个浮点数,用于计算data * s1 / s2
;dtype
:数据类型;data_shape
:数据的形状。
函数首先获取当前的量化配置(QConfig::Current()
),然后判断 s1
和 s2
是否相等,如果相等则直接返回 data
。
接下来,计算 factor = s1 / s2
,并计算 shift_factor = std::log2(factor)
。如果 shift_factor
大于 0 且为整数,则对 data
进行左移运算。如果 factor
为整数,则对 data
进行乘法运算。否则,根据量化配置中的舍入方式(cfg->rounding
)进行定点乘法运算,并将结果转换为指定的数据类型。