解读 ExprFunctor
#
tvm/include/tvm/relay/expr_functor.h
是名为 expr_functor
的函数访问者(visitor),它具有更强大的动态分派功能,可以定义具有基于第一个参数的类型分派的任意函数签名。
在计算机编程中,访问者模式是一种设计模式,用于处理不同类型的对象结构。通过使用访问者模式,可以将对不同对象的操作集中在一个或多个访问者类中,从而实现统一的接口和逻辑。
ExprFunctor
#
template <typename FType>
class ExprFunctor;
// functions to be overriden.
#define EXPR_FUNCTOR_DEFAULT \
{ return VisitExprDefault_(op, std::forward<Args>(args)...); }
#define RELAY_EXPR_FUNCTOR_DISPATCH(OP) \
vtable.template set_dispatch<OP>([](const ObjectRef& n, TSelf* self, Args... args) { \
return self->VisitExpr_(static_cast<const OP*>(n.get()), std::forward<Args>(args)...); \
});
这段代码是 C++ 模板类的定义,用于实现动态函数对象(functional object),该函数对象可以根据第一个表达式参数的类型进行分派。具体来说,这个类名为 ExprFunctor
,它是一个模板类,使用了类型参数 FType
来表示函数签名。根据注释中的描述,FType
应该具有函数签名 R(const Expr&, Args...)
,其中 R
是返回类型,Expr
是第一个参数的类型,Args
是其他参数的类型。
在代码中,看到两个宏定义:
EXPR_FUNCTOR_DEFAULT
:这是默认的函数体,用于处理没有特定重载版本的函数调用。它使用VisitExprDefault_
函数来处理传入的表达式,并将结果返回。RELAY_EXPR_FUNCTOR_DISPATCH(OP)
:这是用于分发函数调用的宏定义。它使用了虚函数表(vtable)和set_dispatch
方法来实现基于算子(OP)的类型分派。当调用该函数对象时,会根据传入的算子类型选择相应的重载版本进行处理。
ExprFunctor
类的主要目的是在表达式树中进行访问操作。它通过重载 operator()
函数来实现对表达式节点的调用,并通过 VisitExpr
函数来处理不同类型的节点。EXPR_FUNCTOR_DEFAULT
宏用于生成默认的可调用对象,而 RELAY_EXPR_FUNCTOR_DISPATCH
宏用于设置节点分派的函数对象。
在
ExprFunctor
类的实现中,首先定义了私有成员变量vtable
,它是类型为FType
的函数对象。然后,通过调用InitVTable
函数来初始化vtable
。InitVTable
函数使用RELAY_EXPR_FUNCTOR_DISPATCH
宏来设置不同类型节点的分派函数对象。每个分派函数对象都接受ConstantNode
指针作为参数,并返回结果。最后,
ExprFunctor
类的构造函数是虚析构函数,确保当删除ExprFunctor
对象时,能够正确地调用其析构函数。
总的来说,这段代码实现了灵活强大的函数对象,可以在表达式树中进行访问操作,并根据节点的类型选择相应的处理方式。
ExprVisitor
#
ExprVisitor
是 tvm::relay::ExprFunctor
的子类。ExprVisitor
将 Expr
视为数据流图,并且每个 Expr
节点只访问一次。
ExprVisitor
类中包含了多个重载的 VisitExpr
函数,每个函数都接受 const Expr&
类型的参数,用于处理不同类型的 Expr
节点。这些重载函数根据节点的类型调用相应的 VisitExpr_
函数进行处理。除了处理 Expr
节点外,ExprVisitor
还定义了一些其他的虚函数,如 VisitType
、VisitClause
、VisitPattern
和 VisitSpan
,用于处理其他类型的节点。在 ExprVisitor
类中还定义了受保护的成员变量 visit_counter_
,它是无序的哈希表,用于记录每个节点被访问的次数。
MixedModeVisitor
#
MixedModeVisitor
是 tvm::relay::ExprVisitor
的子类。MixedModeVisitor
将 Expr
视为数据流图,并按照后序深度优先搜索(DFS)的顺序进行访问。MixedModeVisitor
提供了与 ExprVisitor
相同的递归 API,并使用递归来遍历 IR 的大多数形式,但在底层,它会展开图中嵌套的数据流区域,并以迭代的方式处理它们,以防止堆栈溢出。
在 MixedModeVisitor
类中还定义了一些受保护的成员变量和函数。其中,visit_limit_
表示允许访问节点的最大次数,通常为 1,有时为 2(例如用于消除死代码),但限制为 10 作为合理性检查。
VisitLeaf
是虚函数,当到达图的叶子节点时调用,以非递归方式应用。CheckVisited
是虚函数,用于确定表达式是否已经被访问过或者需要重新访问。
VisitExpr
函数被声明为 final,以保留数据流区域的调用扩展。它还重载了多个版本的 VisitExpr_
函数,用于处理不同类型的节点。
ExprMutator
#
ExprMutator
类是 tvm::relay::ExprFunctor
的子类。ExprMutator
将 Expr
视为数据流图,并且每个 Expr
只进行一次变更。ExprMutator
类中包含了多个重载的 VisitExpr
函数,每个函数都接受 const Expr&
类型的参数,用于处理不同类型的 Expr
节点。这些重载函数根据节点的类型调用相应的 VisitExpr_
函数进行处理。除了处理 Expr
节点外,ExprMutator
还定义了一些其他的虚函数,如 VisitType
、VisitClause
和 VisitPattern
,用于处理其他类型的节点。
在 ExprMutator
类中还定义了受保护的成员变量 memo_
,它是无序的哈希表,用于记录每个节点被访问的次数。这个哈希表用于实现结果的缓存,以提高后续相同表达式的访问效率。
MixedModeMutator
#
MixedModeMutator
是tvm::relay::ExprMutator
的子类。MixedModeMutator
将 Expr
视为数据流图,并只重写每个 Expr
一次。重写后的结果被缓存在映射中并重复使用,以便数据流上的局部转换保持图结构。
MixedModeMutator
提供了与 ExprMutator
相同的递归 API,并使用递归来遍历IR的大多数形式,但在实际实现中,它会展开图中嵌套的数据流区域,并以迭代的方式处理它们,以防止堆栈溢出。
该类使用了 ExprRewriter
的 Rewrite_
API,以实现递归和非递归行为之间的更清晰的分离。
在 MixedModeMutator
类中还定义了一些受保护的成员变量和函数。其中,pre_
表示是否为预处理模式。
VisitExpr
函数被声明为final,以保留数据流区域重写的调用扩展。它还重载了多个版本的 VisitExpr_
函数,用于处理不同类型的节点。
DispatchVisitExpr
函数是一个虚拟函数,用于分发访问表达式节点的操作。
Rewrite_
函数是用户应该重写的虚函数,用于实现他们的传递。这些重写函数应该能够仅使用原始节点 pre
的数据以及具有修改输入的相同节点 post
进行重写,并且不应递归。
VisitLeaf
和 CheckVisited
是受保护的虚函数,用于在叶子节点上进行处理和检查是否已访问。
ExprRewriter
#
ExprRewriter
类是非迭代式的表达式重写器。
ExprRewriter
提供了重写接口,用于以后序 DFS 顺序修改图。预期是,ExprRewriter
对象将被传递给 PostOrderRewrite
,它将非递归地展开图并调用重写输入。然后,它将传递原始节点(称为 pre
)和使用任何更改的输入重新创建的节点(称为 post
)给 ExprRewriter
。然后,ExprRewriter
可以使用这两个节点中的信息执行更复杂的图重写。
在私有成员中,它定义了类型为 FType
的静态成员变量 vtable
,并通过调用 InitVTable
函数进行初始化。InitVTable
函数返回 lambda 表达式,该表达式调用了 Relay_Expr_Rewriter_Dispatch
宏来设置分派。
在公共成员中,它定义了一个虚析构函数,以及重载的括号运算符 operator()
,该运算符调用了 Rewrite
函数。它还定义了一些可以被子类覆盖的虚函数,这些函数不应递归。
最后,它还定义了一些重写的虚函数,这些函数默认不执行任何操作,但可以在子类中被覆盖以执行更复杂的重写逻辑。
PostOrderRewrite
#
PostOrderRewrite
函数,它执行对图的非递归后序 DFS 遍历,并在输入被重写后调用 ExprRewriter
的 Rewrite
函数。在每次重写调用时,PostOrderRewrite
提供原始节点和具有更改的输入的节点,供 ExprRewriter
使用。
该函数接受两个参数:Expr
类型的 expr
,表示要遍历的表达式;ExprRewriter*
类型的 rewriter
,表示用于重写的表达式重写器。
函数的返回类型是 Expr
,表示经过重写后的表达式。
PostOrderVisit
#
PostOrderVisit
函数,它以后序 DFS 顺序递归地访问 IR(中间表示),并对每个节点应用 fvisit
访问者函数。
该函数接受两个参数:Expr
类型的 node
,表示要访问的 IR 节点;std::function<void(const Expr&)>
类型的 fvisit
,表示要应用的访问者函数。
函数没有返回值。
该函数的具体实现并未给出,但从函数注释中可以了解到它的大致功能和用途。它的作用是按照后序 DFS 顺序递归地访问IR中的每个节点,并对每个节点应用给定的访问者函数 fvisit
。由于每个节点只被访问一次,因此可以确保节点的访问是正确且高效的。
ExpandDataflow
和 ExpandANormalForm
#
ExpandDataflow
是一个模板函数,用于以深度优先顺序遍历一个表达式的 IR(中间表示)数据流区域。它接受四个参数:要遍历的表达式 expr
、一个检查节点是否被访问过的函数 fcheck_visited
、一个访问叶子节点的函数 fvisit_leaf
以及一个扩展表达式的函数 fexpand_expr
。
该函数使用一个栈来管理遍历过程中的数据流节点。它首先将输入表达式压入栈中,然后进入循环,直到栈为空为止。在每次迭代中,它从栈顶取出一个节点,并检查该节点是否满足数据流类型。如果满足,则将该节点的子节点压入栈中;如果不满足或者该节点的所有输入都已经被处理过,则调用 fvisit_leaf
函数访问当前叶子节点。
ExpandDataflow
函数通过模板参数 FCheckVisited
、FVisitLeaf
和 FExpandExpr
来实现重用。这些参数是类型别名,分别对应于检查节点是否被访问过的函数、访问叶子节点的函数和扩展表达式的函数的类型。这样,用户可以根据需要提供不同的实现,以便在不同的场景下进行遍历分析。
ExpandANormalForm
函数是一个辅助函数,用于展开一个正常的 LetNode 表达式。它接受三个参数:要展开的表达式 op
、一个在访问 LetNode 之前执行的函数 pre_visit
和一个在访问LetNode之后执行的函数 post_visit
。
ExpandANormalForm
的作用是在展开表达式之前和之后执行一些额外的操作,例如预处理或后处理。