Pass 基础设施#

Both Relax and TVM IR contain a series of optimization passes which improve performance metrics of models such as mean inference, memory footprint, or power consumption for specific devices. There is a suite of standard optimizations as well as machine learning-specific optimizations including constant folding, dead code elimination, operator layout alteration, operator fusion, buffer handling, and loop transformation, etc. Each of these passes is structured as a ir-to-ir transformation using the analysis result collected during and/or before traversal.

However, as TVM evolves quickly, the need for a more systematic and efficient way to manage these passes is becoming apparent. In addition, a generic framework that manages the passes across different layers of the TVM stack (e.g. Relax and tir) paves the way for developers to quickly prototype and plug the implemented passes into the system.

本文档描述了这样一种基础设施的设计,它利用了生产编译器管理优化 Pass 的方式以及现代深度学习框架构建层次结构的风格。

例如,许多现有的生产编译器(如 GCC 和 LLVM)使用 Pass 管理器来有效管理 Pass 的执行。最初,由于 Pass 数量较少,管理 Pass 相对简单,但成熟的编译器将包含数百个独立的 Pass。通常,外部用户希望能够在不修改手动编写的 Pass 顺序的情况下,正确调度自定义 Pass。

同样,现代深度学习框架(如 Pytorch 和 MXNet Gluon)也倾向于通过 SequentialBlock 分别启用 Pass 风格的层次构建方案。通过这些构造,这些现代框架能够方便地将模块/层添加到它们的容器中,并轻松构建神经网络。

The design of the TVM pass infra is largely inspired by the hierarchical pass manager used in LLVM and the block-style containers used in the popular deep learning frameworks. The major goals of the pass infra include:

  1. 实现更好的优化程序化编排。这使得用户能够灵活地定制和构建自己的优化管道。

  2. 提供一种用户友好的方式来调试优化 Pass。

  3. 减轻开发者手动和分别解决 Pass 之间依赖关系的负担。

  4. 简化开发者实现新 Pass 的过程。例如,允许用户在 Python 中实现 Pass,并让 Pass 基础设施管理其执行。

设计#

专注于用户的扩展便利性,使用户能够快速添加新 Pass 而不损失向后兼容性。该设计包含后端和前端。后端实现了 Pass 基础设施的核心逻辑。前端为用户提供了简单的 API 进行交互,即允许用户快速创建自己的优化管道。

C++ 后端#

提供了 PassInfo 对象来包含 Pass 所需的基本信息。name 是 Pass 的名称,opt_level 表示 Pass 将在哪个优化级别启用,required 表示执行某个 Pass 所需的其他 Pass(更多详细信息请参阅 include/tvm/ir/transform.h)。例如,在 Pass 注册期间(将在后面介绍),Pass 开发者可以指定 Pass 的名称、将在哪个优化级别执行以及/或所需的 Pass。opt_level 可用于帮助 Pass 基础设施识别在用户提供的优化级别下是否需要执行某个 Pass。Pass 基础设施可以使用 required 字段来解决 Pass 之间的依赖关系。

class PassInfoNode : public Object {
  String name;
  int opt_level;
  Array<String> required;
};

PassContext#

PassContext 携带了优化 Pass 所需的有用信息。例如,它包含了错误报告系统,因此优化作者可以提供有关优化失败原因的诊断信息。PassContext 还被设计用来取代旧的 BuildConfig,后者用于帮助用户配置编译选项,包括优化级别和所需/禁用的Pass等。例如,可能有配置,它在 opt_level=3 下执行所有 Pass,同时使用 PassContext 提供的 disabled_pass=xx 禁用某些 Pass。现在,可以全局获取 opt_level=3 下的所有 Pass,并排除禁用 Pass 列表中的那些 Pass。PassContext 还提供了一种方法来检测所有 Pass。请参阅 Pass 检测 部分。

这个类旨在让用户方便地编写 Python 的 with 语法,以在特定配置下执行优化。此外,用户可以通过 PassContext::Current() 以线程安全的方式获取在某个程序范围内可用的上下文,因为使用了线程本地存储 PassContextThreadLocalStore 来保存创建的 Pass 上下文对象。稍后将提供示例,展示如何使用 C++ 和 Python API 来创建使用 Pass 上下文的编译管道。

class PassContextNode : public Object {
 public:
  int opt_level{2};
  tvm::Array<tvm::Expr> required_pass;
  tvm::Array<tvm::Expr> disabled_pass;
  mutable Optional<DiagnosticContext> diag_ctx;
  Map<String, ObjectRef> config;
  Array<instrument::PassInstrument> instruments;
};

class PassContext : public NodeRef {
 public:
  TVM_DLL static PassContext Create();
  TVM_DLL static PassContext Current();
  TVM_DLL void InstrumentEnterPassContext();
  TVM_DLL void InstrumentExitPassContext();
  TVM_DLL bool InstrumentBeforePass(const IRModule& mod, const PassInfo& info) const;
  TVM_DLL void InstrumentAfterPass(const IRModule& mod, const PassInfo& info) const;
  /* Other fields are omitted. */

 private:
  // The entry of a pass context scope.
  TVM_DLL void EnterWithScope();
  // The exit of a pass context scope.
  TVM_DLL void ExitWithScope();

  // Classes to get the Python `with` like syntax.
  friend class tvm::With<PassContext>;
};

struct PassContextThreadLocalEntry {
  /*! \brief The default pass context. */
  PassContext default_context;
  /*! \brief The current pass context. */
  std::stack<PassContext> context_stack;
  PassContextThreadLocalEntry() {
    default_context = PassContext(make_node<PassContextNode>());
  }
};

/*! \brief The thread-local store to hold the pass context. */
typedef dmlc::ThreadLocalStore<PassContextThreadLocalEntry>
     PassContextThreadLocalStore;

Pass 构造函数#

The pass infra is designed in a hierarchical manner, and it could work at different granularities of Relax/tir programs. A pure virtual class PassNode is introduced to serve as the base of the different optimization passes. This class contains several virtual methods that must be implemented by the subclasses at the level of modules, functions, or sequences of passes.

class PassNode : Object {
  virtual PassInfo Info() const = 0;
  virtual Module operator()(const IRModule& mod
                            const PassContext& pass_ctx) const = 0;
};

该函子展示了如何实现 Pass,即它总是在某个上下文下对 IRModule 进行操作。所有 Pass 都以 ModuleModule 的方式设计。因此,由 Pass 基础设施管理的优化将始终更新整个模块。

Several subclasses have been created to implement different types of optimization passes, e.g., function-level passes, module-level passes, and sequential passes. Each subclass itself could act as a pass manager. For instance, they could collect the required passes and execute them or build a dependency graph based on the given metadata. The full definition of them can be found in src/ir/transform.cc.

模块级 Pass#

Module level passes are geared mainly for global and inter-procedural optimizations (IPO), which are similar to the module pass used in LLVM. Some typical passes in Relax that need the global picture of a module, such as A-normal form conversion and lambda lifting, etc., fall into this set. At this level, users can even add and/or delete functions in a module. Note that all passes

class ModulePassNode : PassNode {
  PassInfo pass_info;
  runtime::TypedPackedFunc<Module(Module, PassContext)> pass_func;
  Module operator()(const Module& mod, const PassContext& pass_ctx) const final;
  // Other members/methods are omitted
};

pass_info 维护了模块级 Pass 所需的信息。pass_func 描述了实际的优化过程。例如,可能需要对模块执行死代码消除。可以在 pass_func 中实现该算法,并让它在模块上运行。然后,它将删除死代码,包括模块中未使用的函数。请注意,该字段被设计为打包函数,这使得优化可以在 C++ 和 Python 中实现。

函数级 Pass#

Function-level passes are used to implement various intra-function level optimizations for a given Relax/tir module. It fetches one function at a time from the function list of a module for optimization and yields a rewritten Relax Function or tir PrimFunc. Most of passes can be classified into this category, such as common subexpression elimination and inference simplification in Relax as well as vectorization and flattening storage in tir, etc.

Note that the scope of passes at this level is either a Relax function or a tir primitive function. Therefore, we cannot add or delete a function through these passes as they are not aware of the global information.

class FunctionPassNode : PassNode {
  PassInfo pass_info;
  runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func;
  Module operator()(const Module& mod, const PassContext& pass_ctx) const final;
  bool SkipFunction(const Function& func) const;
  // Other members/methods are omitted...
};

pass_info 与我刚刚在模块级 Pass 中描述的内容相同。pass_func 接受函数进行优化,它还需要模块,因为可能用它来报告错误。函数可以用“SkipOptimization”进行注释,以便在优化期间忽略它。

Pass 序列#

SequentialPass 类似于 Pytorch 的 nn.Sequential,它包含一系列要执行的 Pass。

class SequentialPassNode : PassNode {
  PassInfo pass_info;
  // Passes need to be executed.
  Array<Pass> passes;
  bool PassEnabled(const PassInfo& info) const;
  Module operator()(const Module& mod, const PassContext& pass_ctx) const final;
};

以下代码展示了如何在顺序 Pass 中调用各个 Pass。本质上,按照它们被添加到 Pass 列表中的顺序依次执行顺序 Pass 中的每个 Pass。

Module SequentialNode::operator()(const Module& module,
                                  const PassContext& pass_ctx) const {
  Module mod = module;
  for (const Pass& pass : passes) {
    ICHECK(pass.defined()) << "Found undefined pass for optimization.";
    const PassInfo& pass_info = pass->Info();
    if (!PassEnabled(pass_info))  continue;
    for (const auto& it : pass_info->required) {
      const auto* name = it.as<tvm::ir::StringImm>();
      ICHECK(name);
      mod = GetPass(name->value)(mod, pass_ctx);
    }
    mod = pass(mod, pass_ctx);
  }
  return mod;
}

在调用 Pass 时,首先检查该 Pass 是否启用。这是通过首先检查用户是否显式禁用了该 Pass,然后检查用户是否将其指定为必需 Pass 来完成的。如果仍然无法确定该 Pass 是否启用,将检查其 opt_level。只有当其优化级别不低于 Pass 上下文中配置的优化级别时,该 Pass 才会被启用并执行。

要执行 Pass,需要首先使用 Pass 名称从 TVM 打包函数注册表中检索已注册的 Pass。这是可能的,因为每个 Pass 都注册了 API 端点,稍后会展示这一点。

Pass GetPass(const std::string& pass_name) {
  using tvm::runtime::Registry;
  std::string fpass_name = "relax.transform." + pass_name;
  const auto* f = Registry::Get(fpass_name);
  ICHECK(f != nullptr) << "Cannot find " << fpass_name
                      << "to create the pass " << pass_name;
  return (*f)();
}

提供了一些辅助函数来创建上述每种类型的 Pass。这些辅助函数也暴露给 Python 前端,以便用户方便地使用 Python API 创建特定的 Pass 对象。

Pass CreateFunctionPass(
    const runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)>& pass_func,
    int opt_level,
    String name,
    Array<String> required);

Pass CreatePrimFuncPass(
    const runtime::TypedPackedFunc<PrimFunc(PrimFunc, IRModule, PassContext)>& pass_func,
    int opt_level,
    String name,
    Array<String> required);

Pass CreateModulePass(
    const runtime::TypedPackedFunc<IRModule(IRModule, PassContext)>& pass_func,
    int opt_level,
    String name,
    Array<String> required);

Pass Sequential(tvm::Array<Pass> passes, PassInfo pass_info);

Pass 注册#

We've covered the concept of different level of passes and the context used for compilation. It would be interesting to see how easily users can register a pass. Let's take const folding as an example. This pass has already been implemented to fold constants in a Relax function (found in src/relax/transforms/fold_constant.cc).

提供了 API 来执行 ExprExpr 的变换。

Expr FoldConstant(const Expr& expr);

为了将这个 Pass 注册到 Pass 基础设施中,首先需要确定该 Pass 将在哪个级别执行。由于常量折叠发生在单个函数上,应该直观地通过 CreateFunctionPass 为其创建 FunctionPasspass_func 作为打包函数返回,它在 IRModule 中的每个函数上调用 ExprExpr 的API。{} 表示该 Pass 没有先决条件。否则,Pass 开发者必须识别并列出它们。

Meanwhile, a pass API endpoint is registered with the name "relax.transform.FoldConstant. This pass, therefore, becomes an entry in the registry that can be accessed by both C++ (e.g. the GetPass above) and Python when needed.

namespace transform {

Pass FoldConstant() {
  runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
      [=](Function f, IRModule m, PassContext pc) { return ConstantFolder::Fold(f, m); };
  return CreateFunctionPass(pass_func, 0, "FoldConstant", {});
}

TVM_REGISTER_GLOBAL("relax.transform.FoldConstant")
.set_body_typed(FoldConstant);

}  // namespace transform

To allow other C++ modules to apply this pass, we declare a free function in include/tvm/relax/transform.h as the following:

TVM_DLL Pass FoldConstant();

Pass 检测#

Pass 检测(Instrument)是一种分析 Pass 本身的机制。例如,可以使用该基础设施来了解 Pass 需要多少时间和内存,或者 Pass 如何变换 IR 模块。

PassContext 的生命周期中引入了四个检测点。

TVM_DLL void InstrumentEnterPassContext();
TVM_DLL void InstrumentExitPassContext();
TVM_DLL bool InstrumentBeforePass(const IRModule& mod, const PassInfo& info) const;
TVM_DLL void InstrumentAfterPass(const IRModule& mod, const PassInfo& info) const;

InstrumentEnterPassContext 在进入 PassContext 实例的范围时立即调用。

InstrumentExitPassContext is called when leaving the scope of PassContext, or exceptions occur during the execution of passes. This method is also called when instruments is being overridden by override_instruments in tvm.transform.PassContext. See 在当前 PassContext 中覆盖检测器.

InstrumentBeforePass 在执行前调用。如果 Pass 应该运行,InstrumentAfterPass 在执行后调用。行为如下:

if (pass_ctx.InstrumentBeforePass(ir_module, pass_info)) {
  new_ir_module = run_pass(ir_module, pass_ctx);
  pass_ctx.InstrumentAfterPass(new_ir_module, pass_info);
  return new_ir_module;
}

PassInstrument 接口允许你在上述四个方法中运行任意代码。多个 PassInstrument 实例可以注册到 PassContext 中。PassInstrument 实例按照传递给 PassContextinstruments 参数的顺序依次调用。

PassInstrument 提供以下接口:

namespace instrument {

class PassInstrumentNode : public Object {
 public:
  String name;
  virtual void EnterPassContext() const = 0;
  virtual void ExitPassContext() const = 0;
  virtual bool ShouldRun(const IRModule& mod, const transform::PassInfo& info) const = 0;
  virtual void RunBeforePass(const IRModule& mod, const transform::PassInfo& info) const = 0;
  virtual void RunAfterPass(const IRModule& mod, const transform::PassInfo& info) const = 0;
  /* Other fields are omitted. */
};

class PassInstrument : public ObjectRef {
 public:
  TVM_DEFINE_OBJECT_REF_METHODS(PassInstrument, ObjectRef, PassInstrumentNode);
};

}  // namespace instrument

提供了 Python 前端来快速实现 PassInstrument。请参阅 Pass 检测

PassContext 中,PassInstrument 实例的调用顺序如下:

with PassContext(instruments=[pi]) # pi = a PassInstrument implementation.
    pi.EnterPassContext()

    if pi.ShouldRun(Pass1):
        pi.RunBeforePass()
        Pass1()
        pi.RunAfterPass()

    if pi.ShouldRun(Pass2):
        pi.RunBeforePass()
        Pass2()
        pi.RunAfterPass()

    pi.ExitPassContext()

以下是 PassInstrument 接口与 PassContext 方法之间关系的简要介绍。更多详细信息请参阅(src/ir/transform.cc)。

  • InstrumentEnterPassContext

    • EnterPassContext() 按照传递给 PassContextinstruments 的顺序执行。

    • 当异常发生时,PassContext 通过清除所有注册的 PassInstrument 实例来禁用 Pass 检测。

    • 然后,PassContext 执行每个成功完成 EnterPassContext()PassInstrument 实例的 ExitPassContext() 方法。

    • 例如,如果 PassInstrument A、B 和C 注册到 PassContext 中,并且 A 完成了 EnterPassContext(),而 B 抛出异常,那么 C 永远不会执行;A 的 ExitPassContext() 会被执行。

  • InstrumentExitPassContext

    • 每个 PassInstrument 实例的 ExitPassContext() 按照传递给 PassContextinstruments 的顺序执行。

    • 当异常发生时,instruments 会被清除。

    • 在抛出异常的实例之后注册的 PassInstrument 实例不会执行 ExitPassContext

  • InstrumentBeforePass

    • 如果 Pass 未列为必需 Pass,则执行 ShouldRun

    • 如果 Pass 未被 ShouldRun 阻止,则按照 instruments 的顺序执行 RunBeforePass

    • 请注意,InstrumentBeforePass 返回布尔值,指示是否应运行该 Pass。

    • 当异常发生时,它会立即抛出。依赖 Python 的上下文管理器来安全地退出 PassContext (意味着每个检测器的 ExitPassContext 将被运行。对于 C++,请参阅 include/tvm/support/with.h。)

  • InstrumentAfterPass

    • RunAfterPass 按照传递给 PassContextinstruments 的顺序执行。

    • 当异常发生时,它会立即抛出。依赖 Python 的上下文管理器或 With 类(include/tvm/support/with.h)来安全地退出 PassContext

内置检测器#

有几个内置的检测器。那些标记为 TODO 的尚未实现。

  • PassTimingInstrument (see src/ir/instrument.cc)

    • 分析 Pass 的执行时间。

  • PrintIRBefore(TODO)

    • 在 Pass 变换之前打印 IR 模块。如果在 Pass 周围插入 tvm.transform.PrintIR(),也可以实现此目的。然而,使用 PassInstrument,不需要修改 Pass 的顺序。

  • PrintAfter(TODO)

    • 在 Pass 变换后打印 IR 模块。

Python 前端#

Only some simple APIs are needed for the frontend side. For example, we can provide users the following APIs to create and execute a pass (full implementation is provided in python/tvm/relax/transform/transform.py and python/tvm/ir/transform.py). The backend receives the information and decides which function it should use to create a Pass object.

PassContext#

Python 前端为 PassContext 提供了包装器,通过重写 __enter____exit__ 来启用 with 语法。提供了 current 静态方法,供用户获取在某个范围内正在使用的上下文。

@tvm._ffi.register_object("transform.PassContext")
class PassContext(tvm.runtime.Object):
    def __enter__(self):
        _transform.EnterPassContext(self)
        return self

    def __exit__(self, ptype, value, trace, config):
        _transform.ExitPassContext(self)

    @staticmethod
    def current():
        """Return the current pass context."""
        return _transform.GetCurrentPassContext()

PassContext 用于配置编译选项,包括优化级别和所需/禁用的 Pass。它还可以接受配置字典,以便不同的 Pass 可以方便地获取传递的数据,例如回退设备信息和循环展开的步长/深度等。为了能够获取所需的配置,必须通过 TVM_REGISTER_PASS_CONFIG_OPTION 注册键。例如,以下内容用于循环展开 Pass。

TVM_REGISTER_PASS_CONFIG_OPTION("tir.UnrollLoop", UnrollLoopConfig);

更多详细信息请参阅 src/tir/transforms/unroll_loop.cc

Pass 检测#

可以通过在实现以下方法的类上使用 pass_instrument 装饰器(python/tvm/ir/instrument.py)来实现 PassInstrument。请注意,建议使用 pass_instrument 装饰器来实现 PassInstrument,而不是重写或子类化。

  • enter_pass_ctx

    • 此方法在进入 PassContext 时运行。

  • exit_pass_ctx

    • 此方法在退出 PassContext 时运行。

  • should_run

    • 此方法在 Pass 执行之前运行,返回布尔值,指示是否应运行该 Pass。

  • run_before_pass

    • 如果 Pass 应该运行,此方法在 Pass 执行之前运行。

  • run_after_pass

    • 此方法在 Pass 执行后立即运行。

PassInstrument 实例可以通过 tvm.transform.PassContext 中的 instruments 参数注册。

use pass instrument 教程提供了如何使用 Python API 实现 PassInstrument 的示例。

在当前 PassContext 中覆盖检测器#

提供了 override_instruments 方法来覆盖当前 PassContextinstruments。例如,如果 Pass 在没有显式创建新 PassContext 的情况下运行,仍然可以通过以下方式将 PassInstrument 注册到全局 PassContext 中:

cur_pass_ctx = tvm.transform.PassContext.current()
# override PassInstrument instances
cur_pass_ctx.override_instruments([pass_inst])
mod = pass_seq(mod)
result = pass_inst.get_result()

请注意,当调用 override_instruments 时,会调用旧 PassInstrument 实例的 exit_pass_ctx 方法。然后调用新 PassInstrumententer_pass_ctx 方法。