def _run(env, remote):
# declare
o = 4
n = 1
m = 4
x = te.placeholder((o, n, env.BATCH, env.BLOCK_IN), name="x", dtype=env.inp_dtype)
w = te.placeholder((m, n, env.BLOCK_OUT, env.BLOCK_IN), name="w", dtype=env.wgt_dtype)
x_buf = te.compute((o, n, env.BATCH, env.BLOCK_IN), lambda *i: x(*i), "x_buf")
w_buf = te.compute((m, n, env.BLOCK_OUT, env.BLOCK_IN), lambda *i: w(*i), "w_buf")
ko = te.reduce_axis((0, n), name="ko")
ki = te.reduce_axis((0, env.BLOCK_IN), name="ki")
y_gem = te.compute(
(o, m, env.BATCH, env.BLOCK_OUT),
lambda bo, co, bi, ci: te.sum(
x_buf[bo, ko, bi, ki].astype(env.acc_dtype)
* w_buf[co, ko, ci, ki].astype(env.acc_dtype),
axis=[ko, ki],
),
name="y_gem",
)
y_shf = te.compute(
(o, m, env.BATCH, env.BLOCK_OUT), lambda *i: y_gem(*i) >> 8, name="y_shf"
)
y_max = te.compute(
(o, m, env.BATCH, env.BLOCK_OUT), lambda *i: tvm.te.max(y_shf(*i), 0), "y_max"
) # relu
y_min = te.compute(
(o, m, env.BATCH, env.BLOCK_OUT),
lambda *i: tvm.te.min(y_max(*i), (1 << (env.INP_WIDTH - 1)) - 1),
"y_min",
) # relu
y = te.compute(
(o, m, env.BATCH, env.BLOCK_OUT), lambda *i: y_min(*i).astype(env.inp_dtype), name="y"
)
if not remote:
return
def verify(s, name=None):
# Build with the CSE pass disabled as otherwise it would complicate the test
with vta.build_config(disabled_pass={"tir.CommonSubexprElimTIR"}):
mod = vta.build(s, [x, w, y], tvm.target.Target("ext_dev", host=env.target_host))
temp = tempdir()
mod.save(temp.relpath("gemm.o"))
remote.upload(temp.relpath("gemm.o"))
f = remote.load_module("gemm.o")
# verify
dev = remote.ext_dev(0)
x_np = np.random.randint(-128, 128, size=(o, n, env.BATCH, env.BLOCK_IN)).astype(
x.dtype
)
w_np = np.random.randint(-128, 128, size=(m, n, env.BLOCK_OUT, env.BLOCK_IN)).astype(
w.dtype
)
y_np = np.zeros((o, m, env.BATCH, env.BLOCK_OUT)).astype(y.dtype)
x_nd = tvm.nd.array(x_np, dev)
w_nd = tvm.nd.array(w_np, dev)
y_nd = tvm.nd.array(y_np, dev)
y_np = y_np.astype(env.acc_dtype)
for b in range(o):
for i in range(m):
for j in range(n):
y_np[b, i, :] += np.dot(
x_np[b, j, :].astype(env.acc_dtype), w_np[i, j].T.astype(env.acc_dtype)
)
y_np = np.right_shift(y_np, 8)
y_np = np.clip(y_np, 0, (1 << (env.INP_WIDTH - 1)) - 1).astype(y.dtype)
if env.TARGET in ["sim", "tsim"]:
simulator.clear_stats()
f(x_nd, w_nd, y_nd)
np.testing.assert_equal(y_np, y_nd.numpy())
if env.TARGET in ["sim", "tsim"]:
sim_stats = simulator.stats()
print("GEMM schedule:{} execution statistics:".format(name))
for k, v in sim_stats.items():
print("\t{:<16}: {:>16}".format(k, v))
def test_schedule1():
# default schedule with no smt
s = te.create_schedule(y.op)
# set the scope of the SRAM buffers
s[x_buf].set_scope(env.inp_scope)
s[w_buf].set_scope(env.wgt_scope)
s[y_gem].set_scope(env.acc_scope)
s[y_shf].set_scope(env.acc_scope)
s[y_max].set_scope(env.acc_scope)
s[y_min].set_scope(env.acc_scope)
# set pragmas for DMA transfer and ALU ops
s[x_buf].compute_at(s[y_gem], ko)
s[x_buf].pragma(s[x_buf].op.axis[0], env.dma_copy)
s[w_buf].compute_at(s[y_gem], ko)
s[w_buf].pragma(s[w_buf].op.axis[0], env.dma_copy)
s[y_shf].pragma(s[y_shf].op.axis[0], env.alu)
s[y_max].pragma(s[y_max].op.axis[0], env.alu)
s[y_min].pragma(s[y_min].op.axis[0], env.alu)
s[y].pragma(s[y].op.axis[0], env.dma_copy)
# tensorization
s[y_gem].reorder(
ko,
s[y_gem].op.axis[0],
s[y_gem].op.axis[1],
s[y_gem].op.axis[2],
s[y_gem].op.axis[3],
ki,
)
s[y_gem].tensorize(s[y_gem].op.axis[2], env.gemm)
verify(s, name="default")
def test_smt():
# test smt schedule
s = te.create_schedule(y.op)
s[x_buf].set_scope(env.inp_scope)
s[w_buf].set_scope(env.wgt_scope)
s[y_gem].set_scope(env.acc_scope)
s[y_shf].set_scope(env.acc_scope)
s[y_max].set_scope(env.acc_scope)
s[y_min].set_scope(env.acc_scope)
abo, aco, abi, aci = s[y].op.axis
abo1, abo2 = s[y].split(abo, nparts=2)
s[y].bind(abo1, te.thread_axis("cthread"))
s[y_gem].compute_at(s[y], abo1)
s[y_shf].compute_at(s[y], abo1)
s[y_max].compute_at(s[y], abo1)
s[y_min].compute_at(s[y], abo1)
s[y_gem].reorder(
ko,
s[y_gem].op.axis[0],
s[y_gem].op.axis[1],
s[y_gem].op.axis[2],
s[y_gem].op.axis[3],
ki,
)
s[y_gem].tensorize(s[y_gem].op.axis[2], env.gemm)
s[y_shf].pragma(s[y_shf].op.axis[0], env.alu)
s[y_max].pragma(s[y_max].op.axis[0], env.alu)
s[y_min].pragma(s[y_min].op.axis[0], env.alu)
s[x_buf].compute_at(s[y_gem], ko)
s[x_buf].pragma(s[x_buf].op.axis[0], env.dma_copy)
s[w_buf].compute_at(s[y_gem], ko)
s[w_buf].pragma(s[w_buf].op.axis[0], env.dma_copy)
s[y].pragma(abo2, env.dma_copy)
verify(s, name="smt")
test_schedule1()
test_smt()
vta.testing.run(_run)