Pass 基础设施#
Relay 和 TVM IR 都包含一系列优化 Pass,这些 Pass 可以提高模型的性能指标,例如特定设备的平均推理时间、内存占用或功耗。这些优化包括标准优化以及机器学习特定的优化,例如常量折叠、死代码消除、算子布局更改、算子融合、缓冲区处理和循环转换等。每个 Pass 利用在遍历期间和/或之前收集的分析结果,构建为一种 IR 到 IR 的变换。
然而,随着 TVM 的快速发展,对一种更系统、更高效的方式来管理这些 Pass 的需求变得日益明显。此外,通用的框架来管理 TVM 堆栈不同层(例如 Relay 和 TIR)中的 Pass,为开发者快速原型化并将其实现的 Pass 集成到系统中铺平了道路。
本文档描述了这样一种基础设施的设计,它利用了生产编译器管理优化 Pass 的方式以及现代深度学习框架构建层次结构的风格。
例如,许多现有的生产编译器(如 GCC 和 LLVM)使用 Pass 管理器来有效管理 Pass 的执行。最初,由于 Pass 数量较少,管理 Pass 相对简单,但成熟的编译器将包含数百个独立的 Pass。通常,外部用户希望能够在不修改手动编写的 Pass 顺序的情况下,正确调度自定义 Pass。
同样,现代深度学习框架(如 Pytorch 和 MXNet Gluon)也倾向于通过 Sequential 和 Block 分别启用 Pass 风格的层次构建方案。通过这些构造,这些现代框架能够方便地将模块/层添加到它们的容器中,并轻松构建神经网络。
Relay Pass 基础设施的设计很大程度上受到了 LLVM 中使用的分层 Pass 管理器以及流行深度学习框架中使用的块式容器的启发。Pass 基础设施的主要目标包括:
实现更好的优化程序化编排。这使得用户能够灵活地定制和构建自己的优化管道。
提供一种用户友好的方式来调试优化 Pass。
减轻开发者手动和分别解决 Pass 之间依赖关系的负担。
简化开发者实现新 Pass 的过程。例如,允许用户在 Python 中实现 Pass,并让 Pass 基础设施管理其执行。
设计#
专注于用户的扩展便利性,使用户能够快速添加新 Pass 而不损失向后兼容性。该设计包含后端和前端。后端实现了 Pass 基础设施的核心逻辑。前端为用户提供了简单的 API 进行交互,即允许用户快速创建自己的优化管道。
C++ 后端#
提供了 PassInfo
对象来包含 Pass 所需的基本信息。name
是 Pass 的名称,opt_level
表示 Pass 将在哪个优化级别启用,required
表示执行某个 Pass 所需的其他 Pass(更多详细信息请参阅 include/tvm/ir/transform.h)。例如,在 Pass 注册期间(将在后面介绍),Pass 开发者可以指定 Pass 的名称、将在哪个优化级别执行以及/或所需的 Pass。opt_level
可用于帮助 Pass 基础设施识别在用户提供的优化级别下是否需要执行某个 Pass。Pass 基础设施可以使用 required
字段来解决 Pass 之间的依赖关系。
class PassInfoNode : public Object {
String name;
int opt_level;
Array<String> required;
};
PassContext#
PassContext
携带了优化 Pass 所需的有用信息。例如,它包含了错误报告系统,因此优化作者可以提供有关优化失败原因的诊断信息。PassContext
还被设计用来取代旧的 BuildConfig
,后者用于帮助用户配置编译选项,包括优化级别和所需/禁用的Pass等。例如,可能有配置,它在 opt_level=3
下执行所有 Pass,同时使用 PassContext
提供的 disabled_pass=xx
禁用某些 Pass。现在,可以全局获取 opt_level=3
下的所有 Pass,并排除禁用 Pass 列表中的那些 Pass。PassContext
还提供了一种方法来检测所有 Pass。请参阅 Pass 检测 部分。
这个类旨在让用户方便地编写 Python 的 with
语法,以在特定配置下执行优化。此外,用户可以通过 PassContext::Current()
以线程安全的方式获取在某个程序范围内可用的上下文,因为使用了线程本地存储 PassContextThreadLocalStore
来保存创建的 Pass 上下文对象。稍后将提供示例,展示如何使用 C++ 和 Python API 来创建使用 Pass 上下文的编译管道。
class PassContextNode : public Object {
public:
int opt_level{2};
tvm::Array<tvm::Expr> required_pass;
tvm::Array<tvm::Expr> disabled_pass;
mutable Optional<DiagnosticContext> diag_ctx;
Map<String, ObjectRef> config;
Array<instrument::PassInstrument> instruments;
};
class PassContext : public NodeRef {
public:
TVM_DLL static PassContext Create();
TVM_DLL static PassContext Current();
TVM_DLL void InstrumentEnterPassContext();
TVM_DLL void InstrumentExitPassContext();
TVM_DLL bool InstrumentBeforePass(const IRModule& mod, const PassInfo& info) const;
TVM_DLL void InstrumentAfterPass(const IRModule& mod, const PassInfo& info) const;
/* Other fields are omitted. */
private:
// The entry of a pass context scope.
TVM_DLL void EnterWithScope();
// The exit of a pass context scope.
TVM_DLL void ExitWithScope();
// Classes to get the Python `with` like syntax.
friend class tvm::With<PassContext>;
};
struct PassContextThreadLocalEntry {
/*! \brief The default pass context. */
PassContext default_context;
/*! \brief The current pass context. */
std::stack<PassContext> context_stack;
PassContextThreadLocalEntry() {
default_context = PassContext(make_node<PassContextNode>());
}
};
/*! \brief The thread-local store to hold the pass context. */
typedef dmlc::ThreadLocalStore<PassContextThreadLocalEntry>
PassContextThreadLocalStore;
Pass 构造函数#
Pass 基础设施以分层方式设计,可以在不同粒度的 Relay/TIR 程序中工作。引入了纯虚类 PassNode
作为不同优化 Pass 的基类。该类包含几个虚方法,必须由模块、函数或 Pass 序列级别的子类实现。
class PassNode : Object {
virtual PassInfo Info() const = 0;
virtual Module operator()(const IRModule& mod
const PassContext& pass_ctx) const = 0;
};
该函子展示了如何实现 Pass,即它总是在某个上下文下对 IRModule
进行操作。所有 Pass 都以 Module
到 Module
的方式设计。因此,由 Pass 基础设施管理的优化将始终更新整个模块。
已经创建了几个子类来实现不同类型的优化 Pass,例如函数级 Pass、模块级 Pass 和 Pass 序列。每个子类本身都可以充当 Pass 管理器。例如,它们可以收集所需的 Pass 并执行它们,或者根据给定的元数据构建依赖图。它们的完整定义可以在 src/relay/ir/transform.cc 和 src/ir/transform.cc 中找到。
模块级 Pass#
模块级 Pass 主要用于全局和过程间优化(inter-procedural optimizations,简称 IPO),类似于 LLVM 中使用的模块 Pass。Relay 中一些需要模块全局信息的典型 Pass,例如 A 范式转换和 lambda 提升等,都属于这一类。在这个级别,用户甚至可以在模块中添加和/或删除函数。请注意,所有 Pass
class ModulePassNode : PassNode {
PassInfo pass_info;
runtime::TypedPackedFunc<Module(Module, PassContext)> pass_func;
Module operator()(const Module& mod, const PassContext& pass_ctx) const final;
// Other members/methods are omitted
};
pass_info
维护了模块级 Pass 所需的信息。pass_func
描述了实际的优化过程。例如,可能需要对模块执行死代码消除。可以在 pass_func
中实现该算法,并让它在模块上运行。然后,它将删除死代码,包括模块中未使用的函数。请注意,该字段被设计为打包函数,这使得优化可以在 C++ 和 Python 中实现。
函数级 Pass#
函数级 Pass 用于为给定的 Relay/TIR 模块实现各种函数内级别的优化。它从模块的函数列表中一次获取函数进行优化,并生成重写后的 Relay Function
或TIR PrimFunc
。大多数 Pass 可以归类到这一类别中,例如 Relay 中的公共子表达式消除和推理简化,以及 TIR 中的向量化和存储扁平化等。
请注意,此级别 Pass 的范围是 Relay 函数或 TIR 原始函数。因此,不能通过这些 Pass 添加或删除函数,因为它们不了解全局信息。
class FunctionPassNode : PassNode {
PassInfo pass_info;
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func;
Module operator()(const Module& mod, const PassContext& pass_ctx) const final;
bool SkipFunction(const Function& func) const;
// Other members/methods are omitted...
};
pass_info
与我刚刚在模块级 Pass 中描述的内容相同。pass_func
接受函数进行优化,它还需要模块,因为可能用它来报告错误。函数可以用“SkipOptimization”进行注释,以便在优化期间忽略它。
Pass 序列#
SequentialPass
类似于 Pytorch 的 nn.Sequential
,它包含一系列要执行的 Pass。
class SequentialPassNode : PassNode {
PassInfo pass_info;
// Passes need to be executed.
Array<Pass> passes;
bool PassEnabled(const PassInfo& info) const;
Module operator()(const Module& mod, const PassContext& pass_ctx) const final;
};
目前 Relay 中只有少数 Pass 被归入这一类别。例如,FoldScaleAxis
需要在内部调度 ForwardFoldScaleAxis
和 BackwardFoldScaleAxis
。此外,建议首先完成 BackwardFoldScaleAxis
。因此,这个 Pass 是 SequentialPass
的理想候选者。
以下代码展示了如何在顺序 Pass 中调用各个 Pass。本质上,按照它们被添加到 Pass 列表中的顺序依次执行顺序 Pass 中的每个 Pass。
Module SequentialNode::operator()(const Module& module,
const PassContext& pass_ctx) const {
Module mod = module;
for (const Pass& pass : passes) {
ICHECK(pass.defined()) << "Found undefined pass for optimization.";
const PassInfo& pass_info = pass->Info();
if (!PassEnabled(pass_info)) continue;
for (const auto& it : pass_info->required) {
const auto* name = it.as<tvm::ir::StringImm>();
ICHECK(name);
mod = GetPass(name->value)(mod, pass_ctx);
}
mod = pass(mod, pass_ctx);
}
return mod;
}
在调用 Pass 时,首先检查该 Pass 是否启用。这是通过首先检查用户是否显式禁用了该 Pass,然后检查用户是否将其指定为必需 Pass 来完成的。如果仍然无法确定该 Pass 是否启用,将检查其 opt_level
。只有当其优化级别不低于 Pass 上下文中配置的优化级别时,该 Pass 才会被启用并执行。
要执行 Pass,需要首先使用 Pass 名称从 TVM 打包函数注册表中检索已注册的 Pass。这是可能的,因为每个 Pass 都注册了 API 端点,稍后会展示这一点。
Pass GetPass(const std::string& pass_name) {
using tvm::runtime::Registry;
std::string fpass_name = "relay._transform." + pass_name;
const auto* f = Registry::Get(fpass_name);
ICHECK(f != nullptr) << "Cannot find " << fpass_name
<< "to create the pass " << pass_name;
return (*f)();
}
提供了一些辅助函数来创建上述每种类型的 Pass。这些辅助函数也暴露给 Python 前端,以便用户方便地使用 Python API 创建特定的 Pass 对象。
Pass CreateFunctionPass(
const runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)>& pass_func,
int opt_level,
String name,
Array<String> required);
Pass CreatePrimFuncPass(
const runtime::TypedPackedFunc<PrimFunc(PrimFunc, IRModule, PassContext)>& pass_func,
int opt_level,
String name,
Array<String> required);
Pass CreateModulePass(
const runtime::TypedPackedFunc<IRModule(IRModule, PassContext)>& pass_func,
int opt_level,
String name,
Array<String> required);
Pass Sequential(tvm::Array<Pass> passes, PassInfo pass_info);
Pass 注册#
已经介绍了不同级别 Pass 的概念以及用于编译的上下文。看看用户可以如何轻松地注册 Pass 会很有趣。以常量折叠为例。这个 Pass 已经实现,用于折叠 Relay 函数中的常量(可以在 src/relay/transforms/fold_constant.cc 中找到)。
提供了 API 来执行 Expr
到 Expr
的变换。
Expr FoldConstant(const Expr& expr);
为了将这个 Pass 注册到 Pass 基础设施中,首先需要确定该 Pass 将在哪个级别执行。由于常量折叠发生在单个函数上,应该直观地通过 CreateFunctionPass
为其创建 FunctionPass
。pass_func
作为打包函数返回,它在 IRModule 中的每个函数上调用 Expr
到 Expr
的API。{}
表示该 Pass 没有先决条件。否则,Pass 开发者必须识别并列出它们。
同时,Pass API 端点以名称 relay._transform.FoldConstant
注册。因此,这个 Pass 成为注册表中的条目,可以在需要时通过 C++(例如上面的 GetPass
)和 Python 访问。
namespace transform {
Pass FoldConstant() {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(FoldConstant(f));
};
return CreateFunctionPass(pass_func, 2, "FoldConstant", {});
}
TVM_REGISTER_GLOBAL("relay._transform.FoldConstant")
.set_body_typed(FoldConstant);
} // namespace transform
为了允许其他 C++ 模块应用这个 Pass,在 include/tvm/relay/transform.h 中声明了自由函数,如下所示:
TVM_DLL Pass FoldConstant();
Pass 检测#
Pass 检测(Instrument)是一种分析 Pass 本身的机制。例如,可以使用该基础设施来了解 Pass 需要多少时间和内存,或者 Pass 如何变换 IR 模块。
在 PassContext
的生命周期中引入了四个检测点。
TVM_DLL void InstrumentEnterPassContext();
TVM_DLL void InstrumentExitPassContext();
TVM_DLL bool InstrumentBeforePass(const IRModule& mod, const PassInfo& info) const;
TVM_DLL void InstrumentAfterPass(const IRModule& mod, const PassInfo& info) const;
InstrumentEnterPassContext
在进入 PassContext
实例的范围时立即调用。
InstrumentExitPassContext
在离开 PassContext
的范围时调用,或者在 Pass 执行期间发生异常时调用。当检测器被 tvm.transform.PassContext
中的 override_instruments
覆盖时,也会调用此方法。请参阅 在当前 PassContext 中覆盖检测器。
InstrumentBeforePass
在执行前调用。如果 Pass 应该运行,InstrumentAfterPass
在执行后调用。行为如下:
if (pass_ctx.InstrumentBeforePass(ir_module, pass_info)) {
new_ir_module = run_pass(ir_module, pass_ctx);
pass_ctx.InstrumentAfterPass(new_ir_module, pass_info);
return new_ir_module;
}
PassInstrument
接口允许你在上述四个方法中运行任意代码。多个 PassInstrument
实例可以注册到 PassContext
中。PassInstrument
实例按照传递给 PassContext
的 instruments
参数的顺序依次调用。
PassInstrument
提供以下接口:
namespace instrument {
class PassInstrumentNode : public Object {
public:
String name;
virtual void EnterPassContext() const = 0;
virtual void ExitPassContext() const = 0;
virtual bool ShouldRun(const IRModule& mod, const transform::PassInfo& info) const = 0;
virtual void RunBeforePass(const IRModule& mod, const transform::PassInfo& info) const = 0;
virtual void RunAfterPass(const IRModule& mod, const transform::PassInfo& info) const = 0;
/* Other fields are omitted. */
};
class PassInstrument : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(PassInstrument, ObjectRef, PassInstrumentNode);
};
} // namespace instrument
提供了 Python 前端来快速实现 PassInstrument
。请参阅 Pass 检测。
在 PassContext
中,PassInstrument
实例的调用顺序如下:
with PassContext(instruments=[pi]) # pi = a PassInstrument implementation.
pi.EnterPassContext()
if pi.ShouldRun(Pass1):
pi.RunBeforePass()
Pass1()
pi.RunAfterPass()
if pi.ShouldRun(Pass2):
pi.RunBeforePass()
Pass2()
pi.RunAfterPass()
pi.ExitPassContext()
以下是 PassInstrument
接口与 PassContext
方法之间关系的简要介绍。更多详细信息请参阅(src/ir/transform.cc)。
InstrumentEnterPassContext
EnterPassContext()
按照传递给PassContext
的instruments
的顺序执行。当异常发生时,
PassContext
通过清除所有注册的PassInstrument
实例来禁用 Pass 检测。然后,
PassContext
执行每个成功完成EnterPassContext()
的PassInstrument
实例的ExitPassContext()
方法。例如,如果
PassInstrument
A、B 和C 注册到PassContext
中,并且 A 完成了EnterPassContext()
,而 B 抛出异常,那么 C 永远不会执行;A 的ExitPassContext()
会被执行。
InstrumentExitPassContext
每个
PassInstrument
实例的ExitPassContext()
按照传递给PassContext
的instruments
的顺序执行。当异常发生时,
instruments
会被清除。在抛出异常的实例之后注册的
PassInstrument
实例不会执行ExitPassContext
。
InstrumentBeforePass
如果 Pass 未列为必需 Pass,则执行
ShouldRun
。如果 Pass 未被
ShouldRun
阻止,则按照instruments
的顺序执行RunBeforePass
。请注意,
InstrumentBeforePass
返回布尔值,指示是否应运行该 Pass。当异常发生时,它会立即抛出。依赖 Python 的上下文管理器来安全地退出
PassContext
(意味着每个检测器的ExitPassContext
将被运行。对于 C++,请参阅 include/tvm/support/with.h。)
InstrumentAfterPass
RunAfterPass
按照传递给PassContext
的instruments
的顺序执行。当异常发生时,它会立即抛出。依赖 Python 的上下文管理器或
With
类(include/tvm/support/with.h)来安全地退出PassContext
。
内置检测器#
有几个内置的检测器。那些标记为 TODO 的尚未实现。
PassTimingInstrument (see src/ir/instrument.cc)
分析 Pass 的执行时间。
PrintIRBefore(TODO)
在 Pass 变换之前打印 IR 模块。如果在 Pass 周围插入
tvm.transform.PrintIR()
,也可以实现此目的。然而,使用PassInstrument
,不需要修改 Pass 的顺序。
PrintAfter(TODO)
在 Pass 变换后打印 IR 模块。
Python 前端#
前端只需要一些简单的 API。例如,可以为用户提供以下 API 来创建和执行 Pass(完整实现见 python/tvm/relay/transform/transform.py 和 python/tvm/ir/transform.py)。后端接收信息并决定应该使用哪个函数来创建 Pass 对象。
PassContext#
Python 前端为 PassContext
提供了包装器,通过重写 __enter__
和 __exit__
来启用 with
语法。提供了 current
静态方法,供用户获取在某个范围内正在使用的上下文。
@tvm._ffi.register_object("transform.PassContext")
class PassContext(tvm.runtime.Object):
def __enter__(self):
_transform.EnterPassContext(self)
return self
def __exit__(self, ptype, value, trace, config):
_transform.ExitPassContext(self)
@staticmethod
def current():
"""Return the current pass context."""
return _transform.GetCurrentPassContext()
PassContext
用于配置编译选项,包括优化级别和所需/禁用的 Pass。它还可以接受配置字典,以便不同的 Pass 可以方便地获取传递的数据,例如回退设备信息和循环展开的步长/深度等。为了能够获取所需的配置,必须通过 TVM_REGISTER_PASS_CONFIG_OPTION
注册键。例如,以下内容用于循环展开 Pass。
TVM_REGISTER_PASS_CONFIG_OPTION("tir.UnrollLoop", UnrollLoopConfig);
更多详细信息请参阅 src/tir/transforms/unroll_loop.cc。
Pass 对象#
Pass
是所有 Pass 对象的基类。这里的所有方法都是在后端实现的简单包装器。它们是为了让用户能够方便地在 Python 中与基类交互而定义的。Pass 基类中只定义了 __call__
,以使子类成为可调用对象,从而可以轻松调用它们(例如,pass_xx(arg)
)以执行。
@register_relay_node
class Pass(RelayNode):
def __call__(self, mod):
return _transform.RunPass(self, mod)
提供了一些辅助 API,以便从 Python 前端轻松创建 Pass,并让 Pass 基础设施控制执行。例如,向用户提供了 module_pass
、function_pass
和 sequential
,以便他们可以自定义自己的 Pass 或 Pass 管道。
对于所有在 C++ 后端实现的 Pass,分别在 python/tvm/ir/transform.py 和 python/tvm/relay/transform/transform.py 中提供了相应的 Python API。例如,常量折叠有如下的 Python API:
def FoldConstant():
return _transform.FoldConstant()
用户可以通过装饰器构建 Pass,如下所示:
@relay.transform.module_pass(opt_level=2)
def transform(mod, ctx):
tp = relay.TensorType((10,), "float32")
x = relay.var("x", tp)
gv = relay.GlobalVar("abs")
func = relay.Function([x], relay.abs(x))
new_mod = tvm.IRModule({gv: func})
new_mod.update(mod)
return new_mod
module_pass = transform
assert isinstance(module_pass, transform.ModulePass)
assert module_pass.info.opt_level == 2
这里的 transform
函数向输入模块添加了 abs
函数,但它可以是模块级别的任何自定义优化。创建这个 module_pass
后,用户可以将其应用于任何 Relay 模块。例如,可以构建空模块并应用这个 Pass 来添加 abs
函数。
mod = tvm.IRModule()
mod = module_pass(mod)
相应地,也为 function_pass
提供了这样的函数。例如,示例函数级 Pass 可以写成如下形式:
@relay.transform.function_pass(opt_level=1)
class TestReplaceFunc:
def __init__(self, new_func):
self.new_func = new_func
def transform_function(self, func, mod, ctx):
# Just for demo purposes
# Transform func to new_func
return self.new_func
x = relay.var("x", shape=(10, 20))
f1 = relay.Function([x], x)
f2 = relay.Function([x], relay.log(x))
# fpass is now a special pass that replaces every
# function to f1
fpass = TestReplaceFunc(f1)
# Now every function in input_mod is replaced by f1
res_mod = fpass(input_mod)
或者,用户也可以直接注册 Pass 而不使用装饰器,然后调用它。有关如何自定义优化管道以及调试 Relay 和 TIR Pass 的更多示例,请参阅 use pass infra 教程。
Pass 检测#
可以通过在实现以下方法的类上使用 pass_instrument
装饰器(python/tvm/ir/instrument.py)来实现 PassInstrument
。请注意,建议使用 pass_instrument
装饰器来实现 PassInstrument
,而不是重写或子类化。
enter_pass_ctx
此方法在进入
PassContext
时运行。
exit_pass_ctx
此方法在退出
PassContext
时运行。
should_run
此方法在 Pass 执行之前运行,返回布尔值,指示是否应运行该 Pass。
run_before_pass
如果 Pass 应该运行,此方法在 Pass 执行之前运行。
run_after_pass
此方法在 Pass 执行后立即运行。
PassInstrument
实例可以通过 tvm.transform.PassContext
中的 instruments
参数注册。
use pass instrument 教程提供了如何使用 Python API 实现 PassInstrument
的示例。
在当前 PassContext 中覆盖检测器#
提供了 override_instruments
方法来覆盖当前 PassContext
的 instruments
。例如,如果 Pass 在没有显式创建新 PassContext
的情况下运行,仍然可以通过以下方式将 PassInstrument
注册到全局 PassContext
中:
cur_pass_ctx = tvm.transform.PassContext.current()
# override PassInstrument instances
cur_pass_ctx.override_instruments([pass_inst])
mod = pass_seq(mod)
result = pass_inst.get_result()
请注意,当调用 override_instruments
时,会调用旧 PassInstrument
实例的 exit_pass_ctx
方法。然后调用新 PassInstrument
的 enter_pass_ctx
方法。