通用矩阵乘法(VTA)#
import numpy as np
import tvm
from tvm import te
from tvm import rpc
from tvm.contrib.utils import tempdir
from vta.testing import simulator
import vta.testing
env = vta.get_env()
assert env.TARGET == "sim" and simulator.enabled()
remote = rpc.LocalSession()
batch_size, channel, block = 128, 128, 128
data_shape = (batch_size // env.BATCH, channel // env.BLOCK_IN, env.BATCH, env.BLOCK_IN)
weight_shape = (
channel // env.BLOCK_OUT,
channel // env.BLOCK_IN,
env.BLOCK_OUT,
env.BLOCK_IN,
)
res_shape = (batch_size // env.BATCH, channel // env.BLOCK_OUT, env.BATCH, env.BLOCK_OUT)
# To compute number of ops, use a x2 factor for FMA
num_ops = 2 * channel * channel * batch_size
ko = te.reduce_axis((0, channel // env.BLOCK_IN), name="ko")
ki = te.reduce_axis((0, env.BLOCK_IN), name="ki")
data = te.placeholder(data_shape, name="data", dtype=env.inp_dtype)
weight = te.placeholder(weight_shape, name="weight", dtype=env.wgt_dtype)
data_buf = te.compute(data_shape, lambda *i: data(*i), "data_buf")
weight_buf = te.compute(weight_shape, lambda *i: weight(*i), "weight_buf")
res_gem = te.compute(
res_shape,
lambda bo, co, bi, ci: te.sum(
data_buf[bo, ko, bi, ki].astype(env.acc_dtype)
* weight_buf[co, ko, ci, ki].astype(env.acc_dtype),
axis=[ko, ki],
),
name="res_gem",
)
res_shf = te.compute(res_shape, lambda *i: res_gem(*i) >> 8, name="res_shf")
res_max = te.compute(res_shape, lambda *i: tvm.te.max(res_shf(*i), 0), "res_max") # relu
res_min = te.compute(
res_shape, lambda *i: tvm.te.min(res_max(*i), (1 << (env.INP_WIDTH - 1)) - 1), "res_min"
) # relu
res = te.compute(res_shape, lambda *i: res_min(*i).astype(env.inp_dtype), name="res")
def verify(s):
mod = vta.build(
s,
[data, weight, res],
tvm.target.Target("ext_dev", host=env.target_host),
name="gemm",
)
temp = tempdir()
mod.save(temp.relpath("gemm.o"))
remote.upload(temp.relpath("gemm.o"))
f = remote.load_module("gemm.o")
# verify
dev = remote.ext_dev(0)
# Data in original format
data_orig = np.random.randint(-128, 128, size=(batch_size, channel)).astype(data.dtype)
weight_orig = np.random.randint(-128, 128, size=(channel, channel)).astype(weight.dtype)
data_packed = data_orig.reshape(
batch_size // env.BATCH, env.BATCH, channel // env.BLOCK_IN, env.BLOCK_IN
).transpose((0, 2, 1, 3))
weight_packed = weight_orig.reshape(
channel // env.BLOCK_OUT, env.BLOCK_OUT, channel // env.BLOCK_IN, env.BLOCK_IN
).transpose((0, 2, 1, 3))
res_np = np.zeros(res_shape).astype(res.dtype)
data_arr = tvm.nd.array(data_packed, dev)
weight_arr = tvm.nd.array(weight_packed, dev)
res_arr = tvm.nd.array(res_np, dev)
res_ref = np.zeros(res_shape).astype(env.acc_dtype)
for b in range(batch_size // env.BATCH):
for i in range(channel // env.BLOCK_OUT):
for j in range(channel // env.BLOCK_IN):
res_ref[b, i, :] += np.dot(
data_packed[b, j, :].astype(env.acc_dtype),
weight_packed[i, j].T.astype(env.acc_dtype),
)
res_ref = np.right_shift(res_ref, 8)
res_ref = np.clip(res_ref, 0, (1 << (env.INP_WIDTH - 1)) - 1).astype(res.dtype)
time_f = f.time_evaluator("gemm", dev, number=20)
if env.TARGET in ["sim", "tsim"]:
simulator.clear_stats()
cost = time_f(data_arr, weight_arr, res_arr)
if env.TARGET in ["sim", "tsim"]:
stats = simulator.stats()
print("Execution statistics:")
for k, v in stats.items():
print("\t{:<16}: {:>16}".format(k, v))
res_unpack = res_arr.numpy().reshape(
batch_size // env.BATCH, channel // env.BLOCK_OUT, env.BATCH, env.BLOCK_OUT
)
return cost
def run_schedule(load_inp, load_wgt, gemm, alu, store_out, print_ir, block):
s = te.create_schedule(res.op)
s[data_buf].set_scope(env.inp_scope)
s[weight_buf].set_scope(env.wgt_scope)
s[res_gem].set_scope(env.acc_scope)
s[res_shf].set_scope(env.acc_scope)
s[res_min].set_scope(env.acc_scope)
s[res_max].set_scope(env.acc_scope)
if block:
bblock = block // env.BATCH
iblock = block // env.BLOCK_IN
oblock = block // env.BLOCK_OUT
xbo, xco, xbi, xci = s[res].op.axis
xb1, xco1, xb2, xco2 = s[res].tile(xbo, xco, bblock, oblock)
store_pt = xb2
s[res_gem].compute_at(s[res], xco1)
s[res_shf].compute_at(s[res], xco1)
s[res_min].compute_at(s[res], xco1)
s[res_max].compute_at(s[res], xco1)
xbo, xco, xbi, xci = s[res_gem].op.axis
# Compute one line at a time
ko1, ko2 = s[res_gem].split(ko, iblock)
s[res_gem].reorder(ko1, ko2, xbo, xco, xbi, xci, ki)
s[data_buf].compute_at(s[res_gem], ko1)
s[weight_buf].compute_at(s[res_gem], ko1)
# Use VTA instructions
s[data_buf].pragma(s[data_buf].op.axis[0], load_inp)
s[weight_buf].pragma(s[weight_buf].op.axis[0], load_wgt)
s[res_gem].tensorize(xbi, gemm)
s[res_shf].pragma(s[res_shf].op.axis[0], alu)
s[res_min].pragma(s[res_min].op.axis[0], alu)
s[res_max].pragma(s[res_max].op.axis[0], alu)
s[res].pragma(store_pt, store_out)
else:
xbo, xco, xbi, xci = s[res_gem].op.axis
s[res_gem].reorder(ko, xbo, xco, xbi, xci, ki)
# Use VTA instructions
s[data_buf].pragma(s[data_buf].op.axis[0], load_inp)
s[weight_buf].pragma(s[weight_buf].op.axis[0], load_wgt)
s[res_gem].tensorize(xbi, gemm)
s[res_shf].pragma(s[res_shf].op.axis[0], alu)
s[res_min].pragma(s[res_min].op.axis[0], alu)
s[res_max].pragma(s[res_max].op.axis[0], alu)
s[res].pragma(s[res].op.axis[0], store_out)
if print_ir:
print(tvm.lower(s, [data, weight, res], simple_mode=True))
return verify(s)
GEMM GOPS End-to-End Test:
mock = env.mock
with vta.build_config():
cost = run_schedule(
env.dma_copy,
env.dma_copy,
env.gemm,
env.alu,
env.dma_copy,
print_ir=False,
block=block
)
gops = (num_ops / cost.mean) / float(10**9)
print("\tTime cost = %g sec/op, %g GOPS" % (cost.mean, gops))
Execution statistics:
inp_load_nbytes : 344064
wgt_load_nbytes : 344064
acc_load_nbytes : 0
uop_load_nbytes : 1008
out_store_nbytes: 344064
gemm_counter : 172032
alu_counter : 64512
Time cost = 0.00169099 sec/op, 2.48038 GOPS
[08:32:41] /media/pc/data/lxw/ai/tvm/src/tir/transforms/arg_binder.cc:95: Warning: Trying to bind buffer to another one with lower alignment requirement required_alignment=256, provided_alignment=64
2023-09-25 08:32:42.101 INFO load_module /tmp/tmp8u11kql8/gemm.o
GEMM Unit Test:
mock = env.mock
with vta.build_config():
cost = run_schedule(
mock.dma_copy, mock.dma_copy, env.gemm, mock.alu, mock.dma_copy,
print_ir=False,
block=block
)
gops = (num_ops / cost.mean) / float(10**9)
print("\tTime cost = %g sec/op, %g GOPS" % (cost.mean, gops))
Execution statistics:
inp_load_nbytes : 0
wgt_load_nbytes : 0
acc_load_nbytes : 0
uop_load_nbytes : 756
out_store_nbytes: 0
gemm_counter : 172032
alu_counter : 0
Time cost = 0.00688763 sec/op, 0.608962 GOPS
[08:34:29] /media/pc/data/lxw/ai/tvm/src/tir/transforms/arg_binder.cc:95: Warning: Trying to bind buffer to another one with lower alignment requirement required_alignment=256, provided_alignment=64
2023-09-25 08:34:29.973 INFO load_module /tmp/tmp8u11kql8/gemm.o
ALU 测试:
mock = env.mock
with vta.build_config():
cost = run_schedule(
mock.dma_copy, mock.dma_copy, mock.gemm, env.alu, mock.dma_copy,
print_ir=False,
block=block
)
gops = (num_ops / cost.mean) / float(10**9)
print("\tTime cost = %g sec/op, %g GOPS" % (cost.mean, gops))
Execution statistics:
inp_load_nbytes : 0
wgt_load_nbytes : 0
acc_load_nbytes : 0
uop_load_nbytes : 252
out_store_nbytes: 0
gemm_counter : 0
alu_counter : 64512
Time cost = 0.000132332 sec/op, 31.6953 GOPS
[08:33:08] /media/pc/data/lxw/ai/tvm/src/tir/transforms/arg_binder.cc:95: Warning: Trying to bind buffer to another one with lower alignment requirement required_alignment=256, provided_alignment=64
2023-09-25 08:33:08.365 INFO load_module /tmp/tmp8u11kql8/gemm.o
LoadInp Unit Test:
mock = env.mock
with vta.build_config():
cost = run_schedule(
env.dma_copy, mock.dma_copy, mock.gemm, mock.alu, mock.dma_copy,
print_ir=False,
block=block
)
gops = (num_ops / cost.mean) / float(10**9)
gops = (num_ops / cost.mean) / float(10**9)
bandwith = (batch_size * channel * env.INP_WIDTH / cost.mean) / float(10**9)
print(
"\tTime cost = %g sec/op, %g GOPS, bandwidth=%g Gbits"
% (cost.mean, gops, bandwith)
)
Execution statistics:
inp_load_nbytes : 344064
wgt_load_nbytes : 0
acc_load_nbytes : 0
uop_load_nbytes : 0
out_store_nbytes: 0
gemm_counter : 0
alu_counter : 0
Time cost = 2.45895e-06 sec/op, 1705.73 GOPS, bandwidth=53.3041 Gbits
[08:36:33] /media/pc/data/lxw/ai/tvm/src/tir/transforms/arg_binder.cc:95: Warning: Trying to bind buffer to another one with lower alignment requirement required_alignment=256, provided_alignment=64
2023-09-25 08:36:33.333 INFO load_module /tmp/tmp8u11kql8/gemm.o
LoadWgt Unit Test:
mock = env.mock
with vta.build_config():
cost = run_schedule(
mock.dma_copy, env.dma_copy, mock.gemm, mock.alu, mock.dma_copy,
print_ir=False,
block=block
)
gops = (num_ops / cost.mean) / float(10**9)
gops = (num_ops / cost.mean) / float(10**9)
bandwith = (batch_size * channel * env.INP_WIDTH / cost.mean) / float(10**9)
print(
"\tTime cost = %g sec/op, %g GOPS, bandwidth=%g Gbits"
% (cost.mean, gops, bandwith)
)
Execution statistics:
inp_load_nbytes : 0
wgt_load_nbytes : 344064
acc_load_nbytes : 0
uop_load_nbytes : 0
out_store_nbytes: 0
gemm_counter : 0
alu_counter : 0
Time cost = 2.4185e-06 sec/op, 1734.26 GOPS, bandwidth=54.1956 Gbits
[08:37:20] /media/pc/data/lxw/ai/tvm/src/tir/transforms/arg_binder.cc:95: Warning: Trying to bind buffer to another one with lower alignment requirement required_alignment=256, provided_alignment=64
2023-09-25 08:37:20.333 INFO load_module /tmp/tmp8u11kql8/gemm.o
StoreOut Unit Test:
mock = env.mock
with vta.build_config():
cost = run_schedule(
mock.dma_copy, mock.dma_copy, mock.gemm, mock.alu, env.dma_copy,
print_ir=False,
block=block
)
gops = (num_ops / cost.mean) / float(10**9)
gops = (num_ops / cost.mean) / float(10**9)
bandwith = (batch_size * channel * env.INP_WIDTH / cost.mean) / float(10**9)
print(
"\tTime cost = %g sec/op, %g GOPS, bandwidth=%g Gbits"
% (cost.mean, gops, bandwith)
)
Execution statistics:
inp_load_nbytes : 0
wgt_load_nbytes : 0
acc_load_nbytes : 0
uop_load_nbytes : 0
out_store_nbytes: 344064
gemm_counter : 0
alu_counter : 0
Time cost = 2.62682e-05 sec/op, 159.672 GOPS, bandwidth=4.98975 Gbits
[08:38:14] /media/pc/data/lxw/ai/tvm/src/tir/transforms/arg_binder.cc:95: Warning: Trying to bind buffer to another one with lower alignment requirement required_alignment=256, provided_alignment=64
2023-09-25 08:38:14.909 INFO load_module /tmp/tmp8u11kql8/gemm.o