# 创建 TVM 的 NDArray 的子类

::::{dropdown} C++ 源码：
```{literalinclude} src/testing/NDSubClass.cc
:language: C++
```
::::

编译：

In [1]:
!make outputs/libs/libtvm_NDSubClass.so

g++ -std=c++17 -O2 -fPIC -I/media/pc/data/lxw/ai/tvm/include -I/media/pc/data/lxw/ai/tvm/3rdparty/dmlc-core/include -I/media/pc/data/lxw/ai/tvm/3rdparty/dlpack/include -Iinclude -DDMLC_USE_LOGGING_LIBRARY=\<tvm/runtime/logging.h\> -shared -o outputs/libs/libtvm_NDSubClass.so src/testing/NDSubClass.cc -ldl -pthread -L/media/pc/data/lxw/ai/tvm/build


## `tvm_ext.ivec_create`

In [2]:
import tvm
from tvm_book.tvm_ext.libinfo import load_lib

_LIB, _LIB_NAME = load_lib(name="libtvm_NDSubClass.so", search_path=["outputs/libs"])
tvm._ffi._init_api("tvm_ext", __name__)

ivec_create = tvm.get_global_func("tvm_ext.ivec_create")
ivec_get = tvm.get_global_func("tvm_ext.ivec_get")

要使用此插件，外部库应执行以下操作：

1. 继承 TVM 的 NDArray 和 NDArray 容器；
2. 遵循新的对象协议以将新 NDArray 定义为引用类。
3. 在 Python 前端上，继承 `tvm.nd.NDArray`，并使用 `tvm.register_object` 注册类型。

In [3]:
@tvm.register_object("tvm_ext.IntVector")
class IntVec(tvm.Object):
    """Example for using extension class in c++"""

    @property
    def _tvm_handle(self):
        return self.handle.value

    def __getitem__(self, idx):
        return ivec_get(self, idx)

In [4]:
ivec = ivec_create(1, 2, 3)
assert isinstance(ivec, IntVec)
assert ivec[0] == 1
assert ivec[1] == 2

def ivec_cb(v2):
    assert isinstance(v2, IntVec)
    assert v2[2] == 3

tvm.runtime.convert(ivec_cb)(ivec)

## `tvm_ext.NDSubClass`

In [5]:
nd_create = tvm.get_global_func("tvm_ext.nd_create")
nd_add_two = tvm.get_global_func("tvm_ext.nd_add_two")
nd_get_additional_info = tvm.get_global_func("tvm_ext.nd_get_additional_info")

In [6]:
@tvm.register_object("tvm_ext.NDSubClass")
class NDSubClass(tvm.nd.NDArrayBase):
    """Example for subclassing TVM's NDArray infrastructure.

    By inheriting TVM's NDArray, external libraries could
    leverage TVM's FFI without any modification.
    """

    @staticmethod
    def create(additional_info):
        return nd_create(additional_info)

    @property
    def additional_info(self):
        return nd_get_additional_info(self)

    def __add__(self, other):
        return nd_add_two(self, other)


In [7]:
a = NDSubClass.create(additional_info=3)
b = NDSubClass.create(additional_info=5)
assert isinstance(a, NDSubClass)
c = a + b
d = a + a
e = b + b
assert a.additional_info == 3
assert b.additional_info == 5
assert c.additional_info == 8
assert d.additional_info == 6
assert e.additional_info == 10