创建 TVM 的 NDArray 的子类#
tvm_ext.ivec_create
#
import tvm
from tvm_book.tvm_ext.libinfo import _load_lib
_LIB, _LIB_NAME = _load_lib(name="libtvm_ext.so", search_path=["outputs/libs"])
tvm._ffi._init_api("tvm_ext", __name__)
要使用此插件,外部库应执行以下操作:
继承 TVM 的 NDArray 和 NDArray 容器;
遵循新的对象协议以将新 NDArray 定义为引用类。
在 Python 前端上,继承
tvm.nd.NDArray
,并使用tvm.register_object
注册类型。
@tvm.register_object("tvm_ext.IntVector")
class IntVec(tvm.Object):
"""Example for using extension class in c++"""
@property
def _tvm_handle(self):
return self.handle.value
def __getitem__(self, idx):
return ivec_get(self, idx)
ivec = ivec_create(1, 2, 3)
assert isinstance(ivec, IntVec)
assert ivec[0] == 1
assert ivec[1] == 2
def ivec_cb(v2):
assert isinstance(v2, IntVec)
assert v2[2] == 3
tvm.runtime.convert(ivec_cb)(ivec)
tvm_ext.NDSubClass
#
@tvm.register_object("tvm_ext.NDSubClass")
class NDSubClass(tvm.nd.NDArrayBase):
"""Example for subclassing TVM's NDArray infrastructure.
By inheriting TVM's NDArray, external libraries could
leverage TVM's FFI without any modification.
"""
@staticmethod
def create(additional_info):
return nd_create(additional_info)
@property
def additional_info(self):
return nd_get_additional_info(self)
def __add__(self, other):
return nd_add_two(self, other)
a = NDSubClass.create(additional_info=3)
b = NDSubClass.create(additional_info=5)
assert isinstance(a, NDSubClass)
c = a + b
d = a + a
e = b + b
assert a.additional_info == 3
assert b.additional_info == 5
assert c.additional_info == 8
assert d.additional_info == 6
assert e.additional_info == 10