Pass 基础设施#

Relay 和 TVM IR 都包含一系列优化 Pass,这些 Pass 可以提高模型的性能指标,例如特定设备的平均推理时间、内存占用或功耗。这些优化包括标准优化以及机器学习特定的优化,例如常量折叠、死代码消除、算子布局更改、算子融合、缓冲区处理和循环转换等。每个 Pass 利用在遍历期间和/或之前收集的分析结果,构建为一种 IR 到 IR 的变换。

然而,随着 TVM 的快速发展,对一种更系统、更高效的方式来管理这些 Pass 的需求变得日益明显。此外,通用的框架来管理 TVM 堆栈不同层(例如 Relay 和 TIR)中的 Pass,为开发者快速原型化并将其实现的 Pass 集成到系统中铺平了道路。

本文档描述了这样一种基础设施的设计,它利用了生产编译器管理优化 Pass 的方式以及现代深度学习框架构建层次结构的风格。

例如,许多现有的生产编译器(如 GCC 和 LLVM)使用 Pass 管理器来有效管理 Pass 的执行。最初,由于 Pass 数量较少,管理 Pass 相对简单,但成熟的编译器将包含数百个独立的 Pass。通常,外部用户希望能够在不修改手动编写的 Pass 顺序的情况下,正确调度自定义 Pass。

同样,现代深度学习框架(如 Pytorch 和 MXNet Gluon)也倾向于通过 SequentialBlock 分别启用 Pass 风格的层次构建方案。通过这些构造,这些现代框架能够方便地将模块/层添加到它们的容器中,并轻松构建神经网络。

Relay Pass 基础设施的设计很大程度上受到了 LLVM 中使用的分层 Pass 管理器以及流行深度学习框架中使用的块式容器的启发。Pass 基础设施的主要目标包括:

  1. 实现更好的优化程序化编排。这使得用户能够灵活地定制和构建自己的优化管道。

  2. 提供一种用户友好的方式来调试优化 Pass。

  3. 减轻开发者手动和分别解决 Pass 之间依赖关系的负担。

  4. 简化开发者实现新 Pass 的过程。例如,允许用户在 Python 中实现 Pass,并让 Pass 基础设施管理其执行。

设计#

专注于用户的扩展便利性,使用户能够快速添加新 Pass 而不损失向后兼容性。该设计包含后端和前端。后端实现了 Pass 基础设施的核心逻辑。前端为用户提供了简单的 API 进行交互,即允许用户快速创建自己的优化管道。

C++ 后端#

提供了 PassInfo 对象来包含 Pass 所需的基本信息。name 是 Pass 的名称,opt_level 表示 Pass 将在哪个优化级别启用,required 表示执行某个 Pass 所需的其他 Pass(更多详细信息请参阅 include/tvm/ir/transform.h)。例如,在 Pass 注册期间(将在后面介绍),Pass 开发者可以指定 Pass 的名称、将在哪个优化级别执行以及/或所需的 Pass。opt_level 可用于帮助 Pass 基础设施识别在用户提供的优化级别下是否需要执行某个 Pass。Pass 基础设施可以使用 required 字段来解决 Pass 之间的依赖关系。

class PassInfoNode : public Object {
  String name;
  int opt_level;
  Array<String> required;
};

PassContext#

PassContext 携带了优化 Pass 所需的有用信息。例如,它包含了错误报告系统,因此优化作者可以提供有关优化失败原因的诊断信息。PassContext 还被设计用来取代旧的 BuildConfig,后者用于帮助用户配置编译选项,包括优化级别和所需/禁用的Pass等。例如,可能有配置,它在 opt_level=3 下执行所有 Pass,同时使用 PassContext 提供的 disabled_pass=xx 禁用某些 Pass。现在,可以全局获取 opt_level=3 下的所有 Pass,并排除禁用 Pass 列表中的那些 Pass。PassContext 还提供了一种方法来检测所有 Pass。请参阅 Pass 检测 部分。

这个类旨在让用户方便地编写 Python 的 with 语法,以在特定配置下执行优化。此外,用户可以通过 PassContext::Current() 以线程安全的方式获取在某个程序范围内可用的上下文,因为使用了线程本地存储 PassContextThreadLocalStore 来保存创建的 Pass 上下文对象。稍后将提供示例,展示如何使用 C++ 和 Python API 来创建使用 Pass 上下文的编译管道。

class PassContextNode : public Object {
 public:
  int opt_level{2};
  tvm::Array<tvm::Expr> required_pass;
  tvm::Array<tvm::Expr> disabled_pass;
  mutable Optional<DiagnosticContext> diag_ctx;
  Map<String, ObjectRef> config;
  Array<instrument::PassInstrument> instruments;
};

class PassContext : public NodeRef {
 public:
  TVM_DLL static PassContext Create();
  TVM_DLL static PassContext Current();
  TVM_DLL void InstrumentEnterPassContext();
  TVM_DLL void InstrumentExitPassContext();
  TVM_DLL bool InstrumentBeforePass(const IRModule& mod, const PassInfo& info) const;
  TVM_DLL void InstrumentAfterPass(const IRModule& mod, const PassInfo& info) const;
  /* Other fields are omitted. */

 private:
  // The entry of a pass context scope.
  TVM_DLL void EnterWithScope();
  // The exit of a pass context scope.
  TVM_DLL void ExitWithScope();

  // Classes to get the Python `with` like syntax.
  friend class tvm::With<PassContext>;
};

struct PassContextThreadLocalEntry {
  /*! \brief The default pass context. */
  PassContext default_context;
  /*! \brief The current pass context. */
  std::stack<PassContext> context_stack;
  PassContextThreadLocalEntry() {
    default_context = PassContext(make_node<PassContextNode>());
  }
};

/*! \brief The thread-local store to hold the pass context. */
typedef dmlc::ThreadLocalStore<PassContextThreadLocalEntry>
     PassContextThreadLocalStore;

Pass 构造函数#

Pass 基础设施以分层方式设计,可以在不同粒度的 Relay/TIR 程序中工作。引入了纯虚类 PassNode 作为不同优化 Pass 的基类。该类包含几个虚方法,必须由模块、函数或 Pass 序列级别的子类实现。

class PassNode : Object {
  virtual PassInfo Info() const = 0;
  virtual Module operator()(const IRModule& mod
                            const PassContext& pass_ctx) const = 0;
};

该函子展示了如何实现 Pass,即它总是在某个上下文下对 IRModule 进行操作。所有 Pass 都以 ModuleModule 的方式设计。因此,由 Pass 基础设施管理的优化将始终更新整个模块。

已经创建了几个子类来实现不同类型的优化 Pass,例如函数级 Pass、模块级 Pass 和 Pass 序列。每个子类本身都可以充当 Pass 管理器。例如,它们可以收集所需的 Pass 并执行它们,或者根据给定的元数据构建依赖图。它们的完整定义可以在 src/relay/ir/transform.ccsrc/ir/transform.cc 中找到。

模块级 Pass#

模块级 Pass 主要用于全局和过程间优化(inter-procedural optimizations,简称 IPO),类似于 LLVM 中使用的模块 Pass。Relay 中一些需要模块全局信息的典型 Pass,例如 A 范式转换和 lambda 提升等,都属于这一类。在这个级别,用户甚至可以在模块中添加和/或删除函数。请注意,所有 Pass

class ModulePassNode : PassNode {
  PassInfo pass_info;
  runtime::TypedPackedFunc<Module(Module, PassContext)> pass_func;
  Module operator()(const Module& mod, const PassContext& pass_ctx) const final;
  // Other members/methods are omitted
};

pass_info 维护了模块级 Pass 所需的信息。pass_func 描述了实际的优化过程。例如,可能需要对模块执行死代码消除。可以在 pass_func 中实现该算法,并让它在模块上运行。然后,它将删除死代码,包括模块中未使用的函数。请注意,该字段被设计为打包函数,这使得优化可以在 C++ 和 Python 中实现。

函数级 Pass#

函数级 Pass 用于为给定的 Relay/TIR 模块实现各种函数内级别的优化。它从模块的函数列表中一次获取函数进行优化,并生成重写后的 Relay Function 或TIR PrimFunc。大多数 Pass 可以归类到这一类别中,例如 Relay 中的公共子表达式消除和推理简化,以及 TIR 中的向量化和存储扁平化等。

请注意,此级别 Pass 的范围是 Relay 函数或 TIR 原始函数。因此,不能通过这些 Pass 添加或删除函数,因为它们不了解全局信息。

class FunctionPassNode : PassNode {
  PassInfo pass_info;
  runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func;
  Module operator()(const Module& mod, const PassContext& pass_ctx) const final;
  bool SkipFunction(const Function& func) const;
  // Other members/methods are omitted...
};

pass_info 与我刚刚在模块级 Pass 中描述的内容相同。pass_func 接受函数进行优化,它还需要模块,因为可能用它来报告错误。函数可以用“SkipOptimization”进行注释,以便在优化期间忽略它。

Pass 序列#

SequentialPass 类似于 Pytorch 的 nn.Sequential,它包含一系列要执行的 Pass。

class SequentialPassNode : PassNode {
  PassInfo pass_info;
  // Passes need to be executed.
  Array<Pass> passes;
  bool PassEnabled(const PassInfo& info) const;
  Module operator()(const Module& mod, const PassContext& pass_ctx) const final;
};

目前 Relay 中只有少数 Pass 被归入这一类别。例如,FoldScaleAxis 需要在内部调度 ForwardFoldScaleAxisBackwardFoldScaleAxis。此外,建议首先完成 BackwardFoldScaleAxis。因此,这个 Pass 是 SequentialPass 的理想候选者。

以下代码展示了如何在顺序 Pass 中调用各个 Pass。本质上,按照它们被添加到 Pass 列表中的顺序依次执行顺序 Pass 中的每个 Pass。

Module SequentialNode::operator()(const Module& module,
                                  const PassContext& pass_ctx) const {
  Module mod = module;
  for (const Pass& pass : passes) {
    ICHECK(pass.defined()) << "Found undefined pass for optimization.";
    const PassInfo& pass_info = pass->Info();
    if (!PassEnabled(pass_info))  continue;
    for (const auto& it : pass_info->required) {
      const auto* name = it.as<tvm::ir::StringImm>();
      ICHECK(name);
      mod = GetPass(name->value)(mod, pass_ctx);
    }
    mod = pass(mod, pass_ctx);
  }
  return mod;
}

在调用 Pass 时,首先检查该 Pass 是否启用。这是通过首先检查用户是否显式禁用了该 Pass,然后检查用户是否将其指定为必需 Pass 来完成的。如果仍然无法确定该 Pass 是否启用,将检查其 opt_level。只有当其优化级别不低于 Pass 上下文中配置的优化级别时,该 Pass 才会被启用并执行。

要执行 Pass,需要首先使用 Pass 名称从 TVM 打包函数注册表中检索已注册的 Pass。这是可能的,因为每个 Pass 都注册了 API 端点,稍后会展示这一点。

Pass GetPass(const std::string& pass_name) {
  using tvm::runtime::Registry;
  std::string fpass_name = "relay._transform." + pass_name;
  const auto* f = Registry::Get(fpass_name);
  ICHECK(f != nullptr) << "Cannot find " << fpass_name
                      << "to create the pass " << pass_name;
  return (*f)();
}

提供了一些辅助函数来创建上述每种类型的 Pass。这些辅助函数也暴露给 Python 前端,以便用户方便地使用 Python API 创建特定的 Pass 对象。

Pass CreateFunctionPass(
    const runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)>& pass_func,
    int opt_level,
    String name,
    Array<String> required);

Pass CreatePrimFuncPass(
    const runtime::TypedPackedFunc<PrimFunc(PrimFunc, IRModule, PassContext)>& pass_func,
    int opt_level,
    String name,
    Array<String> required);

Pass CreateModulePass(
    const runtime::TypedPackedFunc<IRModule(IRModule, PassContext)>& pass_func,
    int opt_level,
    String name,
    Array<String> required);

Pass Sequential(tvm::Array<Pass> passes, PassInfo pass_info);

Pass 注册#

已经介绍了不同级别 Pass 的概念以及用于编译的上下文。看看用户可以如何轻松地注册 Pass 会很有趣。以常量折叠为例。这个 Pass 已经实现,用于折叠 Relay 函数中的常量(可以在 src/relay/transforms/fold_constant.cc 中找到)。

提供了 API 来执行 ExprExpr 的变换。

Expr FoldConstant(const Expr& expr);

为了将这个 Pass 注册到 Pass 基础设施中,首先需要确定该 Pass 将在哪个级别执行。由于常量折叠发生在单个函数上,应该直观地通过 CreateFunctionPass 为其创建 FunctionPasspass_func 作为打包函数返回,它在 IRModule 中的每个函数上调用 ExprExpr 的API。{} 表示该 Pass 没有先决条件。否则,Pass 开发者必须识别并列出它们。

同时,Pass API 端点以名称 relay._transform.FoldConstant 注册。因此,这个 Pass 成为注册表中的条目,可以在需要时通过 C++(例如上面的 GetPass)和 Python 访问。

namespace transform {

Pass FoldConstant() {
  runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
    [=](Function f, IRModule m, PassContext pc) {
      return Downcast<Function>(FoldConstant(f));
  };
  return CreateFunctionPass(pass_func, 2, "FoldConstant", {});
}

TVM_REGISTER_GLOBAL("relay._transform.FoldConstant")
.set_body_typed(FoldConstant);

}  // namespace transform

为了允许其他 C++ 模块应用这个 Pass,在 include/tvm/relay/transform.h 中声明了自由函数,如下所示:

TVM_DLL Pass FoldConstant();

Pass 检测#

Pass 检测(Instrument)是一种分析 Pass 本身的机制。例如,可以使用该基础设施来了解 Pass 需要多少时间和内存,或者 Pass 如何变换 IR 模块。

PassContext 的生命周期中引入了四个检测点。

TVM_DLL void InstrumentEnterPassContext();
TVM_DLL void InstrumentExitPassContext();
TVM_DLL bool InstrumentBeforePass(const IRModule& mod, const PassInfo& info) const;
TVM_DLL void InstrumentAfterPass(const IRModule& mod, const PassInfo& info) const;

InstrumentEnterPassContext 在进入 PassContext 实例的范围时立即调用。

InstrumentExitPassContext 在离开 PassContext 的范围时调用,或者在 Pass 执行期间发生异常时调用。当检测器被 tvm.transform.PassContext 中的 override_instruments 覆盖时,也会调用此方法。请参阅 在当前 PassContext 中覆盖检测器

InstrumentBeforePass 在执行前调用。如果 Pass 应该运行,InstrumentAfterPass 在执行后调用。行为如下:

if (pass_ctx.InstrumentBeforePass(ir_module, pass_info)) {
  new_ir_module = run_pass(ir_module, pass_ctx);
  pass_ctx.InstrumentAfterPass(new_ir_module, pass_info);
  return new_ir_module;
}

PassInstrument 接口允许你在上述四个方法中运行任意代码。多个 PassInstrument 实例可以注册到 PassContext 中。PassInstrument 实例按照传递给 PassContextinstruments 参数的顺序依次调用。

PassInstrument 提供以下接口:

namespace instrument {

class PassInstrumentNode : public Object {
 public:
  String name;
  virtual void EnterPassContext() const = 0;
  virtual void ExitPassContext() const = 0;
  virtual bool ShouldRun(const IRModule& mod, const transform::PassInfo& info) const = 0;
  virtual void RunBeforePass(const IRModule& mod, const transform::PassInfo& info) const = 0;
  virtual void RunAfterPass(const IRModule& mod, const transform::PassInfo& info) const = 0;
  /* Other fields are omitted. */
};

class PassInstrument : public ObjectRef {
 public:
  TVM_DEFINE_OBJECT_REF_METHODS(PassInstrument, ObjectRef, PassInstrumentNode);
};

}  // namespace instrument

提供了 Python 前端来快速实现 PassInstrument。请参阅 Pass 检测

PassContext 中,PassInstrument 实例的调用顺序如下:

with PassContext(instruments=[pi]) # pi = a PassInstrument implementation.
    pi.EnterPassContext()

    if pi.ShouldRun(Pass1):
        pi.RunBeforePass()
        Pass1()
        pi.RunAfterPass()

    if pi.ShouldRun(Pass2):
        pi.RunBeforePass()
        Pass2()
        pi.RunAfterPass()

    pi.ExitPassContext()

以下是 PassInstrument 接口与 PassContext 方法之间关系的简要介绍。更多详细信息请参阅(src/ir/transform.cc)。

  • InstrumentEnterPassContext

    • EnterPassContext() 按照传递给 PassContextinstruments 的顺序执行。

    • 当异常发生时,PassContext 通过清除所有注册的 PassInstrument 实例来禁用 Pass 检测。

    • 然后,PassContext 执行每个成功完成 EnterPassContext()PassInstrument 实例的 ExitPassContext() 方法。

    • 例如,如果 PassInstrument A、B 和C 注册到 PassContext 中,并且 A 完成了 EnterPassContext(),而 B 抛出异常,那么 C 永远不会执行;A 的 ExitPassContext() 会被执行。

  • InstrumentExitPassContext

    • 每个 PassInstrument 实例的 ExitPassContext() 按照传递给 PassContextinstruments 的顺序执行。

    • 当异常发生时,instruments 会被清除。

    • 在抛出异常的实例之后注册的 PassInstrument 实例不会执行 ExitPassContext

  • InstrumentBeforePass

    • 如果 Pass 未列为必需 Pass,则执行 ShouldRun

    • 如果 Pass 未被 ShouldRun 阻止,则按照 instruments 的顺序执行 RunBeforePass

    • 请注意,InstrumentBeforePass 返回布尔值,指示是否应运行该 Pass。

    • 当异常发生时,它会立即抛出。依赖 Python 的上下文管理器来安全地退出 PassContext (意味着每个检测器的 ExitPassContext 将被运行。对于 C++,请参阅 include/tvm/support/with.h。)

  • InstrumentAfterPass

    • RunAfterPass 按照传递给 PassContextinstruments 的顺序执行。

    • 当异常发生时,它会立即抛出。依赖 Python 的上下文管理器或 With 类(include/tvm/support/with.h)来安全地退出 PassContext

内置检测器#

有几个内置的检测器。那些标记为 TODO 的尚未实现。

  • PassTimingInstrument (see src/ir/instrument.cc)

    • 分析 Pass 的执行时间。

  • PrintIRBefore(TODO)

    • 在 Pass 变换之前打印 IR 模块。如果在 Pass 周围插入 tvm.transform.PrintIR(),也可以实现此目的。然而,使用 PassInstrument,不需要修改 Pass 的顺序。

  • PrintAfter(TODO)

    • 在 Pass 变换后打印 IR 模块。

Python 前端#

前端只需要一些简单的 API。例如,可以为用户提供以下 API 来创建和执行 Pass(完整实现见 python/tvm/relay/transform/transform.pypython/tvm/ir/transform.py)。后端接收信息并决定应该使用哪个函数来创建 Pass 对象。

PassContext#

Python 前端为 PassContext 提供了包装器,通过重写 __enter____exit__ 来启用 with 语法。提供了 current 静态方法,供用户获取在某个范围内正在使用的上下文。

@tvm._ffi.register_object("transform.PassContext")
class PassContext(tvm.runtime.Object):
    def __enter__(self):
        _transform.EnterPassContext(self)
        return self

    def __exit__(self, ptype, value, trace, config):
        _transform.ExitPassContext(self)

    @staticmethod
    def current():
        """Return the current pass context."""
        return _transform.GetCurrentPassContext()

PassContext 用于配置编译选项,包括优化级别和所需/禁用的 Pass。它还可以接受配置字典,以便不同的 Pass 可以方便地获取传递的数据,例如回退设备信息和循环展开的步长/深度等。为了能够获取所需的配置,必须通过 TVM_REGISTER_PASS_CONFIG_OPTION 注册键。例如,以下内容用于循环展开 Pass。

TVM_REGISTER_PASS_CONFIG_OPTION("tir.UnrollLoop", UnrollLoopConfig);

更多详细信息请参阅 src/tir/transforms/unroll_loop.cc

Pass 对象#

Pass 是所有 Pass 对象的基类。这里的所有方法都是在后端实现的简单包装器。它们是为了让用户能够方便地在 Python 中与基类交互而定义的。Pass 基类中只定义了 __call__,以使子类成为可调用对象,从而可以轻松调用它们(例如,pass_xx(arg))以执行。

@register_relay_node
class Pass(RelayNode):
   def __call__(self, mod):
       return _transform.RunPass(self, mod)

提供了一些辅助 API,以便从 Python 前端轻松创建 Pass,并让 Pass 基础设施控制执行。例如,向用户提供了 module_passfunction_passsequential,以便他们可以自定义自己的 Pass 或 Pass 管道。

对于所有在 C++ 后端实现的 Pass,分别在 python/tvm/ir/transform.pypython/tvm/relay/transform/transform.py 中提供了相应的 Python API。例如,常量折叠有如下的 Python API:

def FoldConstant():
    return _transform.FoldConstant()

用户可以通过装饰器构建 Pass,如下所示:

 @relay.transform.module_pass(opt_level=2)
 def transform(mod, ctx):
    tp = relay.TensorType((10,), "float32")
    x = relay.var("x", tp)
    gv = relay.GlobalVar("abs")
    func = relay.Function([x], relay.abs(x))
    new_mod = tvm.IRModule({gv: func})
    new_mod.update(mod)
    return new_mod

module_pass = transform
assert isinstance(module_pass, transform.ModulePass)
assert module_pass.info.opt_level == 2

这里的 transform 函数向输入模块添加了 abs 函数,但它可以是模块级别的任何自定义优化。创建这个 module_pass 后,用户可以将其应用于任何 Relay 模块。例如,可以构建空模块并应用这个 Pass 来添加 abs 函数。

mod = tvm.IRModule()
mod = module_pass(mod)

相应地,也为 function_pass 提供了这样的函数。例如,示例函数级 Pass 可以写成如下形式:

@relay.transform.function_pass(opt_level=1)
class TestReplaceFunc:
   def __init__(self, new_func):
      self.new_func = new_func
      def transform_function(self, func, mod, ctx):
         # Just for demo purposes
         # Transform func to new_func
         return self.new_func

x = relay.var("x", shape=(10, 20))
f1 = relay.Function([x], x)
f2 = relay.Function([x], relay.log(x))
# fpass is now a special pass that replaces every
# function to f1
fpass = TestReplaceFunc(f1)
# Now every function in input_mod is replaced by f1
res_mod = fpass(input_mod)

或者,用户也可以直接注册 Pass 而不使用装饰器,然后调用它。有关如何自定义优化管道以及调试 Relay 和 TIR Pass 的更多示例,请参阅 use pass infra 教程。

Pass 检测#

可以通过在实现以下方法的类上使用 pass_instrument 装饰器(python/tvm/ir/instrument.py)来实现 PassInstrument。请注意,建议使用 pass_instrument 装饰器来实现 PassInstrument,而不是重写或子类化。

  • enter_pass_ctx

    • 此方法在进入 PassContext 时运行。

  • exit_pass_ctx

    • 此方法在退出 PassContext 时运行。

  • should_run

    • 此方法在 Pass 执行之前运行,返回布尔值,指示是否应运行该 Pass。

  • run_before_pass

    • 如果 Pass 应该运行,此方法在 Pass 执行之前运行。

  • run_after_pass

    • 此方法在 Pass 执行后立即运行。

PassInstrument 实例可以通过 tvm.transform.PassContext 中的 instruments 参数注册。

use pass instrument 教程提供了如何使用 Python API 实现 PassInstrument 的示例。

在当前 PassContext 中覆盖检测器#

提供了 override_instruments 方法来覆盖当前 PassContextinstruments。例如,如果 Pass 在没有显式创建新 PassContext 的情况下运行,仍然可以通过以下方式将 PassInstrument 注册到全局 PassContext 中:

cur_pass_ctx = tvm.transform.PassContext.current()
# override PassInstrument instances
cur_pass_ctx.override_instruments([pass_inst])
mod = pass_seq(mod)
result = pass_inst.get_result()

请注意,当调用 override_instruments 时,会调用旧 PassInstrument 实例的 exit_pass_ctx 方法。然后调用新 PassInstrumententer_pass_ctx 方法。