测试定点乘法#
import numpy as np
import tvm
from tvm import relay, te
from tvm.relay.backend import Executor
# def get_hexagon_target(cpu_ver: str, **kwargs) -> tvm.target.Target:
# """Creates a Hexagon target"""
# target = tvm.target.hexagon(cpu_ver, **kwargs)
# return tvm.target.Target(target, host=target)
def build_module(relay_mod, target):
params = {}
executor = Executor("aot", {"link-params": True})
lowered = tvm.relay.build(
relay_mod,
tvm.target.Target(target, host=target),
executor=executor,
params=params,
)
return lowered
def run_module(mod, inputs):
mod.set_input(**inputs)
mod.run()
output = mod.get_output(0).numpy()
return output
测试 relay.fixed_point_multiply
#
ishape = (6, 32)
a = relay.var("a", relay.TensorType(ishape, "int32"))
for multiplier, shift in [
(1288490240, -2), # 0.15
(1395864320, 1), # 1.3
(1288490188, 0), # 0.6
]:
fpm = relay.fixed_point_multiply(a, multiplier, shift)
relay_mod = tvm.IRModule.from_expr(fpm)
relay_mod.show()
with tvm.transform.PassContext(opt_level=3):
# Compile for LLVM...
llvm_lowered = build_module(relay_mod, tvm.target.Target("llvm"))
data_in = np.arange(-96, 96).reshape(ishape)
inputs = {"a": data_in}
# Run llvm...
llvm_mod = tvm.runtime.executor.AotModule(llvm_lowered["default"](tvm.cpu(0)))
expected_output = run_module(llvm_mod, inputs)
# print(expected_output)
def @main(%a: Tensor[(6, 32), int32]) {
fixed_point_multiply(%a, multiplier=1288490240, shift=-2)
}
def @main(%a: Tensor[(6, 32), int32]) {
fixed_point_multiply(%a, multiplier=1395864320, shift=1)
}
def @main(%a: Tensor[(6, 32), int32]) {
fixed_point_multiply(%a, multiplier=1288490188, shift=0)
}
逐通道定点乘法#
scales = (
(1.3, 30.0),
(1.37, 1.0),
(0.6, 1.0),
((1.7, 0.6), 1.0),
((0.007, 1.9), 1.0),
)
ishape = [1, 128, 56, 56]
axis = 1
a = relay.var("a", shape=ishape, dtype="int32")
for in_scale_const, out_scale_const in scales:
# Make list of input scales from in_scale_const parameter.
if isinstance(in_scale_const, tuple):
in_scale = list(in_scale_const) * (ishape[axis] // len(in_scale_const))
else:
in_scale = [in_scale_const] * ishape[axis]
assert len(in_scale) == ishape[axis]
# qnn.requantize is lowered to fixed_point_multiply if zp == 0 and in_dtype == out_dtype.
iscale = relay.const(in_scale)
izero = relay.const(0)
oscale = relay.const(out_scale_const)
ozero = relay.const(0)
op = relay.qnn.op.requantize(a, iscale, izero, oscale, ozero, axis=axis, out_dtype="int32")
mod = tvm.IRModule.from_expr(op)
mod = relay.transform.InferType()(mod)
mod.show()
with tvm.transform.PassContext(opt_level=3):
# Compile for LLVM...
llvm_lowered = build_module(mod, tvm.target.Target("llvm"))
a_np = np.random.randint(-1000, 1000, size=np.prod(ishape)).reshape(ishape)
inputs = {"a": a_np}
# Run llvm...
llvm_mod = tvm.runtime.executor.AotModule(llvm_lowered["default"](tvm.cpu(0)))
expected_output = run_module(llvm_mod, inputs)
def @main(%a: Tensor[(1, 128, 56, 56), int32] /* ty=Tensor[(1, 128, 56, 56), int32] */) -> Tensor[(1, 128, 56, 56), int32] {
qnn.requantize(%a, meta[relay.Constant][0] /* ty=Tensor[(128), float32] */, 0 /* ty=int32 */, 30f /* ty=float32 */, 0 /* ty=int32 */, axis=1, out_dtype="int32") /* ty=Tensor[(1, 128, 56, 56), int32] */
}
def @main(%a: Tensor[(1, 128, 56, 56), int32] /* ty=Tensor[(1, 128, 56, 56), int32] */) -> Tensor[(1, 128, 56, 56), int32] {
qnn.requantize(%a, meta[relay.Constant][0] /* ty=Tensor[(128), float32] */, 0 /* ty=int32 */, 1f /* ty=float32 */, 0 /* ty=int32 */, axis=1, out_dtype="int32") /* ty=Tensor[(1, 128, 56, 56), int32] */
}
def @main(%a: Tensor[(1, 128, 56, 56), int32] /* ty=Tensor[(1, 128, 56, 56), int32] */) -> Tensor[(1, 128, 56, 56), int32] {
qnn.requantize(%a, meta[relay.Constant][0] /* ty=Tensor[(128), float32] */, 0 /* ty=int32 */, 1f /* ty=float32 */, 0 /* ty=int32 */, axis=1, out_dtype="int32") /* ty=Tensor[(1, 128, 56, 56), int32] */
}
def @main(%a: Tensor[(1, 128, 56, 56), int32] /* ty=Tensor[(1, 128, 56, 56), int32] */) -> Tensor[(1, 128, 56, 56), int32] {
qnn.requantize(%a, meta[relay.Constant][0] /* ty=Tensor[(128), float32] */, 0 /* ty=int32 */, 1f /* ty=float32 */, 0 /* ty=int32 */, axis=1, out_dtype="int32") /* ty=Tensor[(1, 128, 56, 56), int32] */
}
def @main(%a: Tensor[(1, 128, 56, 56), int32] /* ty=Tensor[(1, 128, 56, 56), int32] */) -> Tensor[(1, 128, 56, 56), int32] {
qnn.requantize(%a, meta[relay.Constant][0] /* ty=Tensor[(128), float32] */, 0 /* ty=int32 */, 1f /* ty=float32 */, 0 /* ty=int32 */, axis=1, out_dtype="int32") /* ty=Tensor[(1, 128, 56, 56), int32] */
}
fixed point multiply with vectorization#
Vectorization size is more than hw vector length
ishape = [2, 256, 16]
def q_mul_shift(shape):
x = te.placeholder(shape, name="X", dtype="int32")
out = te.compute(
shape,
lambda i, j, k: tvm.tir.q_multiply_shift(
x[i, j, k],
tvm.tir.const(1395864320, "int32"),
tvm.tir.const(31, "int32"),
tvm.tir.const(1, "int32"),
),
name="compute",
)
return te.create_prim_func([x, out])
for vector_size in (32, 64, 128, 256):
mod = q_mul_shift(ishape)
# Schedule with vectorization
sch = tvm.tir.Schedule(mod)
b00 = sch.get_block(name="compute", func_name="main")
fused = sch.fuse(*sch.get_loops(block=b00))
_, v = sch.split(loop=fused, factors=[None, vector_size])
sch.vectorize(v)
with tvm.transform.PassContext(opt_level=3):
host_lib = tvm.build(mod, target=tvm.target.Target("llvm"))
# Verify accuracy
a_np = np.random.randint(-1000, 1000, size=np.prod(ishape)).reshape(ishape).astype("int32")
b_np = np.random.randint(-1000, 1000, size=np.prod(ishape)).reshape(ishape).astype("int32")
host_args = [tvm.runtime.ndarray.array(arg) for arg in [a_np, b_np]]
host_lib(*host_args)
a_shape = [2, 256, 16]
b_shape = [256]
def q_mul_shift(shape):
shift_shape = [shape[1]]
x = te.placeholder(shape, name="X", dtype="int32")
y = te.placeholder(shift_shape, name="X", dtype="int32")
l_shift = te.placeholder(shift_shape, name="X", dtype="int32")
r_shift = te.placeholder(shift_shape, name="X", dtype="int32")
out = te.compute(
shape,
lambda i, j, k: tvm.tir.q_multiply_shift_per_axis(
x[i, j, k],
y[j],
l_shift[j],
r_shift[j],
tvm.tir.const(31, "int32"),
tvm.tir.const(1, "bool"),
tvm.tir.const(0, "bool"),
),
name="compute",
)
return te.create_prim_func([x, y, l_shift, r_shift, out])
for vector_size in (32, 64, 128, 256):
mod = q_mul_shift(a_shape)
# Schedule with vectorization
sch = tvm.tir.Schedule(mod)
b00 = sch.get_block(name="compute", func_name="main")
fused = sch.fuse(*sch.get_loops(block=b00))
_, v = sch.split(loop=fused, factors=[None, vector_size])
sch.vectorize(v)
with tvm.transform.PassContext(opt_level=3):
host_lib = tvm.build(mod, target=tvm.target.Target("llvm"))
# Verify accuracy
x_np = (
np.random.randint(-1000, 1000, size=np.prod(a_shape)).reshape(a_shape).astype("int32")
)
y_np = (
np.random.randint(-1000, 1000, size=np.prod(b_shape)).reshape(b_shape).astype("int32")
)
lsh_np = np.random.randint(0, 10, size=np.prod(b_shape)).reshape(b_shape).astype("int32")
rsh_np = np.random.randint(0, 10, size=np.prod(b_shape)).reshape(b_shape).astype("int32")
b_np = (
np.random.randint(-1000, 1000, size=np.prod(a_shape)).reshape(a_shape).astype("int32")
)
np_args = [x_np, y_np, lsh_np, rsh_np, b_np]
host_args = [tvm.runtime.ndarray.array(arg) for arg in np_args]
host_lib(*host_args)
relay.fixed_point_multiply
<function tvm.relay.op.tensor.fixed_point_multiply(data, multiplier, shift)>