node functor#
源码:tvm/include/tvm/node/functor.h
NodeFunctor
#
NodeFunctor
的模板类,它用于根据第一个参数的类型动态分派函数。这个类在构造基于 AST/IR 节点类型的多态分派时非常有用。
/*!
* \brief A dynamically dispatched functor on the type of the first argument.
*
* This is a class that is useful to construct polymorphic dispatching
* base on the AST/IR node's type.
*
* \code
* NodeFunctor<std::string (const ObjectRef& n, std::string prefix)> tostr;
* tostr.set_dispatch<Add>([](const ObjectRef& op, std::string prefix) {
* return prefix + "Add";
* });
* tostr.set_dispatch<IntImm>([](const ObjectRef& op, std::string prefix) {
* return prefix + "IntImm"
* });
*
* Expr x = make_const(1);
* Expr y = x + x;
* // dispatch to IntImm, outputs "MyIntImm"
* LOG(INFO) << tostr(x, "My");
* // dispatch to IntImm, outputs "MyAdd"
* LOG(INFO) << tostr(y, "My");
* \endcode
*
* \tparam FType function signiture
* This type if only defined for FType with function signature
*/
template <typename FType>
class NodeFunctor;
template <typename R, typename... Args>
class NodeFunctor<R(const ObjectRef& n, Args...)> {
private:
/*! \brief internal function pointer type */
typedef R (*FPointer)(const ObjectRef& n, Args...);
/*! \brief refer to itself. */
using TSelf = NodeFunctor<R(const ObjectRef& n, Args...)>;
/*! \brief internal function table */
std::vector<FPointer> func_;
public:
/*! \brief the result type of this functor */
using result_type = R;
/*!
* \brief Whether the functor can dispatch the corresponding Node
* \param n The node to be dispatched
* \return Whether dispatching function is registered for n's type.
*/
bool can_dispatch(const ObjectRef& n) const {
uint32_t type_index = n->type_index();
return type_index < func_.size() && func_[type_index] != nullptr;
}
/*!
* \brief invoke the functor, dispatch on type of n
* \param n The Node argument
* \param args The additional arguments
* \return The result.
*/
R operator()(const ObjectRef& n, Args... args) const {
ICHECK(can_dispatch(n)) << "NodeFunctor calls un-registered function on type "
<< n->GetTypeKey();
return (*func_[n->type_index()])(n, std::forward<Args>(args)...);
}
/*!
* \brief set the dispatcher for type TNode
* \param f The function to be set.
* \tparam TNode the type of Node to be dispatched.
* \return reference to self.
*/
template <typename TNode>
TSelf& set_dispatch(FPointer f) { // NOLINT(*)
uint32_t tindex = TNode::RuntimeTypeIndex();
if (func_.size() <= tindex) {
func_.resize(tindex + 1, nullptr);
}
ICHECK(func_[tindex] == nullptr) << "Dispatch for " << TNode::_type_key << " is already set";
func_[tindex] = f;
return *this;
}
/*!
* \brief unset the dispatcher for type TNode
*
* \tparam TNode the type of Node to be dispatched.
* \return reference to self.
*/
template <typename TNode>
TSelf& clear_dispatch() { // NOLINT(*)
uint32_t tindex = TNode::RuntimeTypeIndex();
ICHECK_LT(tindex, func_.size()) << "clear_dispatch: index out of range";
func_[tindex] = nullptr;
return *this;
}
};
NodeFunctor
类有两个模板参数:R
表示返回值类型,Args...
表示函数的其他参数类型。它的内部包含函数指针类型 FPointer
,用于存储指向特定类型的函数的指针。此外,还有名为 func_
的内部向量,用于存储这些函数指针。
NodeFunctor
类提供了以下成员函数:
can_dispatch(const ObjectRef& n) const
:检查是否可以对给定的节点进行分派。如果节点的类型索引小于func_
的大小且对应的函数指针不为空,则返回true
。operator()(const ObjectRef& n, Args... args) const
:调用分派函数。首先检查是否可以对给定的节点进行分派,然后使用节点的类型索引从func_
中获取相应的函数指针,并调用该函数。set_dispatch(FPointer f)
:为特定类型的节点设置分派函数。首先计算节点类型的运行时类型索引,然后调整func_
的大小以容纳新的函数指针(如果需要),并将新函数指针设置为给定的函数。clear_dispatch()
:清除特定类型的节点的分派函数。首先计算节点类型的运行时类型索引,然后将对应的函数指针设置为nullptr
。
NodeFunctor<std::string (const ObjectRef& n, std::string prefix)> tostr;
tostr.set_dispatch<Add>([](const ObjectRef& op, std::string prefix) {
return prefix + "Add";
});
tostr.set_dispatch<IntImm>([](const ObjectRef& op, std::string prefix) {
return prefix + "IntImm"
});
Expr x = make_const(1);
Expr y = x + x;
// dispatch to IntImm, outputs "MyIntImm"
LOG(INFO) << tostr(x, "My");
// dispatch to IntImm, outputs "MyAdd"
LOG(INFO) << tostr(y, "My");
这段代码定义了名为 tostr
的 NodeFunctor
对象,该对象用于将不同类型的节点转换为字符串。NodeFunctor
是模板类,接受函数签名作为参数,该函数签名表示如何将节点转换为字符串。
在这段代码中,NodeFunctor
的函数签名为std::string (const ObjectRef& n, std::string prefix)
,表示它接受 ObjectRef
类型的节点和字符串前缀,并返回字符串。
接下来,使用 set_dispatch
方法为 NodeFunctor
设置两个分派函数。第一个分派函数处理 Add
类型的节点,它将节点转换为字符串并将前缀添加到字符串末尾。第二个分派函数处理 IntImm
类型的节点,它也将节点转换为字符串并将前缀添加到字符串末尾。
然后,创建两个表达式 x
和 y
,其中 x
是常量节点,值为 1
。通过将这两个表达式传递给 tostr
对象,可以将其转换为字符串。由于 x
是 Add
类型的节点,因此调用 tostr(x, "My")
时,将调用第一个分派函数,输出结果为 "MyAdd"
。同样,由于 y
是 IntImm
类型的节点,因此调用 tostr(y, "My")
时,将调用第二个分派函数,输出结果为 "MyIntImm"
。
TVM_STATIC_IR_FUNCTOR
#
#define TVM_REG_FUNC_VAR_DEF(ClsName) static TVM_ATTRIBUTE_UNUSED auto& __make_functor##_##ClsName
这段代码是宏定义,用于生成一个名为 __make_functor##_##ClsName
的函数对象。这个函数对象是静态的(static)并且具有 TVM_ATTRIBUTE_UNUSED
属性,表示它不会被使用。
解析如下:
#define
是 C/C++ 预处理器指令,用于定义宏。TVM_REG_FUNC_VAR_DEF(ClsName)
是宏的名称,其中ClsName
是参数,表示类名。static TVM_ATTRIBUTE_UNUSED auto& __make_functor##_##ClsName
是宏的定义部分。static
表示这是静态成员函数。TVM_ATTRIBUTE_UNUSED
是属性,表示该变量或函数未被使用,编译器不会发出警告。auto&
表示返回值类型为引用到自动类型的变量。__make_functor##_##ClsName
是生成的函数对象的名称,其中##
是连接符,用于将两个字符串连接在一起。
综上所述,这段代码的作用是定义名为 __make_functor##_##ClsName
的静态函数对象,该函数对象具有 TVM_ATTRIBUTE_UNUSED
属性,表示它不会被使用。
/*!
* \brief Useful macro to set NodeFunctor dispatch in a global static field.
*
* \code
* // Use NodeFunctor to implement ReprPrinter similar to Visitor Pattern.
* // vtable allows easy patch of new Node types, without changing
* // interface of ReprPrinter.
*
* class ReprPrinter {
* public:
* std::ostream& stream;
* // the dispatch function.
* void print(Expr e) {
* const static FType& f = *vtable();
* f(e, this);
* }
*
* using FType = NodeFunctor<void (const ObjectRef&, ReprPrinter* )>;
* // function to return global function table
* static FType& vtable();
* };
*
* // in cpp/cc file
* ReprPrinter::FType& ReprPrinter::vtable() { // NOLINT(*)
* static FType inst; return inst;
* }
*
* TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
* .set_dispatch<Add>([](const ObjectRef& ref, ReprPrinter* p) {
* auto* n = static_cast<const Add*>(ref.get());
* p->print(n->a);
* p->stream << '+'
* p->print(n->b);
* });
*
*
* \endcode
*
* \param ClsName The name of the class
* \param FField The static function that returns a singleton of NodeFunctor.
*/
#define TVM_STATIC_IR_FUNCTOR(ClsName, FField) \
TVM_STR_CONCAT(TVM_REG_FUNC_VAR_DEF(ClsName), __COUNTER__) = ClsName::FField()
这段代码定义了名为 TVM_STATIC_IR_FUNCTOR
的宏,用于设置 NodeFunctor
的调度。NodeFunctor
是一种用于实现类似于访问者模式的函数对象。
在这段代码中,ReprPrinter
类使用了 NodeFunctor
来实现打印功能。通过使用 vtable
,可以轻松地为新的节点类型添加新的调度函数,而无需更改 ReprPrinter
接口。
ReprPrinter::FType& ReprPrinter::vtable()
函数返回全局函数表。这个函数表是一个静态成员变量,它存储了 NodeFunctor
的实例。
TVM_STATIC_IR_FUNCTOR(ClsName, FField)
宏的作用是将 ClsName
类的 FField
函数作为 NodeFunctor
的调度函数添加到全局函数表中。这样,当调用 print
方法时,会根据节点的类型选择相应的调度函数进行处理。