VTA padded load#
import tvm
from tvm import te
import numpy as np
from tvm import topi
from tvm.contrib.utils import tempdir
import vta
import vta.testing
from vta.testing import simulator
np.random.seed(0xDEADB)
def _run(env, remote):
def check_padded_load(pad_before, pad_after, test_name=None):
# declare
n = 3
m = 5
x = te.placeholder((n, m, env.BATCH, env.BLOCK_OUT), name="x", dtype=env.acc_dtype)
x_buf = topi.nn.pad(x, pad_before, pad_after, name="y")
# insert no-op that won't be optimized away
y_buf = te.compute(
(
n + pad_before[0] + pad_after[0],
m + pad_before[1] + pad_after[1],
env.BATCH,
env.BLOCK_OUT,
),
lambda *i: x_buf(*i) >> 0,
"y_buf",
)
y = te.compute(
(
n + pad_before[0] + pad_after[0],
m + pad_before[1] + pad_after[1],
env.BATCH,
env.BLOCK_OUT,
),
lambda *i: y_buf(*i).astype(env.inp_dtype),
"y",
)
# schedule
s = te.create_schedule(y.op)
s[x_buf].set_scope(env.acc_scope)
s[x_buf].pragma(x_buf.op.axis[0], env.dma_copy)
s[y_buf].set_scope(env.acc_scope)
s[y_buf].pragma(y_buf.op.axis[0], env.alu)
s[y].pragma(y.op.axis[0], env.dma_copy)
# build
with vta.build_config():
mod = vta.build(s, [x, y], tvm.target.Target("ext_dev", host=env.target_host))
if not remote:
return
temp = tempdir()
mod.save(temp.relpath("padded_load.o"))
remote.upload(temp.relpath("padded_load.o"))
f = remote.load_module("padded_load.o")
# verify
dev = remote.ext_dev(0)
x_np = np.random.randint(0, 10, size=(n, m, env.BATCH, env.BLOCK_OUT)).astype(x.dtype)
y_np = np.zeros(
(
n + pad_before[0] + pad_after[0],
m + pad_before[1] + pad_after[1],
env.BATCH,
env.BLOCK_OUT,
)
).astype(y.dtype)
y_np[pad_before[0] : pad_before[0] + n, pad_before[1] : pad_before[1] + m, :] = x_np
x_nd = tvm.nd.array(x_np, dev)
y_nd = tvm.nd.empty(y_np.shape, device=dev, dtype=y_np.dtype)
if env.TARGET in ["sim", "tsim"]:
simulator.clear_stats()
f(x_nd, y_nd)
np.testing.assert_equal(y_np, y_nd.numpy())
if env.TARGET in ["sim", "tsim"]:
sim_stats = simulator.stats()
print(f"Padded {test_name} load execution statistics:")
for k, v in sim_stats.items():
print("\t{:<16}: {:>16}".format(k, v))
check_padded_load([2, 0, 0, 0], [0, 0, 0, 0], test_name="Y0")
check_padded_load([0, 2, 0, 0], [0, 0, 0, 0], test_name="Y1")
check_padded_load([0, 0, 0, 0], [2, 0, 0, 0], test_name="X0")
check_padded_load([0, 0, 0, 0], [0, 2, 0, 0], test_name="X1")
check_padded_load([1, 1, 0, 0], [1, 1, 0, 0], test_name="all")
vta.testing.run(_run)
Padded Y0 load execution statistics:
inp_load_nbytes : 0
wgt_load_nbytes : 0
acc_load_nbytes : 960
uop_load_nbytes : 4
out_store_nbytes: 400
gemm_counter : 0
alu_counter : 25
Padded Y1 load execution statistics:
inp_load_nbytes : 0
wgt_load_nbytes : 0
acc_load_nbytes : 960
uop_load_nbytes : 4
out_store_nbytes: 336
gemm_counter : 0
alu_counter : 21
Padded X0 load execution statistics:
inp_load_nbytes : 0
wgt_load_nbytes : 0
acc_load_nbytes : 960
uop_load_nbytes : 4
out_store_nbytes: 400
gemm_counter : 0
alu_counter : 25
Padded X1 load execution statistics:
inp_load_nbytes : 0
wgt_load_nbytes : 0
acc_load_nbytes : 960
uop_load_nbytes : 4
out_store_nbytes: 336
gemm_counter : 0
alu_counter : 21
Padded all load execution statistics:
inp_load_nbytes : 0
wgt_load_nbytes : 0
acc_load_nbytes : 960
uop_load_nbytes : 4
out_store_nbytes: 560
gemm_counter : 0
alu_counter : 35
2023-09-25 13:10:39.193 INFO load_module /tmp/tmpxhvhkw8k/padded_load.o
2023-09-25 13:10:39.453 INFO load_module /tmp/tmpxhvhkw8k/padded_load.o
2023-09-25 13:10:39.737 INFO load_module /tmp/tmpxhvhkw8k/padded_load.o
2023-09-25 13:10:39.993 INFO load_module /tmp/tmpxhvhkw8k/padded_load.o
2023-09-25 13:10:40.310 INFO load_module /tmp/tmpxhvhkw8k/padded_load.o