def _run(env, remote):
m = 8
n = 10
# compute
a = te.placeholder((m, n, env.BATCH, env.BLOCK_OUT), name="a", dtype=env.acc_dtype)
a_buf = te.compute(
(m, n, env.BATCH, env.BLOCK_OUT), lambda *i: a(*i), "a_buf"
) # DRAM->SRAM
max_buf = te.compute(
(m, n, env.BATCH, env.BLOCK_OUT), lambda *i: tvm.te.max(a_buf(*i), 0), "res_buf"
) # relu
min_buf = te.compute(
(m, n, env.BATCH, env.BLOCK_OUT),
lambda *i: tvm.te.min(max_buf(*i), (1 << (env.INP_WIDTH - 1)) - 1),
"max_buf",
) # relu
res = te.compute(
(m, n, env.BATCH, env.BLOCK_OUT),
lambda *i: min_buf(*i).astype(env.inp_dtype),
"min_buf",
) # SRAM->DRAM
# schedule
s = te.create_schedule(res.op)
s[a_buf].set_scope(env.acc_scope) # SRAM
s[a_buf].pragma(a_buf.op.axis[0], env.dma_copy) # DRAM->SRAM
s[max_buf].set_scope(env.acc_scope) # SRAM
s[min_buf].set_scope(env.acc_scope) # SRAM
s[max_buf].pragma(max_buf.op.axis[0], env.alu) # compute
s[min_buf].pragma(min_buf.op.axis[0], env.alu) # compute
s[res].pragma(res.op.axis[0], env.dma_copy) # SRAM->DRAM
# build
with vta.build_config():
mod = vta.build(s, [a, res], tvm.target.Target("ext_dev", host=env.target_host))
if not remote:
return
temp = tempdir()
mod.save(temp.relpath("load_act.o"))
remote.upload(temp.relpath("load_act.o"))
f = remote.load_module("load_act.o")
# verify
dev = remote.ext_dev(0)
a_np = np.random.randint(-256, 256, size=(m, n, env.BATCH, env.BLOCK_OUT)).astype(a.dtype)
res_np = np.clip(a_np, 0, (1 << (env.INP_WIDTH - 1)) - 1).astype(res.dtype)
a_nd = tvm.nd.array(a_np, dev)
res_nd = tvm.nd.array(np.zeros((m, n, env.BATCH, env.BLOCK_OUT)).astype(res.dtype), dev)
if env.TARGET in ["sim", "tsim"]:
simulator.clear_stats()
f(a_nd, res_nd)
np.testing.assert_equal(res_np, res_nd.numpy())
if env.TARGET in ["sim", "tsim"]:
sim_stats = simulator.stats()
print("Relu execution statistics:")
for k, v in sim_stats.items():
print("\t{:<16}: {:>16}".format(k, v))
vta.testing.run(_run)