register_object
#
tvm.register_object(type_key=None)
实现的关键接口是 _LIB.TVMObjectTypeKey2Index
,函数的作用是根据给定的 key
获取对应的类型索引。(根据注释的说明,具体的实现细节可能在其他地方进行定义。如果你需要使用这个函数,可以在代码中包含该函数的声明,并在需要的地方调用它来获取类型索引。)拿到索引后,调用 _register_object(tindex, cls)
在 Python 端完成注册。
_LIB.TVMObjectTypeKey2Index
的实现如下查找链路:
// src/runtime/object.cc
int TVMObjectTypeKey2Index(const char* type_key, unsigned* out_tindex) {
API_BEGIN();
out_tindex[0] = tvm::runtime::ObjectInternal::ObjectTypeKey2Index(type_key);
API_END();
}
-> tvm::runtime::ObjectInternal::ObjectTypeKey2Index
定义如下:
static uint32_t ObjectTypeKey2Index(const std::string& type_key) {
return Object::TypeKey2Index(type_key);
}
-> Object::TypeKey2Index
定义如下:
uint32_t Object::TypeKey2Index(const std::string& key) {
return TypeContext::Global()->TypeKey2Index(key);
}
->
uint32_t TypeKey2Index(const std::string& skey) {
auto it = type_key2index_.find(skey);
ICHECK(it != type_key2index_.end())
<< "Cannot find type " << skey
<< ". Did you forget to register the node by TVM_REGISTER_NODE_TYPE ?";
return it->second;
}
->
std::unordered_map<std::string, uint32_t> type_key2index_;
TVMObjectTypeKey2Index
接受两个参数:一个指向字符类型的指针 type_key
和一个指向无符号整数类型的指针 out_tindex
。函数的返回类型是 int
。
函数的参数解释如下:
const char* type_key
:表示类型键的字符串指针。unsigned* out_tindex
:指向无符号整数的指针,用于存储转换后的类型索引。
函数的返回值解释如下:
当成功时,返回
0
。当失败时,返回非零值。
如果你需要使用这个函数,可以在代码中包含该函数的声明,并在需要的地方调用它来将类型键转换为类型索引。
在上述查找过程发现:TVM_REGISTER_NODE_TYPE
宏,用于注册 key2index 的绑定。
#define TVM_REGISTER_NODE_TYPE(TypeName) \
TVM_REGISTER_OBJECT_TYPE(TypeName); \
TVM_REGISTER_REFLECTION_VTABLE(TypeName, ::tvm::detail::ReflectionTrait<TypeName>) \
.set_creator([](const std::string&) -> ObjectPtr<Object> { \
return ::tvm::runtime::make_object<TypeName>(); \
})
TVM_REGISTER_NODE_TYPE
宏用于在 C++ 中注册节点类型。
首先,TVM_REGISTER_NODE_TYPE(TypeName)
宏定义了函数调用,该函数调用了 TVM_REGISTER_OBJECT_TYPE(TypeName)
和 TVM_REGISTER_REFLECTION_VTABLE(TypeName, ::tvm::detail::ReflectionTrait<TypeName>)
两个函数。
TVM_REGISTER_OBJECT_TYPE(TypeName)
函数用于注册对象类型,将给定的类型名称TypeName
与相应的对象类型关联起来。TVM_REGISTER_REFLECTION_VTABLE(TypeName, ::tvm::detail::ReflectionTrait<TypeName>)
函数用于注册反射虚函数表(vtable),将给定的类型名称TypeName
与相应的反射虚函数表关联起来。这个虚函数表中包含了该类型的反射方法。接下来,
.set_creator([](const std::string&) -> ObjectPtr<Object> {...})
是可选的设置函数,用于指定如何创建该类型的对象实例。在这个例子中,使用了 lambda 表达式作为创建函数,它接受字符串参数,并返回新创建的TypeName
类型的对象实例。
综上所述,这段代码的作用是注册节点类型,并提供了创建该类型对象实例的方法。
register_object
示例#
在 src/tvm_ext.cc
中定义 test.BaseObj
:
#include <string.h>
#include <tvm/runtime/object.h>
#include <tvm/node/reflection.h>
namespace tvm {
namespace runtime {
class TestNode :public Object {
public:
// 对象字段
std::string name;
// 对象属性
static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
static constexpr const char* _type_key = "app.TestNode";
// 告诉 TVM 编译器,TestNode 类是 Object 类的子类,
// 并且需要在编译时进行一些特殊的处理。
TVM_DECLARE_BASE_OBJECT_INFO(TestNode, Object);
void VisitAttrs(AttrVisitor* v) {
v->Visit("name", &name);
}
};
TVM_REGISTER_NODE_TYPE(TestNode); // 注册节点类型
}
}
在 Python 端调用:
import tvm
from tvm.runtime import Object
from tvm._ffi.base import _LIB
import ctypes
# _LIB.TVMObjectTypeKey2Index
def load_dll(lib_path="lib/libtvm_ext.so"):
"""加载库,函数将被注册到 TVM"""
# 作为全局加载,这样全局 extern symbol 对其他 dll 是可见的。
# curr_path = f"{ROOT}/"
lib = ctypes.CDLL(lib_path, ctypes.RTLD_GLOBAL)
return lib
load_dll("./libs/libtvm_ext.so")
<CDLL './libs/libtvm_ext.so', handle 3b36bf0 at 0x7f1e3c4f68f0>
node = tvm.ir.make_node("app.TestNode", name="A")
node
app.TestNode(0x46bcb20)
或者:
@tvm._ffi.register_object("app.TestNode")
class TestNode(Object):
def __init__(self, handle):
"""Initialize the function with handle
Parameters
----------
handle : SymbolHandle
the handle to the underlying C++ Symbol
"""
super().__init__(handle)
self.handle = handle
如果想要改变 node
实例的显示内容,也可以在 C++ 端写入:
#include <tvm/node/repr_printer.h>
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<TestNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* op = static_cast<const TestNode*>(ref.get());
p->stream << "Test(";
p->stream << "name=" << op->name<< ", ";
p->stream << ")";
});
@tvm._ffi.register_object("app.TestNode")
class TestNode(Object):
def __init__(self, handle):
"""Initialize the function with handle
Parameters
----------
handle : SymbolHandle
the handle to the underlying C++ Symbol
"""
super().__init__(handle)
self.handle = handle
node = tvm.ir.make_node("app.TestNode", name="A")
node
Test(name=A, )
或者直接在 Python 端改写:
@tvm._ffi.register_object("app.TestNode")
class TestNode(Object):
def __init__(self, handle):
"""Initialize the function with handle
Parameters
----------
handle : SymbolHandle
the handle to the underlying C++ Symbol
"""
super().__init__(handle)
self.handle = handle
def __repr__(self):
return f"{self.__class__.__name__}_{self.name}"
node = tvm.ir.make_node("app.TestNode", name="A")
node
TestNode_A