VTA 运行时数组#
import tvm
import numpy as np
import vta.testing
np.random.seed(0xDEADB)
def _run(env, remote):
n = 100
dev = remote.ext_dev(0)
x_np = np.random.randint(1, 10, size=(n, n, env.BATCH, env.BLOCK_OUT)).astype("int8")
x_nd = tvm.nd.array(x_np, dev)
print(x_nd.device)
np.testing.assert_equal(x_np, x_nd.numpy())
vta.testing.run(_run)
remote[0]:ext_dev(0)