TVM 节点反射#
tvm/include/tvm/node/reflection.h
是 TVM(Tensor Virtual Machine)库中的一个头文件,用于实现 TVM 中的反射机制。名为 Reflector
的类,它是整个反射机制的核心。Reflector
类的主要作用是通过序列化和反序列化操作,将计算图中的各种节点、参数和数据结构进行转换,以便在不同的硬件平台上进行部署。
Reflector
类中的方法包括:
Reflector::Init()
:初始化Reflector
对象。在构造函数中调用此方法。Reflector::Run()
:执行反射操作。首先对计算图进行序列化,然后根据目标平台对序列化后的数据进行反序列化,最后执行反序列化后的计算图。Reflector::Export()
:导出指定节点的信息。将指定节点的信息导出到一个字符串中。Reflector::Import()
:导入指定节点的信息。从一个字符串中读取节点信息,并将其反序列化为一个ReflectorNode
对象。Reflector::GetAttrs()
:获取指定节点的属性列表。返回一个包含属性名称和值的映射(std::unordered_map<string, AttrValue>
)。Reflector::SetAttrs()
:设置指定节点的属性列表。使用给定的属性值更新节点的属性。Reflector::ResetGraph()
:重置计算图。清除所有节点、参数和数据结构。Reflector::LoadGraph()
:加载计算图。从磁盘或其他存储介质中读取计算图的数据结构,并反序列化为ReflectorNode
对象。Reflector::FindNode()
:查找指定名称的节点。返回一个指向具有指定名称的节点的指针。Reflector::FindOutput(const std::string& name)
:查找具有指定名称的输出节点。返回一个指向具有指定名称的输出节点的指针。Reflector::FindInput(const std::string& name)
:查找具有指定名称的输入节点。返回一个指向具有指定名称的输入节点的指针。Reflector::FindNextNode(const ReflectorNode* node)
:查找给定节点的下一个节点。返回一个指向下一个节点的指针,如果没有找到,则返回nullptr
。Reflector::FindAllNodes(const std::function<bool(const ReflectorNode*)>& filter)
:查找满足给定过滤条件的所有节点。返回一个包含满足条件的节点指针的列表。Reflector::FindSubgraph(const std::vector<const ReflectorNode*>& nodes)
:查找给定节点集合所在的子图。返回一个表示子图的对象,该对象包含了子图中的所有节点和连接关系。Reflector::DumpGraph()
:将计算图以文本形式输出到标准输出(或指定的文件)。
NodeGetAttr
、NodeListAttrNames
和 MakeNode
#
NodeGetAttr
函数用于获取对象的属性值。它接受两个参数:args
和ret
。args
是包含输入参数的数组,ret
是指向返回值的指针。首先,代码检查args[0]
的类型码是否为kTVMObjectHandle
,然后将args[0]
的值转换为Object*
类型。接下来,它调用ReflectionVTable::Global()->GetAttr
函数来获取对象的属性值,并将结果存储在ret
指向的位置。NodeListAttrNames
函数用于列出对象的所有属性名称。它也接受两个参数:args
和ret
。args
是包含输入参数的数组,ret
是指向返回值的指针。首先,代码检查args[0]
的类型码是否为kTVMObjectHandle
,然后将args[0]
的值转换为Object*
类型。接下来,它调用ReflectionVTable::Global()->ListAttrNames
函数来获取对象的属性名称列表,并将其存储在新的std::vector<std::string>
对象中。最后,它创建包装器函数,该函数接受整数参数i
,并根据i
的值返回相应的属性名称或属性名称列表的大小。MakeNode
函数用于创建新的对象。它接受const TVMArgs&
类型的参数args
和指向返回值的指针rv
。首先,代码从args
中提取对象的类型键(type_key
),并创建新的TVMArgs
对象kwargs
,其中包含剩余的参数。然后,它调用ReflectionVTable::Global()->CreateObject
函数来创建新的对象,并将结果存储在rv
指向的位置。
查看对应的 Python 接口示例:
import tvm
# MakeNode -> tvm.ir.make_node
x = tvm.ir.make_node("IntImm", dtype="int32", value=10, span=None)
assert isinstance(x, tvm.tir.IntImm)
assert x.value == 10
其余两个类被打包到 Object
:
tvm.runtime.Object.__getattr__??
Signature: tvm.runtime.Object.__getattr__(self, name)
Docstring: <no docstring>
Source:
def __getattr__(self, name):
# specially check handle since
# this is required for PackedFunc calls
if name == "handle":
raise AttributeError("handle is not set")
try:
return _ffi_node_api.NodeGetAttr(self, name)
except AttributeError:
raise AttributeError(f"{type(self)} has no attribute {name}") from None
File: /media/pc/data/lxw/ai/tvm/xinetzone/__pypackages__/3.10/lib/tvm/runtime/object.py
Type: function
tvm.runtime.Object.__dir__??
Signature: tvm.runtime.Object.__dir__(self)
Docstring: Default dir() implementation.
Source:
def __dir__(self):
class_names = dir(self.__class__)
fnames = _ffi_node_api.NodeListAttrNames(self)
size = fnames(-1)
return sorted([fnames(i) for i in range(size)] + class_names)
File: /media/pc/data/lxw/ai/tvm/xinetzone/__pypackages__/3.10/lib/tvm/runtime/object.py
Type: function