Pass 基础设施#

Relax 和 TVM IR 都包含一系列优化pass,这些过程可以提高模型在特定设备上的性能指标,例如平均推理时间、内存占用或功耗。这些优化pass包括标准优化和机器学习特定优化,如常量折叠、死代码消除、算子布局调整、算子融合、缓冲区处理和循环变换等。每个优化pass都使用在遍历期间和/或遍历之前收集的分析结果,作为 ir-to-ir 转换结构。

然而,随着 TVM 的快速发展,管理这些过程的系统性、高效性需求日益明显。此外,一个管理 TVM 堆栈不同层(例如 Relax 和 tir)中优化pass的通用框架,为开发者快速原型设计和将实现的优化pass集成到系统中铺平了道路。

本文档描述了这样的基础设施设计,该设计利用了生产编译器管理优化pass的方式,以及现代深度学习框架构建层所采用的风格。

例如,许多现有的生产编译器,如 GCC 和 LLVM,使用传递管理器来有效管理传递的执行。最初管理传递很简单,因为传递的数量较少,但成熟的编译器将包含数百个独立的传递。通常外部用户希望能够在不修改单个手工制作的传递顺序的情况下,正确地安排自定义传递。

同样,现代深度学习框架(如 Pytorch 和 MXNet Gluon)也倾向于通过 SequentialBlock 分别启用 Pass 风格的层次构建方案。通过这些构造,这些现代框架能够方便地将模块/层添加到它们的容器中,并轻松构建神经网络。

TVM 传递基础设施的设计主要受到 LLVM 中使用的分层传递管理器和流行深度学习框架中使用的块式容器启发。传递基础设施的主要目标包括:

  1. 实现更好的优化程序化编排。这使得用户能够灵活地定制和构建自己的优化pass。

  2. 提供一种用户友好的方式来调试优化pass。

  3. 减轻开发人员手动分别解决管道之间依赖关系的负担。

  4. 简化开发人员实现新管道的操作。例如,允许用户用 Python 实现其管道,并让管道基础设施来管理其执行。

设计#

专注于用户的扩展便利性,使用户能够快速添加新 Pass 而不损失向后兼容性。该设计包含后端和前端。后端实现了 Pass 基础设施的核心逻辑。前端为用户提供了简单的 API 进行交互,即允许用户快速创建自己的优化pass。

C++ 后端#

提供了 PassInfo 对象来包含 Pass 所需的基本信息。name 是 Pass 的名称,opt_level 表示 Pass 将在哪个优化级别启用,required 表示执行某个 Pass 所需的其他 Pass(更多详细信息请参阅 include/tvm/ir/transform.h)。例如,在 Pass 注册期间(将在后面介绍),Pass 开发者可以指定 Pass 的名称、将在哪个优化级别执行以及/或所需的 Pass。opt_level 可用于帮助 Pass 基础设施识别在用户提供的优化级别下是否需要执行某个 Pass。Pass 基础设施可以使用 required 字段来解决 Pass 之间的依赖关系。

class PassInfoNode : public Object {
  String name;
  int opt_level;
  Array<String> required;
};

PassContext#

PassContext 携带了优化 Pass 所需的有用信息。例如,它包含了错误报告系统,因此优化作者可以提供有关优化失败原因的诊断信息。PassContext 还被设计用来取代旧的 BuildConfig,后者用于帮助用户配置编译选项,包括优化级别和所需/禁用的Pass等。例如,可能有配置,它在 opt_level=3 下执行所有 Pass,同时使用 PassContext 提供的 disabled_pass=xx 禁用某些 Pass。现在,可以全局获取 opt_level=3 下的所有 Pass,并排除禁用 Pass 列表中的那些 Pass。PassContext 还提供了一种方法来检测所有 Pass。请参阅 Pass 检测 部分。

这个类旨在让用户方便地编写 Python 的 with 语法,以在特定配置下执行优化。此外,用户可以通过 PassContext::Current() 以线程安全的方式获取在某个程序范围内可用的上下文,因为使用了线程本地存储 PassContextThreadLocalStore 来保存创建的 Pass 上下文对象。稍后将提供示例,展示如何使用 C++ 和 Python API 来创建使用 Pass 上下文的编译管道。

class PassContextNode : public Object {
 public:
  int opt_level{2};
  tvm::Array<tvm::Expr> required_pass;
  tvm::Array<tvm::Expr> disabled_pass;
  mutable Optional<DiagnosticContext> diag_ctx;
  Map<String, Any> config;
  Array<instrument::PassInstrument> instruments;
};

class PassContext : public NodeRef {
 public:
  TVM_DLL static PassContext Create();
  TVM_DLL static PassContext Current();
  TVM_DLL void InstrumentEnterPassContext();
  TVM_DLL void InstrumentExitPassContext();
  TVM_DLL bool InstrumentBeforePass(const IRModule& mod, const PassInfo& info) const;
  TVM_DLL void InstrumentAfterPass(const IRModule& mod, const PassInfo& info) const;
  /* Other fields are omitted. */

 private:
  // The entry of a pass context scope.
  TVM_DLL void EnterWithScope();
  // The exit of a pass context scope.
  TVM_DLL void ExitWithScope();

  // Classes to get the Python `with` like syntax.
  friend class tvm::With<PassContext>;
};

struct PassContextThreadLocalEntry {
  /*! \brief The default pass context. */
  PassContext default_context;
  /*! \brief The current pass context. */
  std::stack<PassContext> context_stack;
  PassContextThreadLocalEntry() {
    default_context = PassContext(make_node<PassContextNode>());
  }
};

/*! \brief The thread-local store to hold the pass context. */
typedef dmlc::ThreadLocalStore<PassContextThreadLocalEntry>
     PassContextThreadLocalStore;

Pass 构造函数#

Pass 基础设施以分层方式设计,可以在不同粒度的 Relax/tir 程序中工作。引入了纯虚拟类 PassNode ,作为不同优化 Pass 的基础。此类包含几个必须在模块、函数或 Pass 序列级别由子类实现的虚拟方法。

class PassNode : Object {
  virtual PassInfo Info() const = 0;
  virtual Module operator()(const IRModule& mod
                            const PassContext& pass_ctx) const = 0;
};

该函子展示了如何实现 Pass,即它总是在某个上下文下对 IRModule 进行操作。所有 Pass 都以 ModuleModule 的方式设计。因此,由 Pass 基础设施管理的优化将始终更新整个模块。

已经创建了几个子类来实现不同类型的优化 passes,例如函数级 passes、模块级 passes和 passes 序列。每个子类本身都可以充当 passes 管理器。例如,它们可以收集所需的传递并执行它们,或者根据给定的元数据构建依赖图。它们的完整定义可以在 src/ir/transform.cc 中找到。

模块级 Pass#

模块级 passes 主要针对全局和跨过程优化(inter-procedural optimizations,简称 IPO),这与 LLVM 中使用的模块传递类似。Relax 中的一些典型 passes,如 A 范式变换和 lambda 抬升等,都属于这一类。在这个级别上,用户甚至可以添加和/或删除模块中的函数。请注意,所有 passes

class ModulePassNode : PassNode {
  PassInfo pass_info;
  std::function<Module(Module, PassContext)> pass_func;
  Module operator()(const Module& mod, const PassContext& pass_ctx) const final;
  // Other members/methods are omitted
};

pass_info 维护了模块级 Pass 所需的信息。pass_func 描述了实际的优化pass。例如,可能需要对模块执行死代码消除。可以在 pass_func 中实现该算法,并让它在模块上运行。然后,它将删除死代码,包括模块中未使用的函数。请注意,该字段被设计为打包函数,这使得优化可以在 C++ 和 Python 中实现。

函数级 Pass#

函数级别的 pass 用于对给定的 Relax/tir 模块实现各种函数内部级别的优化。它每次从模块的函数列表中获取一个函数进行优化,并生成一个重写的 Relax Function 或 tir PrimFunc 。大多数 pass 都可以归入此类,例如 Relax 中的 common 子表达式消除和推断简化,以及 tir 中的向量化和存储扁平化等。

请注意,此级别 pass 的作用域要么是 Relax 函数,要么是 tir 原始函数。因此,不能通过这些 pass 添加或删除函数,因为它们 unaware of 全局信息。

class FunctionPassNode : PassNode {
  PassInfo pass_info;
  std::function<Function(Function, Module, PassContext)> pass_func;
  Module operator()(const Module& mod, const PassContext& pass_ctx) const final;
  bool SkipFunction(const Function& func) const;
  // Other members/methods are omitted...
};

pass_info 与刚刚在模块级 Pass 中描述的内容相同。pass_func 接受函数进行优化,它还需要模块,因为可能用它来报告错误。函数可以用“SkipOptimization”进行注释,以便在优化期间忽略它。

Pass 序列#

SequentialPass 类似于 Pytorch 的 nn.Sequential,它包含一系列要执行的 Pass。

class SequentialPassNode : PassNode {
  PassInfo pass_info;
  // Passes need to be executed.
  Array<Pass> passes;
  bool PassEnabled(const PassInfo& info) const;
  Module operator()(const Module& mod, const PassContext& pass_ctx) const final;
};

以下代码展示了如何在顺序 Pass 中调用各个 Pass。本质上,按照它们被添加到 Pass 列表中的顺序依次执行顺序 Pass 中的每个 Pass。

Module SequentialNode::operator()(const Module& module,
                                  const PassContext& pass_ctx) const {
  Module mod = module;
  for (const Pass& pass : passes) {
    ICHECK(pass.defined()) << "Found undefined pass for optimization.";
    const PassInfo& pass_info = pass->Info();
    if (!PassEnabled(pass_info))  continue;
    for (const auto& it : pass_info->required) {
      const auto* name = it.as<tvm::ir::StringImm>();
      ICHECK(name);
      mod = GetPass(name->value)(mod, pass_ctx);
    }
    mod = pass(mod, pass_ctx);
  }
  return mod;
}

在调用 Pass 时,首先检查该 Pass 是否启用。这是通过首先检查用户是否显式禁用了该 Pass,然后检查用户是否将其指定为必需 Pass 来完成的。如果仍然无法确定该 Pass 是否启用,将检查其 opt_level。只有当其优化级别不低于 Pass 上下文中配置的优化级别时,该 Pass 才会被启用并执行。

要执行 Pass,需要首先使用 Pass 名称从 TVM 打包函数注册表中检索已注册的 Pass。这是可能的,因为每个 Pass 都注册了 API 端点,稍后会展示这一点。

Pass GetPass(const std::string& pass_name) {
  using tvm::runtime::Registry;
  std::string fpass_name = "relax.transform." + pass_name;
  const std::optional<tvm::ffi::Function> f = tvm::ffi::Function::GetGlobal(fpass_name);
  ICHECK(f.has_value()) << "Cannot find " << fpass_name
                        << "to create the pass " << pass_name;
  return (*f)();
}

提供了一些辅助函数来创建上述每种类型的 Pass。这些辅助函数也暴露给 Python 前端,以便用户方便地使用 Python API 创建特定的 Pass 对象。

Pass CreateFunctionPass(
    std::function<Function(Function, IRModule, PassContext)> pass_func,
    int opt_level,
    String name,
    Array<String> required);

Pass CreatePrimFuncPass(
    std::function<PrimFunc(PrimFunc, IRModule, PassContext)> pass_func,
    int opt_level,
    String name,
    Array<String> required);

Pass CreateModulePass(
    std::function<IRModule(IRModule, PassContext)> pass_func,
    int opt_level,
    String name,
    Array<String> required);

Pass Sequential(tvm::Array<Pass> passes, PassInfo pass_info);

Pass 注册#

经介绍了不同级别的 pass 的概念以及用于编译的上下文。看看用户可以如何轻松地注册 Pass 会很有趣。以 const folding 为例。这个 pass 已经被实现用于在 Relax 函数中折叠常量(位于 src/relax/transforms/fold_constant.cc)。

提供了 API 来执行 ExprExpr 的变换。

Expr FoldConstant(const Expr& expr);

为了将这个 Pass 注册到 Pass 基础设施中,首先需要确定该 Pass 将在哪个级别执行。由于常量折叠发生在单个函数上,应该直观地通过 CreateFunctionPass 为其创建 FunctionPasspass_func 作为打包函数返回,它在 IRModule 中的每个函数上调用 ExprExpr 的API。{} 表示该 Pass 没有先决条件。否则,Pass 开发者必须识别并列出它们。

与此同时,名为 relax.transform.FoldConstant 的 pass API 端点被注册。因此,这个 pass 成为注册表中的一个条目,当需要时,C++(例如上述的 GetPass )和 Python 都可以访问它。

namespace transform {

Pass FoldConstant() {
  auto pass_func =
      [=](Function f, IRModule m, PassContext pc) { return ConstantFolder::Fold(f, m); };
  return CreateFunctionPass(pass_func, 0, "FoldConstant", {});
}

TVM_FFI_STATIC_INIT_BLOCK({
  namespace refl = tvm::ffi::reflection;
  refl::GlobalDef().def("relax.transform.FoldConstant", FoldConstant);
});

}  // namespace transform

为了让其他 C++模块能够应用这个 pass,在 include/tvm/relax/transform.h 中声明了 free 函数,如下所示:

TVM_DLL Pass FoldConstant();

Pass 检测#

Pass 检测(Instrument)是一种分析 Pass 本身的机制。例如,可以使用该基础设施来了解 Pass 需要多少时间和内存,或者 Pass 如何变换 IR 模块。

PassContext 的生命周期中引入了四个检测点。

TVM_DLL void InstrumentEnterPassContext();
TVM_DLL void InstrumentExitPassContext();
TVM_DLL bool InstrumentBeforePass(const IRModule& mod, const PassInfo& info) const;
TVM_DLL void InstrumentAfterPass(const IRModule& mod, const PassInfo& info) const;

InstrumentEnterPassContext 在进入 PassContext 实例的范围时立即调用。

InstrumentExitPassContext 在离开 PassContext 的作用域时,或在 passes 执行过程中发生异常时被调用。当 override_instrumentstvm.transform.PassContext 中覆盖 instruments 时,该方法也会被调用。参见 在当前 PassContext 中覆盖检测器

InstrumentBeforePass 在执行前调用。如果应该运行此 pass,InstrumentAfterPass 在执行后调用。行为如下:

if (pass_ctx.InstrumentBeforePass(ir_module, pass_info)) {
  new_ir_module = run_pass(ir_module, pass_ctx);
  pass_ctx.InstrumentAfterPass(new_ir_module, pass_info);
  return new_ir_module;
}

PassInstrument 接口允许你在上述四个方法中运行任意代码。多个 PassInstrument 实例可以注册到 PassContext 中。PassInstrument 实例按照传递给 PassContextinstruments 参数的顺序依次调用。

PassInstrument 提供以下接口:

namespace instrument {

class PassInstrumentNode : public Object {
 public:
  String name;
  virtual void EnterPassContext() const = 0;
  virtual void ExitPassContext() const = 0;
  virtual bool ShouldRun(const IRModule& mod, const transform::PassInfo& info) const = 0;
  virtual void RunBeforePass(const IRModule& mod, const transform::PassInfo& info) const = 0;
  virtual void RunAfterPass(const IRModule& mod, const transform::PassInfo& info) const = 0;
  /* Other fields are omitted. */
};

class PassInstrument : public ObjectRef {
 public:
  TVM_DEFINE_OBJECT_REF_METHODS(PassInstrument, ObjectRef, PassInstrumentNode);
};

}  // namespace instrument

提供了 Python 前端来快速实现 PassInstrument。请参阅 Pass 检测

PassContext 中,PassInstrument 实例的调用顺序如下:

with PassContext(instruments=[pi]) # pi = a PassInstrument implementation.
    pi.EnterPassContext()

    if pi.ShouldRun(Pass1):
        pi.RunBeforePass()
        Pass1()
        pi.RunAfterPass()

    if pi.ShouldRun(Pass2):
        pi.RunBeforePass()
        Pass2()
        pi.RunAfterPass()

    pi.ExitPassContext()

以下是 PassInstrument 接口与 PassContext 方法之间关系的简要介绍。更多详细信息请参阅(src/ir/transform.cc)。

  • InstrumentEnterPassContext

    • EnterPassContext() 按照传递给 PassContextinstruments 的顺序执行。

    • 当异常发生时,PassContext 通过清除所有注册的 PassInstrument 实例来禁用 Pass 检测。

    • 然后,PassContext 执行每个成功完成 EnterPassContext()PassInstrument 实例的 ExitPassContext() 方法。

    • 例如,如果 PassInstrument A、B 和C 注册到 PassContext 中,并且 A 完成了 EnterPassContext(),而 B 抛出异常,那么 C 永远不会执行;A 的 ExitPassContext() 会被执行。

  • InstrumentExitPassContext

    • 每个 PassInstrument 实例的 ExitPassContext() 按照传递给 PassContextinstruments 的顺序执行。

    • 当异常发生时,instruments 会被清除。

    • 在抛出异常的实例之后注册的 PassInstrument 实例不会执行 ExitPassContext

  • InstrumentBeforePass

    • 如果 Pass 未列为必需 Pass,则执行 ShouldRun

    • 如果 Pass 未被 ShouldRun 阻止,则按照 instruments 的顺序执行 RunBeforePass

    • 请注意,InstrumentBeforePass 返回布尔值,指示是否应运行该 Pass。

    • 当异常发生时,它会立即抛出。依赖 Python 的上下文管理器来安全地退出 PassContext (意味着每个检测器的 ExitPassContext 将被运行。对于 C++,请参阅 include/tvm/support/with.h。)

  • InstrumentAfterPass

    • RunAfterPass 按照传递给 PassContextinstruments 的顺序执行。

    • 当异常发生时,它会立即抛出。依赖 Python 的上下文管理器或 With 类(include/tvm/support/with.h)来安全地退出 PassContext

内置检测器#

有几个内置的检测器。那些标记为 TODO 的尚未实现。

  • PassTimingInstrument (see src/ir/instrument.cc)

    • 分析 Pass 的执行时间。

  • PrintIRBefore(TODO)

    • 在 Pass 变换之前打印 IR 模块。如果在 Pass 周围插入 tvm.transform.PrintIR(),也可以实现此目的。然而,使用 PassInstrument,不需要修改 Pass 的顺序。

  • PrintAfter(TODO)

    • 在 Pass 变换后打印 IR 模块。

Python 前端#

前端方面只需要一些简单的 API。例如,可以为用户提供以下 API 来创建和执行 pass(完整的实现提供在 python/tvm/relax/transform/transform.pypython/tvm/ir/transform.py 中)。后端接收这些信息,并决定它应该使用哪个函数来创建 Pass 对象。

PassContext#

Python 前端为 PassContext 提供了包装器,通过重写 __enter____exit__ 来启用 with 语法。提供了 current 静态方法,供用户获取在某个范围内正在使用的上下文。

@tvm.ffi.register_object("transform.PassContext")
class PassContext(tvm.runtime.Object):
    def __enter__(self):
        _transform.EnterPassContext(self)
        return self

    def __exit__(self, ptype, value, trace, config):
        _transform.ExitPassContext(self)

    @staticmethod
    def current():
        """Return the current pass context."""
        return _transform.GetCurrentPassContext()

PassContext 用于配置编译选项,包括优化级别和所需/禁用的 Pass。它还可以接受配置字典,以便不同的 Pass 可以方便地获取传递的数据,例如回退设备信息和循环展开的步长/深度等。为了能够获取所需的配置,必须通过 TVM_REGISTER_PASS_CONFIG_OPTION 注册键。例如,以下内容用于循环展开 Pass。

TVM_REGISTER_PASS_CONFIG_OPTION("tir.UnrollLoop", UnrollLoopConfig);

更多详细信息请参阅 src/tir/transforms/unroll_loop.cc

Pass 检测#

可以通过在实现以下方法的类上使用 pass_instrument 装饰器(python/tvm/ir/instrument.py)来实现 PassInstrument。请注意,建议使用 pass_instrument 装饰器来实现 PassInstrument,而不是重写或子类化。

  • enter_pass_ctx

    • 此方法在进入 PassContext 时运行。

  • exit_pass_ctx

    • 此方法在退出 PassContext 时运行。

  • should_run

    • 此方法在 Pass 执行之前运行,返回布尔值,指示是否应运行该 Pass。

  • run_before_pass

    • 如果 Pass 应该运行,此方法在 Pass 执行之前运行。

  • run_after_pass

    • 此方法在 Pass 执行后立即运行。

PassInstrument 实例可以通过 tvm.transform.PassContext 中的 instruments 参数注册。

use pass instrument 教程提供了如何使用 Python API 实现 PassInstrument 的示例。

在当前 PassContext 中覆盖检测器#

提供了 override_instruments 方法来覆盖当前 PassContextinstruments。例如,如果 Pass 在没有显式创建新 PassContext 的情况下运行,仍然可以通过以下方式将 PassInstrument 注册到全局 PassContext 中:

cur_pass_ctx = tvm.transform.PassContext.current()
# override PassInstrument instances
cur_pass_ctx.override_instruments([pass_inst])
mod = pass_seq(mod)
result = pass_inst.get_result()

请注意,当调用 override_instruments 时,会调用旧 PassInstrument 实例的 exit_pass_ctx 方法。然后调用新 PassInstrumententer_pass_ctx 方法。