canonicalizations
#
from typing import Callable
import numpy as np
from tvm import relay
from tvm.relay.qnn.op import canonicalizations
def fake_identity_func_numpy(arr: np.ndarray):
return arr.astype("float32")
def dequantize_numpy(np_arr, np_scale=1.0, np_zero_point=0):
return (np_arr.astype("int32") - np_zero_point) * np_scale
def fake_identity_func_relay(
floating_point_func: Callable[[np.ndarray], np.ndarray],
input_arg=None,
in_scale=relay.const(1.0, dtype="float32"),
in_zero_point=relay.const(0, dtype="int32"),
out_scale=relay.const(1.0, dtype="float32"),
out_zero_point=relay.const(0, dtype="int32"),
in_axis=-1,
out_axis=-1,
in_dtype="uint8",
out_dtype="uint8",
):
if input_arg is None:
input_arg = relay.const(np.arange(0, 256, dtype="uint8").view(in_dtype))
return (
canonicalizations.create_integer_lookup_op(
input_arg=input_arg,
floating_point_func=floating_point_func,
in_scale=in_scale,
in_zero_point=in_zero_point,
out_scale=out_scale,
out_zero_point=out_zero_point,
in_axis=in_axis,
out_axis=out_axis,
in_dtype=in_dtype,
out_dtype=out_dtype,
),
input_arg.data.numpy(),
)
def run_function_test(
in_scale: float,
in_zero_point: int,
out_scale: float,
out_zero_point: int,
in_dtype: str,
out_dtype: str,
floating_point_func: Callable[[np.ndarray], np.ndarray],
input_arg: relay.Expr = None,
rtol=1e-7,
atol=0,
):
relay_lookup, input_arg = fake_identity_func_relay(
input_arg=input_arg,
floating_point_func=floating_point_func,
in_scale=relay.const(in_scale, "float32"),
in_zero_point=relay.const(in_zero_point, "int32"),
out_scale=relay.const(out_scale, "float32"),
out_zero_point=relay.const(out_zero_point, "int32"),
in_dtype=in_dtype,
out_dtype=out_dtype,
)
result = canonicalizations.run_const_expr(relay_lookup)
np.testing.assert_allclose(
floating_point_func(
dequantize_numpy(input_arg, np_scale=in_scale, np_zero_point=in_zero_point)
),
dequantize_numpy(result, np_scale=out_scale, np_zero_point=out_zero_point),
atol=atol,
rtol=rtol,
)
run_function_test(
in_scale=1.0,
in_zero_point=0,
out_scale=1.0,
out_zero_point=0,
in_dtype="int8",
out_dtype="int8",
floating_point_func=fake_identity_func_numpy,
)
def test_int8_to_int8(self):
self.run_function_test(
in_scale=1.0,
in_zero_point=0,
out_scale=1.0,
out_zero_point=0,
in_dtype="int8",
out_dtype="int8",
floating_point_func=self.fake_identity_func_numpy,
)
def test_uint8_to_uint8(self):
self.run_function_test(
in_scale=1.0,
in_zero_point=128,
out_scale=1.0,
out_zero_point=128,
in_dtype="uint8",
out_dtype="uint8",
floating_point_func=self.fake_identity_func_numpy,
)
def test_int8_to_uint8(self):
self.run_function_test(
in_scale=1.0,
in_zero_point=0,
out_scale=1.0,
out_zero_point=128,
in_dtype="int8",
out_dtype="uint8",
floating_point_func=self.fake_identity_func_numpy,
)
def test_uint8_to_int8(self):
self.run_function_test(
in_scale=1.0,
in_zero_point=128,
out_scale=1.0,
out_zero_point=0,
in_dtype="uint8",
out_dtype="int8",
floating_point_func=self.fake_identity_func_numpy,
)
"""Test different input shapes"""
def test_keep_input_shapes(self):
# input in floating point ~[-2, 2], final output ~[0, 8]
self.run_function_test(
input_arg=relay.const(np.arange(-128, 128).astype("int8").reshape([2, 2, 8, 8])),
in_scale=0.015,
in_zero_point=0,
out_scale=16 / 256,
out_zero_point=0,
in_dtype="int8",
out_dtype="int8",
floating_point_func=self.fake_identity_func_numpy,
atol=0.03,
rtol=0.01,
)
self.run_function_test(
input_arg=relay.const(np.arange(-128, 128).astype("int8").reshape([2, 2, 64])),
in_scale=0.015,
in_zero_point=0,
out_scale=16 / 256,
out_zero_point=0,
in_dtype="int8",
out_dtype="int8",
floating_point_func=self.fake_identity_func_numpy,
atol=0.03,
rtol=0.01,
)
self.run_function_test(
input_arg=relay.const(np.arange(-128, 128).astype("int8").reshape([2, 128])),
in_scale=0.015,
in_zero_point=0,
out_scale=16 / 256,
out_zero_point=0,
in_dtype="int8",
out_dtype="int8",
floating_point_func=self.fake_identity_func_numpy,
atol=0.03,
rtol=0.01,
)