def _run(env, remote):
n = 6
x = te.placeholder((n, n, env.BATCH, env.BLOCK_OUT), name="x", dtype=env.acc_dtype)
x_buf = te.compute((n, n, env.BATCH, env.BLOCK_OUT), lambda *i: x(*i), "x_buf")
# 插入不会被优化掉的 no-op
y_buf = te.compute((n, n, env.BATCH, env.BLOCK_OUT), lambda *i: x_buf(*i) >> 0, "y_buf")
y = te.compute(
(n, n, env.BATCH, env.BLOCK_OUT), lambda *i: y_buf(*i).astype(env.inp_dtype), "y"
)
# schedule
s = te.create_schedule(y.op)
s[x_buf].set_scope(env.acc_scope)
s[x_buf].pragma(x_buf.op.axis[0], env.dma_copy)
s[y_buf].set_scope(env.acc_scope)
s[y_buf].pragma(y_buf.op.axis[0], env.alu)
s[y].pragma(y.op.axis[0], env.dma_copy)
# 构建库
with vta.build_config():
m = vta.build(s, [x, y], tvm.target.Target("ext_dev", host=env.target_host))
temp = tempdir()
m.save(temp.relpath("load_act.o"))
remote.upload(temp.relpath("load_act.o"))
f = remote.load_module("load_act.o")
# 验证
dev = remote.ext_dev(0)
x_np = np.random.randint(1, 10, size=(n, n, env.BATCH, env.BLOCK_OUT)).astype(x.dtype)
y_np = x_np.astype(y.dtype)
x_nd = tvm.nd.array(x_np, dev)
y_nd = tvm.nd.empty(y_np.shape, device=dev, dtype=y_np.dtype)
assert env.TARGET in ["sim", "tsim"]
simulator.clear_stats()
f(x_nd, y_nd)
np.testing.assert_equal(y_np, y_nd.numpy())
sim_stats = simulator.stats()
print("Save load execution statistics:")
for k, v in sim_stats.items():
print("\t{:<16}: {:>16}".format(k, v))
vta.testing.run(_run)