tvm.relay.qnn.op.softmax()

tvm.relay.qnn.op.softmax()#

源码:tvm/src/relay/qnn/op/softmax.cc

import testing
from tvm import relay
relay.qnn.op.softmax??
Signature:
relay.qnn.op.softmax(
    x,
    scale,
    zero_point,
    output_scale,
    output_zero_point,
    axis=-1,
)
Docstring: <no docstring>
Source:   
def softmax(x, scale, zero_point, output_scale, output_zero_point, axis=-1):
    return _make.softmax(x, axis, scale, zero_point, output_scale, output_zero_point)
File:      /media/pc/data/lxw/ai/tvm/python/tvm/relay/qnn/op/qnn.py
Type:      function
import numpy as np
import tvm
from tvm import relay

is_sorted = lambda a: np.all(a[:-1] <= a[1:])

shape = [5, 10]
scale = 0.2
x_ = relay.var("x", shape=shape, dtype="int8")
x = relay.qnn.op.dequantize(x_, relay.const(scale), relay.const(0))
op = relay.op.nn.softmax(x, axis=1)
op = relay.qnn.op.quantize(
    op, relay.const(1.0 / 256.0), relay.const(-128), out_dtype="int8"
)

x_np = np.random.randint(-128, 127, size=shape, dtype="int8")
x_np = np.sort(x_np)
args = [x_np]

mod = tvm.IRModule.from_expr(op)
mod = tvm.relay.transform.InferType()(mod)
mod_int = tvm.relay.transform.FakeQuantizationToInteger(
    hard_fail=True, optional_qnn_ops=["nn.softmax"]
)(mod)
assert not tvm.ir.structural_equal(mod, mod_int)
result = (
    relay.create_executor("vm", mod=mod, device=tvm.cpu(), target="llvm")
    .evaluate()(*args)
    .numpy()
)
result_int = (
    relay.create_executor("vm", mod=mod_int, device=tvm.cpu(), target="llvm")
    .evaluate()(*args)
    .numpy()
)

# Check at least the softmax output is in ascending order,
# since it is difficult to use allclose due to not-so-good accuracy.
for qdq, qop in zip(result, result_int):
    assert is_sorted(qdq)
    assert is_sorted(qop)

try:
    np.testing.assert_allclose(result_int, result, atol=1)
except AssertionError as e:
    # To see the difference
    print(e)
Not equal to tolerance rtol=1e-07, atol=1

Mismatched elements: 5 / 50 (10%)
Max absolute difference: 5
Max relative difference: 0.33333333
 x: array([[-128, -128, -128, -128, -128, -128, -128, -128, -128,  126],
       [-128, -128, -128, -128, -128, -128, -128, -124,  -96,   88],
       [-128, -128, -128, -128, -128, -128, -128, -128, -120,  118],...
 y: array([[-128, -128, -128, -128, -128, -128, -128, -128, -128,  127],
       [-128, -128, -128, -128, -128, -128, -128, -123,  -98,   93],
       [-128, -128, -128, -128, -128, -128, -128, -128, -120,  120],...
mod_int.show()
def @main(%x: Tensor[(5, 10), int8] /* ty=Tensor[(5, 10), int8] */) -> Tensor[(5, 10), int8] {
  qnn.softmax(%x, 0.2f /* ty=float32 */, 0 /* ty=int32 */, 0.00390625f /* ty=float32 */, -128 /* ty=int32 */, axis=1) /* ty=Tensor[(5, 10), int8] */
}
with tvm.target.Target("llvm"):
    with tvm.transform.PassContext(opt_level=3):
        run_mod = relay.qnn.transform.Legalize()(mod_int)
        run_mod = relay.qnn.transform.CanonicalizeOps()(run_mod)
run_mod.show()
def @main(%x: Tensor[(5, 10), int8] /* ty=Tensor[(5, 10), int8] */) -> Tensor[(5, 10), int8] {
  %0 = cast(%x, dtype="int32") /* ty=Tensor[(5, 10), int32] */;
  %1 = subtract(%0, 0 /* ty=int32 */) /* ty=Tensor[(5, 10), int32] */;
  %2 = max(%1, axis=[1], keepdims=True) /* ty=Tensor[(5, 1), int32] */;
  %3 = subtract(%1, %2) /* ty=Tensor[(5, 10), int32] */;
  %4 = right_shift(%3, 1 /* ty=int32 */) /* ty=Tensor[(5, 10), int32] */;
  %5 = add(%3, %4) /* ty=Tensor[(5, 10), int32] */;
  %6 = right_shift(%3, 4 /* ty=int32 */) /* ty=Tensor[(5, 10), int32] */;
  %7 = clip(5f /* ty=float32 */, a_min=-2.14748e+09f, a_max=2.14748e+09f) /* ty=float32 */;
  %8 = cast(%7, dtype="int32") /* ty=int32 */;
  %9 = subtract(%5, %6) /* ty=Tensor[(5, 10), int32] */;
  %10 = negative(%8) /* ty=int32 */;
  %11 = divide(%9, %10) /* ty=Tensor[(5, 10), int32] */;
  %12 = clip(%11, a_min=0f, a_max=20f) /* ty=Tensor[(5, 10), int32] */;
  %13 = negative(%8) /* ty=int32 */;
  %14 = multiply(%12, %13) /* ty=Tensor[(5, 10), int32] */;
  %15 = subtract(%9, %14) /* ty=Tensor[(5, 10), int32] */;
  %16 = right_shift(%15, 1 /* ty=int32 */) /* ty=Tensor[(5, 10), int32] */;
  %17 = max(%12, axis=[1], keepdims=True) /* ty=Tensor[(5, 1), int32] */;
  %18 = add(%16, %8) /* ty=Tensor[(5, 10), int32] */;
  %19 = subtract(%17, %12) /* ty=Tensor[(5, 10), int32] */;
  %20 = left_shift(%18, %19) /* ty=Tensor[(5, 10), int32] */;
  %21 = sum(%20, axis=[1], keepdims=True) /* ty=Tensor[(5, 1), int32] */;
  %22 = divide(1073741824 /* ty=int32 */, %21) /* ty=Tensor[(5, 1), int32] */;
  %23 = multiply(%22, %20) /* ty=Tensor[(5, 10), int32] */;
  %24 = right_shift(%23, 23 /* ty=int32 */) /* ty=Tensor[(5, 10), int32] */;
  %25 = cast(%24, dtype="int32") /* ty=Tensor[(5, 10), int32] */;
  %26 = cast(-128 /* ty=int32 */, dtype="int32") /* ty=int32 */;
  %27 = fixed_point_multiply(%25, multiplier=1073741824, shift=2) /* ty=Tensor[(5, 10), int32] */;
  %28 = add(%26, %27) /* ty=Tensor[(5, 10), int32] */;
  %29 = clip(%28, a_min=-128f, a_max=127f) /* ty=Tensor[(5, 10), int32] */;
  cast(%29, dtype="int8") /* ty=Tensor[(5, 10), int8] */
}
/*
 * \brief Canonicalizes the QNN softmax op.
 * \param attrs The Softmax attrs.
 * \param new_args The new mutated args to the call node.
 * \param arg_types The types of input and output.
 * \return The sequence of Relay ops for softmax op.
 * \note This op is highly experimental and sometimes lacks accuracy.
 *       Be aware that the input scale must be in the range of 0 to 1.
 */
Expr QnnSoftmaxCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
                            const Array<tvm::relay::Type>& arg_types) {
  // Expected: input, scale, zero_point, output_scale, output_zero_point
  ICHECK_EQ(new_args.size(), 5);

  const auto const_i32 = [&](int32_t val) { return MakeConstantScalar(DataType::Int(32), val); };
  const auto const_f32 = [&](float val) { return MakeConstantScalar(DataType::Float(32), val); };

  const auto const_input_scale = new_args[1].as<ConstantNode>();
  ICHECK(const_input_scale) << "Input scale should be constant.";
  ICHECK(const_input_scale->is_scalar()) << "Input scale should be scalar.";
  const float input_scale = static_cast<float*>(const_input_scale->data->data)[0];
  ICHECK(input_scale <= 1.f) << "Input scale should be less than or equal to 1.";

  const Expr input_zero_point = new_args[2];
  const Expr output_scale = new_args[3];
  const Expr output_zero_point = new_args[4];
  const int axis = attrs.as<SoftmaxAttrs>()->axis;

  // Refer to the Algorithm 1 in https://arxiv.org/pdf/2207.01405.pdf

  const Expr quantized_data = Subtract(Cast(new_args[0], DataType::Int(32)), input_zero_point);

  const Expr x_0 = ConvertDtype(const_f32(std::round(1.f / input_scale)), DataType::Int(32));
  const Expr max = Max(quantized_data, {axis}, true, false);
  const Expr x = Subtract(quantized_data, max);

  const int m = 30;
  const int bits = 8;
  const Expr x_p = Subtract(Add(x, RightShift(x, const_i32(1))), RightShift(x, const_i32(4)));
  const Expr q = Clip(Divide(x_p, Negative(x_0)), 0, 20);
  const Expr max_q = Max(q, {axis}, true, false);
  const Expr r = Subtract(x_p, Multiply(q, Negative(x_0)));
  const Expr x_b = Add(RightShift(r, const_i32(1)), x_0);
  const Expr exps = LeftShift(x_b, Subtract(max_q, q));
  const Expr sums = Sum(exps, {axis}, true, false);
  const Expr output =
      RightShift(Multiply(Divide(const_i32(1 << m), sums), exps), const_i32(m - (bits - 1)));
  const Expr requantized = Requantize(output, arg_types[0].as<TensorTypeNode>()->shape,
                                      const_f32(1.f / (1 << (bits - 1))), const_i32(0),
                                      output_scale, output_zero_point, DataType::Int(bits), 0);

  return requantized;
}

这段代码是一个用于规范化 QNN softmax 操作的函数。它接受三个参数:attrs(Softmax 属性)、new_args(新的调用节点参数)和 arg_types(输入和输出的类型)。该函数返回 Relay 算子序列,用于执行 softmax 运算。

该函数首先检查输入参数的数量是否正确,然后从 new_args 中提取出输入、缩放因子、零点、输出缩放因子和输出零点等参数。接着,它使用这些参数计算出量化数据,并按照算法 1 进行计算。最后,它将结果重新量化并返回。

需要注意的是,这个函数是高度实验性的,有时可能缺乏准确性。此外,输入缩放因子必须在 0 到 1 之间。