env_func
#
源码:tvm/src/ir/env_func.cc
EnvFuncNode
和 EnvFunc
#
/*!
* \brief A serializable function backed by TVM's global environment.
*
* This is a wrapper to enable serializable global PackedFunc.
* An EnvFunc is saved by its name in the global registry
* under the assumption that the same function is registered during load.
* \sa EnvFunc
*/
class EnvFuncNode : public Object {
public:
/*! \brief Unique name of the global function */
String name;
/*! \brief The internal packed function */
runtime::PackedFunc func;
/*! \brief constructor */
EnvFuncNode() {}
void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); }
bool SEqualReduce(const EnvFuncNode* other, SEqualReducer equal) const {
// name uniquely identifies the env function.
return name == other->name;
}
void SHashReduce(SHashReducer hash_reduce) const {
// Name uniquely identifies the env function.
hash_reduce(name);
}
static constexpr const char* _type_key = "EnvFunc";
static constexpr bool _type_has_method_sequal_reduce = true;
static constexpr bool _type_has_method_shash_reduce = true;
TVM_DECLARE_FINAL_OBJECT_INFO(EnvFuncNode, Object);
};
EnvFuncNode
是继承自 Object
的类,它包含字符串类型的成员变量 name
和 runtime::PackedFunc
类型的成员变量 func
。这个类的主要目的是作为 TVM 全局环境的包装器,使得函数可以被序列化。在加载时,通过名称在全局注册表中查找相同的函数。此外,它还提供了一些方法来访问和操作这些成员变量。例如,VisitAttrs
方法允许访问 name
属性,而 SEqualReduce
和 SHashReduce
方法则用于比较两个 EnvFuncNode
对象是否相等以及计算它们的哈希值。在类的声明中,还使用了一些宏来定义一些常量和类型信息。例如,_type_key
常量被定义为 "EnvFunc"
,表示这个类的类型键; _type_has_method_sequal_reduce
和 _type_has_method_shash_reduce
常量被定义为 true
,表示这个类支持相等性和哈希性的计算方法。最后,TVM_DECLARE_FINAL_OBJECT_INFO
宏用于声明这个类的最终对象信息。
/*!
* \brief Managed reference to EnvFuncNode.
* \sa EnvFuncNode
*/
class EnvFunc : public ObjectRef {
public:
EnvFunc() {}
explicit EnvFunc(ObjectPtr<Object> n) : ObjectRef(n) {}
/*! \return The internal global function pointer */
const EnvFuncNode* operator->() const { return static_cast<const EnvFuncNode*>(get()); }
/*!
* \brief Invoke the function.
* \param args The arguments
* \returns The return value.
*/
template <typename... Args>
runtime::TVMRetValue operator()(Args&&... args) const {
const EnvFuncNode* n = operator->();
ICHECK(n != nullptr);
return n->func(std::forward<Args>(args)...);
}
/*!
* \brief Get a global function based on the name.
* \param name The name of the global function.
* \return The created global function.
* \note The function can be unique
*/
TVM_DLL static EnvFunc Get(const String& name);
/*! \brief specify container node */
using ContainerType = EnvFuncNode;
};
EnvFunc
是 EnvFuncNode
的引用类型,它提供了一种方法来调用内部存储的函数。这个类继承自 ObjectRef
类,并提供了以下功能:
构造函数:
EnvFunc()
和explicit EnvFunc(ObjectPtr<Object> n) : ObjectRef(n) {}
。这两个构造函数分别用于创建空的EnvFunc
对象和用给定的ObjectPtr<Object>
初始化的EnvFunc
对象。获取内部全局函数指针:
const EnvFuncNode* operator->() const { return static_cast<const EnvFuncNode*>(get()); }
。这个方法返回指向内部全局函数指针的常量指针。调用函数:
template <typename... Args> runtime::TVMRetValue operator()(Args&&... args) const
。这个方法接受一系列参数,并使用这些参数调用内部全局函数。它返回内部全局函数的返回值。根据名称获取全局函数:
TVM_DLL static EnvFunc Get(const String& name);
。这个方法根据给定的名称在全局环境中查找并返回对应的全局函数。指定容器节点:
using ContainerType = EnvFuncNode;
。这行代码声明了EnvFunc
类可以作为EnvFuncNode
类型的容器节点。
总的来说,EnvFuncNode
和 EnvFunc
提供了一种机制,可以将 TVM 中的函数封装为可序列化的全局环境,并提供了方便的方法来调用这些函数。
TypedEnvFunc
#
/*!
* \brief Please refer to \ref TypedEnvFuncAnchor "TypedEnvFunc<R(Args..)>"
*/
template <typename FType>
class TypedEnvFunc;
/*!
* \anchor TypedEnvFuncAnchor
* \brief A typed version of EnvFunc.
* It is backed by a GlobalFuncNode internally.
*
* \tparam R The return value of the function.
* \tparam Args The argument signature of the function.
* \sa EnvFunc
*/
template <typename R, typename... Args>
class TypedEnvFunc<R(Args...)> : public ObjectRef {
public:
/*! \brief short hand for this function type */
using TSelf = TypedEnvFunc<R(Args...)>;
TypedEnvFunc() {}
explicit TypedEnvFunc(ObjectPtr<Object> n) : ObjectRef(n) {}
/*!
* \brief Assign global function to a TypedEnvFunc
* \param other Another global function.
* \return reference to self.
*/
TSelf& operator=(const EnvFunc& other) {
ObjectRef::operator=(other);
return *this;
}
/*! \return The internal global function pointer */
const EnvFuncNode* operator->() const { return static_cast<const EnvFuncNode*>(get()); }
/*!
* \brief Invoke the function.
* \param args The arguments
* \returns The return value.
*/
R operator()(Args... args) const {
const EnvFuncNode* n = operator->();
ICHECK(n != nullptr);
return runtime::detail::typed_packed_call_dispatcher<R>::run(n->func,
std::forward<Args>(args)...);
}
/*! \brief specify container node */
using ContainerType = EnvFuncNode;
};
这段代码定义了名为 TypedEnvFunc
的模板类,它是对 EnvFunc
类的泛型版本。TypedEnvFunc
类的主要目的是将全局函数封装为类型安全的函数对象。
TypedEnvFunc
类有两个模板参数:R
和 Args
。其中,R
表示函数的返回值类型,Args
表示函数的参数类型。TypedEnvFunc<R(Args...)>
表示接受 Args...
类型参数并返回 R
类型的函数对象。
TypedEnvFunc
类继承自 ObjectRef
类,因此它具有引用计数功能。它提供了一些成员函数,如 operator=
、operator->
和 operator()
,分别用于赋值、获取内部全局函数指针和调用函数。
在 operator()
函数中,首先通过 operator-()
获取内部全局函数指针,然后使用 runtime::detail::typed_packed_call_dispatcher<R>::run()
函数调用全局函数,并将结果返回。
此外,TypedEnvFunc
类还定义了名为 ContainerType
的类型别名,用于指定容器节点类型为 EnvFuncNode
。
EnvFunc
的实现#
/*!
* \file env_func.cc
*/
#include <tvm/ir/env_func.h>
#include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h>
namespace tvm {
using runtime::PackedFunc;
using runtime::TVMArgs;
using runtime::TVMRetValue;
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<EnvFuncNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const EnvFuncNode*>(node.get());
p->stream << "EnvFunc(" << op->name << ")";
});
ObjectPtr<Object> CreateEnvNode(const std::string& name) {
auto* f = runtime::Registry::Get(name);
ICHECK(f != nullptr) << "Cannot find global function \'" << name << '\'';
ObjectPtr<EnvFuncNode> n = make_object<EnvFuncNode>();
n->func = *f;
n->name = name;
return n;
}
EnvFunc EnvFunc::Get(const String& name) { return EnvFunc(CreateEnvNode(name)); }
TVM_REGISTER_GLOBAL("ir.EnvFuncGet").set_body_typed(EnvFunc::Get);
TVM_REGISTER_GLOBAL("ir.EnvFuncCall").set_body([](TVMArgs args, TVMRetValue* rv) {
EnvFunc env = args[0];
ICHECK_GE(args.size(), 1);
env->func.CallPacked(TVMArgs(args.values + 1, args.type_codes + 1, args.size() - 1), rv);
});
TVM_REGISTER_GLOBAL("ir.EnvFuncGetPackedFunc").set_body_typed([](const EnvFunc& n) {
return n->func;
});
TVM_REGISTER_NODE_TYPE(EnvFuncNode)
.set_creator(CreateEnvNode)
.set_repr_bytes([](const Object* n) -> std::string {
return static_cast<const EnvFuncNode*>(n)->name;
});
} // namespace tvm
这段代码用于处理环境函数。环境函数是一种特殊类型的函数,它们在运行时被调用,而不是在编译时。
代码中定义了两个主要的函数:CreateEnvNode
和 EnvFunc::Get
。
CreateEnvNode
函数接受字符串参数 name
,这个字符串应该是全局函数的名称。然后,它从注册表中获取这个函数,并创建新的 EnvFuncNode
对象,将这个函数和它的名称存储在这个对象中。最后,它返回这个新创建的对象。
EnvFunc::Get
函数接受字符串参数 name
,并使用 CreateEnvNode
函数来创建对应的 EnvFuncNode
对象。然后,它返回这个新创建的对象。
此外,代码还注册了几个全局函数,包括 ir.EnvFuncGet
、ir.EnvFuncCall
和 ir.EnvFuncGetPackedFunc
。这些函数分别用于获取环境函数、调用环境函数和获取环境函数的打包函数。
最后,代码还注册了节点类型 EnvFuncNode
,并设置了它的创建函数和表示函数。创建函数是 CreateEnvNode
,表示函数是字符串,表示环境函数的名称。