注册插件#
import sys
from pathlib import Path
ROOT = Path(".").resolve().parents[3]
# print(ROOT)
sys.path.extend([f"{ROOT}/tests"])
# from tools.tag_span import _create_span, _set_span, _verify_structural_equal_with_span
import tools
from d2py.utils.file import mkdir
root_dir = ".temp"
mkdir(root_dir )
import tvm
from tvm import te
import numpy as np
@tvm.register_extension
class MyTensorView(object):
_tvm_tcode = tvm._ffi.runtime_ctypes.ArgTypeCode.DLTENSOR_HANDLE
def __init__(self, arr):
self.arr = arr
@property
def _tvm_handle(self):
return self.arr._tvm_handle
dltensor compatible#
dtype = "int64"
n = te.var("n")
Ab = tvm.tir.decl_buffer((n,), dtype)
i = te.var("i")
ib = tvm.tir.ir_builder.create()
A = ib.buffer_ptr(Ab)
with ib.for_range(0, n - 1, "i") as i:
A[i + 1] = A[i] + 1
stmt = ib.get()
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab], stmt).with_attr("global_symbol", "arange"))
f = tvm.build(mod, target="stackvm")
a = tvm.nd.array(np.zeros(10, dtype=dtype))
aview = MyTensorView(a)
f(aview)
np.testing.assert_equal(a.numpy(), np.arange(a.shape[0]))