变换概述#

tvm.ir.transform 定义了 IR 变体之间的通用 pass 的基础设施。

tvm.ir.transform.PassInfo 类包含 pass 所需的元数据。它是运行优化或分析所需信息的容器。当需要更多元数据时,可以通过添加新成员来扩展这个类。

  • namestr)是 pass 名称

  • opt_levelint) 表示在哪个优化级别将启用传递

  • requiredlist[str]) 表示执行某个传递所需的依赖

所有 Pass 的基类。这里的所有方法都只是在后端实现的简单包装器。它们的定义是为了方便用户与基类进行交互。

tvm.ir.transform.PassContext 表示优化/分析运行的基础。

每个 pass 上下文都包含许多辅助信息,用于帮助优化 pass。这些信息包括记录优化过程中误差的误差报告器等。

Relax/tir 程序的优化可以应用在不同的粒度上,即

模块级 Pass#

tvm.ir.transform.ModulePass 是在 IRModule 上工作的 pass。用户不需要直接与该类交互。相反,应该通过 tvm.ir.transform.module_pass() 创建模块级传递,因为 module_pass API 的设计足够灵活,以不同的方式处理模块级 pass 的创建。此外,可以从基类访问模块 pass 的所有成员。

类模式#

import numpy as np
import tvm
from tvm import relax, tir
from tvm.script import relax as R

@tvm.transform.module_pass(opt_level=2)
class CustomPipeline:
    def __init__(self, enable_fold):
        self.enable_fold = enable_fold
        self.cse = relax.transform.EliminateCommonSubexpr()
        self.const_fold = relax.transform.FoldConstant()

    def transform_module(self, mod, ctx):
        mod = self.cse(mod)
        if self.enable_fold:
            mod = self.const_fold(mod)
        return mod

# 创建定制的 pipeline 实例
pipeline = CustomPipeline(enable_fold=False)
assert isinstance(pipeline, tvm.transform.ModulePass)
from tvm.script import ir as I
from tvm.script import relax as R

@I.ir_module
class Model:
    @R.function
    def main(
        x: R.Tensor((1, 64, 56, 56), dtype="float32"),
        weight: R.Tensor((64, 64, 3, 3), dtype="float32"),
    ) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
        R.func_attr({"num_input": 1})
        c_data = np.empty((1, 64, 54, 54)).astype("float32")
        c = R.const(c_data)
        with R.dataflow():
            conv = R.nn.conv2d(x, weight)
            y = R.add(c, c)
            y = R.multiply(y, R.const(2, "float32"))
            y = R.add(conv, y)
            z = R.add(y, c)
            z1 = R.add(y, c)
            z2 = R.add(z, z1)
            gv = z2
            R.output(gv)
        return gv

m = Model
m.show()
pipeline = CustomPipeline(enable_fold=True)
pipeline(m).show()

Hide code cell output

# from tvm.script import ir as I
# from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function
    def main(x: R.Tensor((1, 64, 56, 56), dtype="float32"), weight: R.Tensor((64, 64, 3, 3), dtype="float32")) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
        R.func_attr({"num_input": 1})
        c: R.Tensor((1, 64, 54, 54), dtype="float32") = metadata["relax.expr.Constant"][0]
        with R.dataflow():
            conv: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.conv2d(x, weight, strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            y: R.Tensor((1, 64, 54, 54), dtype="float32") = R.add(c, c)
            y_1: R.Tensor((1, 64, 54, 54), dtype="float32") = R.multiply(y, R.const(2.0, "float32"))
            y_2: R.Tensor((1, 64, 54, 54), dtype="float32") = R.add(conv, y_1)
            z: R.Tensor((1, 64, 54, 54), dtype="float32") = R.add(y_2, c)
            z1: R.Tensor((1, 64, 54, 54), dtype="float32") = R.add(y_2, c)
            z2: R.Tensor((1, 64, 54, 54), dtype="float32") = R.add(z, z1)
            gv: R.Tensor((1, 64, 54, 54), dtype="float32") = z2
            R.output(gv)
        return gv

# Metadata omitted. Use show_meta=True in script() method to show it.
# from tvm.script import ir as I
# from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function
    def main(x: R.Tensor((1, 64, 56, 56), dtype="float32"), weight: R.Tensor((64, 64, 3, 3), dtype="float32")) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
        R.func_attr({"num_input": 1})
        with R.dataflow():
            conv: R.Tensor((1, 64, 54, 54), dtype="float32") = R.nn.conv2d(x, weight, strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            y: R.Tensor((1, 64, 54, 54), dtype="float32") = R.add(conv, metadata["relax.expr.Constant"][0])
            z: R.Tensor((1, 64, 54, 54), dtype="float32") = R.add(y, metadata["relax.expr.Constant"][1])
            z2: R.Tensor((1, 64, 54, 54), dtype="float32") = R.add(z, z)
            gv: R.Tensor((1, 64, 54, 54), dtype="float32") = z2
            R.output(gv)
        return gv

# Metadata omitted. Use show_meta=True in script() method to show it.

函数模式#

以下代码通过装饰用户定义的变换函数来创建模块传递。

from tvm.script import ir as I
from tvm.script import relax as R

@tvm.transform.module_pass(opt_level=2)
def transform(mod, ctx):
    @I.ir_module
    class Model:
        @R.function
        def main(
            x: R.Tensor((1, 2), dtype="float32"),
        ) -> R.Tensor((1, 2), dtype="float32"):
            with R.dataflow():
                y = R.abs(x)
                gv = y
                R.output(gv)
            return gv
    new_mod = tvm.IRModule()
    new_mod['abs'] = Model["main"]
    new_mod.update(mod)
    return new_mod

module_pass = transform
assert isinstance(module_pass, tvm.transform.ModulePass)
assert module_pass.info.opt_level == 2
from tvm.script import ir as I
from tvm.script import relax as R

@I.ir_module
class Model:
    @R.function
    def main(
        x: R.Tensor((1, 2,), dtype="float32"),
    ) -> R.Tensor((1, 2,), dtype="float32"):
        with R.dataflow():
            y = R.add(x, x)
            gv = y
            R.output(gv)
        return gv

# 给定模块 `m`,优化可以如下调用:
m = Model
updated_mod = module_pass(m)
# 现在,函数 `abs` 应该被添加到模块 `m` 中。
updated_mod.show()
# from tvm.script import ir as I
# from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function
    def abs(x: R.Tensor((1, 2), dtype="float32")) -> R.Tensor((1, 2), dtype="float32"):
        R.func_attr({"global_symbol": "main"})
        with R.dataflow():
            y: R.Tensor((1, 2), dtype="float32") = R.abs(x)
            gv: R.Tensor((1, 2), dtype="float32") = y
            R.output(gv)
        return gv

    @R.function
    def main(x: R.Tensor((1, 2), dtype="float32")) -> R.Tensor((1, 2), dtype="float32"):
        with R.dataflow():
            y: R.Tensor((1, 2), dtype="float32") = R.add(x, x)
            gv: R.Tensor((1, 2), dtype="float32") = y
            R.output(gv)
        return gv

Relax 函数级 Pass#

tvm.relax.transform.function_pass() 用于变换 Relax 函数。