创建 TVM 的 NDArray 的子类

创建 TVM 的 NDArray 的子类#

C++ 源码:
/*!
 * \brief Example package that uses TVM.
 * \file tvm_ext.cc
 */
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/module.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
#include <tvm/tir/op.h>

using namespace tvm;
using namespace tvm::tir;
using namespace tvm::runtime;

namespace tvm_ext {
/*!
 * \brief A subclass of TVM's NDArray.
 *
 * To use this extension, an external library should
 *
 * 1) Inherit TVM's NDArray and NDArray container,
 *
 * 2) Follow the new object protocol to define new NDArray as a reference class.
 *
 * 3) On Python frontend, inherit `tvm.nd.NDArray`,
 *    register the type using tvm.register_object
 */
class NDSubClass : public tvm::runtime::NDArray {
 public:
  class SubContainer : public NDArray::Container {
   public:
    SubContainer(int additional_info) : additional_info_(additional_info) {
      type_index_ = SubContainer::RuntimeTypeIndex();
    }
    int additional_info_{0};

    static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
    static constexpr const char* _type_key = "tvm_ext.NDSubClass";
    TVM_DECLARE_FINAL_OBJECT_INFO(SubContainer, NDArray::Container);
  };

  static void SubContainerDeleter(Object* obj) {
    auto* ptr = static_cast<SubContainer*>(obj);
    delete ptr;
  }

  NDSubClass() {}
  explicit NDSubClass(ObjectPtr<Object> n) : NDArray(n) {}
  explicit NDSubClass(int additional_info) {
    SubContainer* ptr = new SubContainer(additional_info);
    ptr->SetDeleter(SubContainerDeleter);
    data_ = GetObjectPtr<Object>(ptr);
  }

  NDSubClass AddWith(const NDSubClass& other) const {
    SubContainer* a = static_cast<SubContainer*>(get_mutable());
    SubContainer* b = static_cast<SubContainer*>(other.get_mutable());
    ICHECK(a != nullptr && b != nullptr);
    return NDSubClass(a->additional_info_ + b->additional_info_);
  }
  int get_additional_info() const {
    SubContainer* self = static_cast<SubContainer*>(get_mutable());
    ICHECK(self != nullptr);
    return self->additional_info_;
  }
  using ContainerType = SubContainer;
};

TVM_REGISTER_OBJECT_TYPE(NDSubClass::SubContainer);

/*!
 * \brief Introduce additional extension data structures
 *        by sub-classing TVM's object system.
 */
class IntVectorObj : public Object {
 public:
  std::vector<int> vec;

  static constexpr const char* _type_key = "tvm_ext.IntVector";
  TVM_DECLARE_FINAL_OBJECT_INFO(IntVectorObj, Object);
};

/*!
 * \brief Int vector reference class.
 */
class IntVector : public ObjectRef {
 public:
  TVM_DEFINE_OBJECT_REF_METHODS(IntVector, ObjectRef, IntVectorObj);
};

TVM_REGISTER_OBJECT_TYPE(IntVectorObj);

}  // namespace tvm_ext


namespace tvm_ext {

TVM_REGISTER_GLOBAL("tvm_ext.ivec_create").set_body([](TVMArgs args, TVMRetValue* rv) {
  auto n = tvm::runtime::make_object<IntVectorObj>();
  for (int i = 0; i < args.size(); ++i) {
    n->vec.push_back(args[i].operator int());
  }
  *rv = IntVector(n);
});

TVM_REGISTER_GLOBAL("tvm_ext.ivec_get").set_body([](TVMArgs args, TVMRetValue* rv) {
  IntVector p = args[0];
  *rv = p->vec[args[1].operator int()];
});

TVM_REGISTER_GLOBAL("tvm_ext.nd_create").set_body([](TVMArgs args, TVMRetValue* rv) {
  int additional_info = args[0];
  *rv = NDSubClass(additional_info);
  ICHECK_EQ(rv->type_code(), kTVMNDArrayHandle);
});

TVM_REGISTER_GLOBAL("tvm_ext.nd_add_two").set_body([](TVMArgs args, TVMRetValue* rv) {
  NDSubClass a = args[0];
  NDSubClass b = args[1];
  *rv = a.AddWith(b);
});

TVM_REGISTER_GLOBAL("tvm_ext.nd_get_additional_info").set_body([](TVMArgs args, TVMRetValue* rv) {
  NDSubClass a = args[0];
  *rv = a.get_additional_info();
});
}  // namespace tvm_ext

编译:

!make outputs/libs/libtvm_NDSubClass.so
g++ -std=c++17 -O2 -fPIC -I/media/pc/data/lxw/ai/tvm/include -I/media/pc/data/lxw/ai/tvm/3rdparty/dmlc-core/include -I/media/pc/data/lxw/ai/tvm/3rdparty/dlpack/include -Iinclude -DDMLC_USE_LOGGING_LIBRARY=\<tvm/runtime/logging.h\> -shared -o outputs/libs/libtvm_NDSubClass.so src/testing/NDSubClass.cc -ldl -pthread -L/media/pc/data/lxw/ai/tvm/build

tvm_ext.ivec_create#

import tvm
from tvm_book.tvm_ext.libinfo import load_lib

_LIB, _LIB_NAME = load_lib(name="libtvm_NDSubClass.so", search_path=["outputs/libs"])
tvm._ffi._init_api("tvm_ext", __name__)

ivec_create = tvm.get_global_func("tvm_ext.ivec_create")
ivec_get = tvm.get_global_func("tvm_ext.ivec_get")

要使用此插件,外部库应执行以下操作:

  1. 继承 TVM 的 NDArray 和 NDArray 容器;

  2. 遵循新的对象协议以将新 NDArray 定义为引用类。

  3. 在 Python 前端上,继承 tvm.nd.NDArray,并使用 tvm.register_object 注册类型。

@tvm.register_object("tvm_ext.IntVector")
class IntVec(tvm.Object):
    """Example for using extension class in c++"""

    @property
    def _tvm_handle(self):
        return self.handle.value

    def __getitem__(self, idx):
        return ivec_get(self, idx)
ivec = ivec_create(1, 2, 3)
assert isinstance(ivec, IntVec)
assert ivec[0] == 1
assert ivec[1] == 2

def ivec_cb(v2):
    assert isinstance(v2, IntVec)
    assert v2[2] == 3

tvm.runtime.convert(ivec_cb)(ivec)

tvm_ext.NDSubClass#

nd_create = tvm.get_global_func("tvm_ext.nd_create")
nd_add_two = tvm.get_global_func("tvm_ext.nd_add_two")
nd_get_additional_info = tvm.get_global_func("tvm_ext.nd_get_additional_info")
@tvm.register_object("tvm_ext.NDSubClass")
class NDSubClass(tvm.nd.NDArrayBase):
    """Example for subclassing TVM's NDArray infrastructure.

    By inheriting TVM's NDArray, external libraries could
    leverage TVM's FFI without any modification.
    """

    @staticmethod
    def create(additional_info):
        return nd_create(additional_info)

    @property
    def additional_info(self):
        return nd_get_additional_info(self)

    def __add__(self, other):
        return nd_add_two(self, other)
a = NDSubClass.create(additional_info=3)
b = NDSubClass.create(additional_info=5)
assert isinstance(a, NDSubClass)
c = a + b
d = a + a
e = b + b
assert a.additional_info == 3
assert b.additional_info == 5
assert c.additional_info == 8
assert d.additional_info == 6
assert e.additional_info == 10