向 Relay 添加编译器 Pass#
编译器 pass 是扩展 Relay 功能集以及对 Relay 程序执行优化的主要接口。通过编写编译器 pass,您可以根据目标修改抽象语法树(AST)或收集有关 AST 的信息。事实上,Relay 中一些最重要的内置功能(例如自动微分和类型推断)也不过是“标准”的编译器 pass。
从高层次来看,编写 pass 有两个关键组成部分:
创建一个或多个遍历程序的 C++ 类
将遍历实现及其元数据封装在 pass 管理器 API 中,以便与 Pass 基础设施 无缝对接
首先,将概述编写编译器 pass 的关键机制。接着,将通过 Relay 中的常量折叠 pass 这一具体示例进行详细讲解。
AST遍历器#
用于遍历 Relay 程序的基础类是 ExprFunctor
。它提供的公共接口是 VisitExpr
方法,该方法接受表达式以及零个或多个参数,并返回某种类型的实例。当您扩展此类时,可以通过为每种表达式类型重写 VisitExpr_
的实现来定义 AST 遍历模式。
VisitExpr
和 VisitExpr_
之间的关系与调度有关。每个 VisitExpr_
定义都针对特定类型的表达式,但您并不总是知道将要访问哪种节点类型。为了解决这个问题,ExprFunctor
提供了 VisitExpr
函数,它从给定的表达式路由到处理它的 VisitExpr_
情况。尽管 C++ 已经提供了动态调度,ExprFunctor
定义了自己的虚函数表(vtable),VisitExpr
使用它。通过定义我们自己的虚函数表,可以更好地控制调度。例如,如果想定义 PrintVisitor
遍历器,在每次访问之前打印“Here”,可以重写 VisitExpr
:
void PrintVisitor::VisitExpr(const Expr& expr) {
std::cout << "Here" << std::endl;
ExprFunctor::VisitExpr(expr);
}
ExprFunctor
本身是非常通用的类,这就是为什么您通常会扩展 ExprVisitor
或 ExprMutator
。这些类扩展了 ExprFunctor
,并为每种表达式类型提供了 VisitExpr_
的默认实现,这些实现捕获了常见的遍历模式。拥有这些默认实现意味着只需要为希望行为不同的表达式类型提供重写实现。将在接下来的部分中分别描述每个子类。
表达式访问器#
ExprVisitor
用于不修改程序而是执行程序分析并收集信息的传递。使用此类时,VisitExpr
及其私有对应方法不返回任何内容。该类提供的 VisitExpr_
实现仅访问表达式的所有字段。下面展示了 IfNode
的默认实现。
void ExprVisitor::VisitExpr_(const IfNode* op) {
this->VisitExpr(op->cond);
this->VisitExpr(op->true_branch);
this->VisitExpr(op->false_branch);
}
请注意,在这里调用的是 VisitExpr
而不是 VisitExpr_
,因此可以使用 ExprFunctor
中的虚函数表进行路由。
现在,如果想编写类 CallChecker
来检查程序中是否出现任何函数调用,只需要扩展 ExprVisitor
并定义以下 VisitExpr_
方法:
void VisitExpr_(const CallNode* n) final {
result_ = true;
}
其中 result_
是字段。在这种情况下,不需要进一步递归访问 CallNode
的字段,因为 result_
已经为真,现在知道原始表达式包含调用。为了使这个访问器可用,将提供以下公共方法:
bool Check(const Expr& expr) final {
result_ = false;
VisitExpr(expr);
return result_;
}
这就是所需要的。通常,会定义公共接口,在调用顶层递归之前执行一些簿记工作。当然,还可以通过创建独立的函数来进一步封装API,该函数创建 CallChecker
实例并调用其上的 Check
方法,但关键是已经以极少的努力实现了目标。
表达式变换器#
ExprMutator
用于以某种方式变换程序的传递。使用此类时,VisitExpr
及其私有对应方法返回 Expr
。该类提供的默认 VisitExpr_
实现访问表达式的所有字段,并将这些字段设置为访问它们的结果。下面展示了 TupleGetItemNode
的默认实现。
Expr ExprMutator::VisitExpr_(const TupleGetItemNode* g) {
auto t = this->Mutate(g->tuple);
if (g->tuple == t) {
return GetRef<Expr>(g);
} else {
return TupleGetItem(t, g->index);
}
}
这里有几件事需要注意。首先,Mutate
是 ExprMutator
中 VisitExpr
的别名。其次,只有在调用 Mutate
修改了 tuple
字段时,才返回新节点。这种更新方法称为函数式更新,这样做可以避免不必要的分配。
ExprMutator
有 ExprVisitor
没有的特性,那就是内置的 memo_
字段,用于缓存结果。ExprMutator
拥有记忆器(memoizer)是合理的,因为知道缓存的是哪种类型的结果(即 Expr
),而 ExprVisitor
的访问方法不返回任何内容。通常,当需要在 ExprVisitor
的子类中缓存结果时,需要自己定义缓存。
现在,如果想编写类 IfCollapser
,用其真分支替换每个 if 语句,将为 IfNode
重写 VisitExpr_
:
Expr ExprMutator::VisitExpr_(const IfNode* op) {
return this->Mutate(op->true_branch);
}
请注意,返回的表达式不一定是 IfNode
,这是可以的,因为返回类型是 Expr
。现在,创建公共接口:
Expr CollapseIfs(const Expr& expr) final {
return this->Mutate(expr);
}
使用这个变换器,不需要进行任何簿记工作,但仍然希望遵循将描述性方法作为接口的惯例。
示例:常量折叠#
为了更好地理解编写 pass 的过程,将参考常量折叠 pass(位于 src/relay/transforms/fold_constant.cc),因为它是相对简单的 pass,结合了两种类型的遍历。
常量折叠涉及评估程序中仅涉及常量值的表达式,然后用评估结果替换这些表达式。此 pass 的目标是提前进行所有可能的计算。为了实现这一点,常量折叠 pass 使用了访问者(ConstantChecker
)和变换器(ConstantFolder
)。
ConstantChecker
访问者#
此访问者用于检查表达式是否为常量。在 Relay 中,将表达式定义为常量,如果它是 ConstantNode
,或者它是仅包含常量字段的 TupleNode
。
使用 memo_
字段来从节点映射到它们是否为常量,并缓存这些结果。以下是 ConstantChecker
中的 VisitExpr_
定义。
void VisitExpr_(const ConstantNode* n) final {
memo_[GetRef<Constant>(n)] = true;
}
void VisitExpr_(const TupleNode* n) final {
bool result = true;
for (const auto& field : n->fields) {
if (!Check(field)) {
result = false;
break;
}
}
memo_[GetRef<Tuple>(n)] = result;
}
用于协调这些定义的簿记是 Check
方法,它返回给定的表达式是否被视为常量。
bool Check(const Expr& expr) {
const auto it = memo_.find(expr);
if (it != memo_.end())
return it->second;
VisitExpr(expr);
return memo_[expr];
}
不会为遇到的每个节点修改 memo_
;相反,只在遇到的节点可能为常量时修改 memo_
。然后,当 memo_
中不包含 expr
时,依赖于默认值为 false。
ConstantFolder
变换器#
这个变换器执行了常量折叠传递的大部分工作,并在内部使用了 ConstantChecker
。在 Relay 中,有三种节点类型参与了常量折叠:LetNode
、TupleItemGetNode
和 CallNode
。在接下来的段落中,将解释每种节点在此传递中的作用。
Expr VisitExpr_(const LetNode* op) final {
Expr value = this->Mutate(op->value);
if (value.as<ConstantNode>()) {
memo_[op->var] = value;
return this->Mutate(op->body);
} else {
Var var = Downcast<Var>(this->Mutate(op->var));
Expr body = this->Mutate(op->body);
if (var.same_as(op->var) &&
value.same_as(op->value) &&
body.same_as(op->body)) {
return GetRef<Expr>(op);
} else {
return Let(var, value, body);
}
}
}
在 LetNode
的情况下,首先尝试对表达式中绑定的值进行常量折叠。如果能够成功折叠,那么会填充 memo_
并返回访问主体(body)的结果——本质上,是将绑定的值传播到主体中的使用位置。如果我们无法对绑定的值进行常量折叠,则会模拟默认的实现方式。
Expr VisitExpr_(const TupleGetItemNode* op) final {
Expr res = ExprMutator::VisitExpr_(op);
op = res.as<TupleGetItemNode>();
if (const auto* tuple = op->tuple.as<TupleNode>()) {
return tuple->fields[op->index];
} else {
return res;
}
}
在 TupleItemGetNode
的情况下,会检查 op->tuple
字段是否为 TupleNode
。如果是,会将元组获取算子替换为由 op->index
指向的元组字段。需要进行检查的原因是,op->tuple
可能本身不是元组,但其求值结果可能是元组。
Expr VisitExpr_(const CallNode* call) final {
static auto op_stateful = Op::GetAttrMap<TOpIsStateful>("TOpIsStateful");
Expr res = ExprMutator::VisitExpr_(call);
call = res.as<CallNode>();
// We don't constant fold function with zero arguments.
// This is a heuristic that is useful.
// For example it is harmful to fold ones(shape=(4, 5)).
if (call->args.size() == 0) return res;
const OpNode* op = call->op.as<OpNode>();
if (op == nullptr) return res;
// skip stateful ops.
if (op_stateful.get(GetRef<Op>(op), false)) return res;
bool all_const_args = true;
for (Expr arg : call->args) {
if (!checker_.Check(arg)) {
all_const_args = false;
}
}
if (all_const_args) {
return ConstEvaluate(res);
} else {
return res;
}
}
在 CallNode
的情况下,首先使用 ExprMutator
的 VisitExpr_
方法来访问调用节点,这会对调用的所有字段进行常量折叠。使用 ExprMutator::VisitExpr_
而不是 VisitExpr
,是因为希望绕过虚函数表(vtable,以避免无限循环)并使用 ExprMutator
提供的默认实现。然后,仅在所有参数都是常量时(使用 ConstantChecker
)才对调用进行求值。对调用求值会生成 值,因此使用辅助方法 ValueToExpr
,将求值后的表达式重新放回抽象语法树(AST)中。
现在,为常量折叠器构建了更便捷的接口 FoldConstant
。FoldConstant
是独立于 ConstantFolder
类之外的函数,它接收表达式并在内部创建并使用 ConstantFolder
实例(完整定义可以在 src/relay/transforms/fold_constant.cc 中找到)。
向 Pass 管理器注册 Pass#
注意:有关此主题的更多具体细节,请参阅 :ref:`pass-infra` 的文档。
编写完 AST 遍历器后,可以通过以下代码将 Pass 注册为 TVM API 的端点:
namespace transform {
Pass FoldConstant() {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) {
return Downcast<Function>(FoldConstant(f));
};
return CreateFunctionPass(pass_func, 2, "FoldConstant", {});
}
} // namespace transform
如果将由上述代码生成的 Pass
对象传递给 Pass 基础设施,它将确保 AST 遍历应用于给定 Relay 模块中的每个函数,这是常量折叠 Pass 所期望的行为(它应尽可能折叠所有常量)。
函数 CreateFunctionPass
允许注册 Pass 的优化级别(在本例中为2),该级别可用于根据 Pass 的通用功能、Pass 的名称以及 Pass 的任何依赖项将其分组。Pass 的依赖项以列表形式给出,这些依赖项是运行当前 Pass 所必需的其他 Pass 的结果。FoldConstant
没有任何依赖项,但许多 Relay Pass 确实依赖于类型信息,因此 InferType
是常见的依赖项;其他 Pass 可能依赖于通过 ToANormalForm
Pass将程序转换为 A 范式(A-normal form)。
请注意,PassContext
对象包含了 Pass 用于错误报告和配置选项的信息;FoldConstant
不需要这些信息,但其他 Pass 可能会引用它们的 PassContext
对象。
现在可以通过 Pass 基础设施调用该 Pass,不过最好也为该 Pass 添加 Python 绑定,如以下代码片段所示:
TVM_REGISTER_GLOBAL("relay._transform.FoldConstant")
.set_body_typed(FoldConstant);
一旦以上述方式定义了 Pass
对象,就可以使用 Pass 基础设施的 Sequential
构造来调用它们。Sequential
接受 Pass 列表,并按顺序将它们应用于 Relay 模块,从而获得变换后的模块。例如,以下代码将 FoldConstant
和 ToANormalForm
Pass依次应用于 mod
中的每个函数,并生成新模块。
seq = transform.Sequential([
relay.transform.FoldConstant(),
relay.transform.ToANormalForm()
])
new_mod = seq(mod)
有关注册的更多详细信息可以在 TVM运行时系统 中找到,有关 Pass 管理器接口的更多信息可以在 Pass 基础设施 中找到。Relay 的标准 Pass 列在 include/tvm/relay/transform.h 中,并在 src/relay/transforms/ 中实现。