变换#

在本节中,将深入探讨 Relax 程序的变换。变换是编译流程中的关键组成部分,用于优化并与硬件后端进行集成。

首先,按照在 上一节 中所做的那样,创建简单的 Relax 程序。

Hide code cell content
import tvm
from tvm import IRModule, relax
from tvm.relax.frontend import nn


class NNModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(784, 128)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu1(x)
        x = self.fc2(x)
        return x


origin_mod, params = NNModule().export_tvm(
    {"forward": {"x": nn.spec.Tensor(("n", 784), "float32")}}
)
origin_mod.show()
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function
    def forward(x: R.Tensor(("n", 784), dtype="float32"), fc1_weight: R.Tensor((128, 784), dtype="float32"), fc1_bias: R.Tensor((128,), dtype="float32"), fc2_weight: R.Tensor((10, 128), dtype="float32"), fc2_bias: R.Tensor((10,), dtype="float32")) -> R.Tensor(("n", 10), dtype="float32"):
        n = T.int64()
        R.func_attr({"num_input": 1})
        with R.dataflow():
            permute_dims: R.Tensor((784, 128), dtype="float32") = R.permute_dims(fc1_weight, axes=None)
            matmul: R.Tensor((n, 128), dtype="float32") = R.matmul(x, permute_dims, out_dtype="void")
            add: R.Tensor((n, 128), dtype="float32") = R.add(matmul, fc1_bias)
            relu: R.Tensor((n, 128), dtype="float32") = R.nn.relu(add)
            permute_dims1: R.Tensor((128, 10), dtype="float32") = R.permute_dims(fc2_weight, axes=None)
            matmul1: R.Tensor((n, 10), dtype="float32") = R.matmul(relu, permute_dims1, out_dtype="void")
            add1: R.Tensor((n, 10), dtype="float32") = R.add(matmul1, fc2_bias)
            gv: R.Tensor((n, 10), dtype="float32") = add1
            R.output(gv)
        return gv

应用变换#

Pass 是对程序应用变换的主要方式。可以对程序应用 Pass。作为第一步,让应用内置的 LegalizeOps Pass,将高级算子降级为低级算子。

mod = tvm.relax.transform.LegalizeOps()(origin_mod)
mod.show()
Hide code cell output
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func(private=True)
    def add(var_matmul: T.handle, fc1_bias: T.Buffer((T.int64(128),), "float32"), var_T_add: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        n = T.int64()
        matmul = T.match_buffer(var_matmul, (n, T.int64(128)))
        T_add = T.match_buffer(var_T_add, (n, T.int64(128)))
        # with T.block("root"):
        for ax0, ax1 in T.grid(n, T.int64(128)):
            with T.block("T_add"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(matmul[v_ax0, v_ax1], fc1_bias[v_ax1])
                T.writes(T_add[v_ax0, v_ax1])
                T_add[v_ax0, v_ax1] = matmul[v_ax0, v_ax1] + fc1_bias[v_ax1]

    @T.prim_func(private=True)
    def add1(var_matmul1: T.handle, fc2_bias: T.Buffer((T.int64(10),), "float32"), var_T_add: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        n = T.int64()
        matmul1 = T.match_buffer(var_matmul1, (n, T.int64(10)))
        T_add = T.match_buffer(var_T_add, (n, T.int64(10)))
        # with T.block("root"):
        for ax0, ax1 in T.grid(n, T.int64(10)):
            with T.block("T_add"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(matmul1[v_ax0, v_ax1], fc2_bias[v_ax1])
                T.writes(T_add[v_ax0, v_ax1])
                T_add[v_ax0, v_ax1] = matmul1[v_ax0, v_ax1] + fc2_bias[v_ax1]

    @T.prim_func(private=True)
    def matmul(var_x: T.handle, permute_dims: T.Buffer((T.int64(784), T.int64(128)), "float32"), var_matmul: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        n = T.int64()
        x = T.match_buffer(var_x, (n, T.int64(784)))
        matmul = T.match_buffer(var_matmul, (n, T.int64(128)))
        # with T.block("root"):
        for i0, i1, k in T.grid(n, T.int64(128), T.int64(784)):
            with T.block("matmul"):
                v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
                T.reads(x[v_i0, v_k], permute_dims[v_k, v_i1])
                T.writes(matmul[v_i0, v_i1])
                with T.init():
                    matmul[v_i0, v_i1] = T.float32(0.0)
                matmul[v_i0, v_i1] = matmul[v_i0, v_i1] + x[v_i0, v_k] * permute_dims[v_k, v_i1]

    @T.prim_func(private=True)
    def matmul1(var_relu: T.handle, permute_dims1: T.Buffer((T.int64(128), T.int64(10)), "float32"), var_matmul: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        n = T.int64()
        relu = T.match_buffer(var_relu, (n, T.int64(128)))
        matmul = T.match_buffer(var_matmul, (n, T.int64(10)))
        # with T.block("root"):
        for i0, i1, k in T.grid(n, T.int64(10), T.int64(128)):
            with T.block("matmul"):
                v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
                T.reads(relu[v_i0, v_k], permute_dims1[v_k, v_i1])
                T.writes(matmul[v_i0, v_i1])
                with T.init():
                    matmul[v_i0, v_i1] = T.float32(0.0)
                matmul[v_i0, v_i1] = matmul[v_i0, v_i1] + relu[v_i0, v_k] * permute_dims1[v_k, v_i1]

    @T.prim_func(private=True)
    def relu(var_add: T.handle, var_compute: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        n = T.int64()
        add = T.match_buffer(var_add, (n, T.int64(128)))
        compute = T.match_buffer(var_compute, (n, T.int64(128)))
        # with T.block("root"):
        for i0, i1 in T.grid(n, T.int64(128)):
            with T.block("compute"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(add[v_i0, v_i1])
                T.writes(compute[v_i0, v_i1])
                compute[v_i0, v_i1] = T.max(add[v_i0, v_i1], T.float32(0.0))

    @T.prim_func(private=True)
    def transpose(fc1_weight: T.Buffer((T.int64(128), T.int64(784)), "float32"), T_transpose: T.Buffer((T.int64(784), T.int64(128)), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1 in T.grid(T.int64(784), T.int64(128)):
            with T.block("T_transpose"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(fc1_weight[v_ax1, v_ax0])
                T.writes(T_transpose[v_ax0, v_ax1])
                T_transpose[v_ax0, v_ax1] = fc1_weight[v_ax1, v_ax0]

    @T.prim_func(private=True)
    def transpose1(fc2_weight: T.Buffer((T.int64(10), T.int64(128)), "float32"), T_transpose: T.Buffer((T.int64(128), T.int64(10)), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1 in T.grid(T.int64(128), T.int64(10)):
            with T.block("T_transpose"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(fc2_weight[v_ax1, v_ax0])
                T.writes(T_transpose[v_ax0, v_ax1])
                T_transpose[v_ax0, v_ax1] = fc2_weight[v_ax1, v_ax0]

    @R.function
    def forward(x: R.Tensor(("n", 784), dtype="float32"), fc1_weight: R.Tensor((128, 784), dtype="float32"), fc1_bias: R.Tensor((128,), dtype="float32"), fc2_weight: R.Tensor((10, 128), dtype="float32"), fc2_bias: R.Tensor((10,), dtype="float32")) -> R.Tensor(("n", 10), dtype="float32"):
        n = T.int64()
        R.func_attr({"num_input": 1})
        cls = Module
        with R.dataflow():
            permute_dims = R.call_tir(cls.transpose, (fc1_weight,), out_sinfo=R.Tensor((784, 128), dtype="float32"))
            matmul = R.call_tir(cls.matmul, (x, permute_dims), out_sinfo=R.Tensor((n, 128), dtype="float32"))
            add = R.call_tir(cls.add, (matmul, fc1_bias), out_sinfo=R.Tensor((n, 128), dtype="float32"))
            relu = R.call_tir(cls.relu, (add,), out_sinfo=R.Tensor((n, 128), dtype="float32"))
            permute_dims1 = R.call_tir(cls.transpose1, (fc2_weight,), out_sinfo=R.Tensor((128, 10), dtype="float32"))
            matmul1 = R.call_tir(cls.matmul1, (relu, permute_dims1), out_sinfo=R.Tensor((n, 10), dtype="float32"))
            add1 = R.call_tir(cls.add1, (matmul1, fc2_bias), out_sinfo=R.Tensor((n, 10), dtype="float32"))
            gv: R.Tensor((n, 10), dtype="float32") = add1
            R.output(gv)
        return gv

从输出可以看到,程序中的高级算子(即 relax.op)已被相应的低级算子(即 relax.call_tir)所取代。

接下来,尝试应用算子融合,这是机器学习编译器中广泛使用的一种优化技术。请注意,在 Relax 中,融合优化是通过一系列 Pass 的协作完成的。可以按顺序应用这些 Pass。

mod = tvm.ir.transform.Sequential(
    [
        tvm.relax.transform.AnnotateTIROpPattern(),
        tvm.relax.transform.FuseOps(),
        tvm.relax.transform.FuseTIR(),
    ]
)(mod)
mod.show()
Hide code cell output
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func(private=True)
    def fused_matmul1_add1(p_relu: T.handle, permute_dims1: T.Buffer((T.int64(128), T.int64(10)), "float32"), fc2_bias: T.Buffer((T.int64(10),), "float32"), p_output0: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        n = T.int64()
        relu = T.match_buffer(p_relu, (n, T.int64(128)))
        T_add_intermediate = T.match_buffer(p_output0, (n, T.int64(10)))
        # with T.block("root"):
        matmul_intermediate = T.alloc_buffer((n, T.int64(10)))
        for i0, i1, k in T.grid(n, T.int64(10), T.int64(128)):
            with T.block("matmul"):
                v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
                T.reads(relu[v_i0, v_k], permute_dims1[v_k, v_i1])
                T.writes(matmul_intermediate[v_i0, v_i1])
                with T.init():
                    matmul_intermediate[v_i0, v_i1] = T.float32(0.0)
                matmul_intermediate[v_i0, v_i1] = matmul_intermediate[v_i0, v_i1] + relu[v_i0, v_k] * permute_dims1[v_k, v_i1]
        for ax0, ax1 in T.grid(n, T.int64(10)):
            with T.block("T_add"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(matmul_intermediate[v_ax0, v_ax1], fc2_bias[v_ax1])
                T.writes(T_add_intermediate[v_ax0, v_ax1])
                T_add_intermediate[v_ax0, v_ax1] = matmul_intermediate[v_ax0, v_ax1] + fc2_bias[v_ax1]

    @T.prim_func(private=True)
    def fused_matmul_add_relu(p_x: T.handle, permute_dims: T.Buffer((T.int64(784), T.int64(128)), "float32"), fc1_bias: T.Buffer((T.int64(128),), "float32"), p_output0: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        n = T.int64()
        x = T.match_buffer(p_x, (n, T.int64(784)))
        compute_intermediate = T.match_buffer(p_output0, (n, T.int64(128)))
        # with T.block("root"):
        matmul_intermediate = T.alloc_buffer((n, T.int64(128)))
        T_add_intermediate = T.alloc_buffer((n, T.int64(128)))
        for i0, i1, k in T.grid(n, T.int64(128), T.int64(784)):
            with T.block("matmul"):
                v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
                T.reads(x[v_i0, v_k], permute_dims[v_k, v_i1])
                T.writes(matmul_intermediate[v_i0, v_i1])
                with T.init():
                    matmul_intermediate[v_i0, v_i1] = T.float32(0.0)
                matmul_intermediate[v_i0, v_i1] = matmul_intermediate[v_i0, v_i1] + x[v_i0, v_k] * permute_dims[v_k, v_i1]
        for ax0, ax1 in T.grid(n, T.int64(128)):
            with T.block("T_add"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(matmul_intermediate[v_ax0, v_ax1], fc1_bias[v_ax1])
                T.writes(T_add_intermediate[v_ax0, v_ax1])
                T_add_intermediate[v_ax0, v_ax1] = matmul_intermediate[v_ax0, v_ax1] + fc1_bias[v_ax1]
        for i0, i1 in T.grid(n, T.int64(128)):
            with T.block("compute"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(T_add_intermediate[v_i0, v_i1])
                T.writes(compute_intermediate[v_i0, v_i1])
                compute_intermediate[v_i0, v_i1] = T.max(T_add_intermediate[v_i0, v_i1], T.float32(0.0))

    @T.prim_func(private=True)
    def transpose(fc1_weight: T.Buffer((T.int64(128), T.int64(784)), "float32"), T_transpose: T.Buffer((T.int64(784), T.int64(128)), "float32")):
        T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1 in T.grid(T.int64(784), T.int64(128)):
            with T.block("T_transpose"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(fc1_weight[v_ax1, v_ax0])
                T.writes(T_transpose[v_ax0, v_ax1])
                T_transpose[v_ax0, v_ax1] = fc1_weight[v_ax1, v_ax0]

    @T.prim_func(private=True)
    def transpose1(fc2_weight: T.Buffer((T.int64(10), T.int64(128)), "float32"), T_transpose: T.Buffer((T.int64(128), T.int64(10)), "float32")):
        T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1 in T.grid(T.int64(128), T.int64(10)):
            with T.block("T_transpose"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(fc2_weight[v_ax1, v_ax0])
                T.writes(T_transpose[v_ax0, v_ax1])
                T_transpose[v_ax0, v_ax1] = fc2_weight[v_ax1, v_ax0]

    @R.function
    def forward(x: R.Tensor(("n", 784), dtype="float32"), fc1_weight: R.Tensor((128, 784), dtype="float32"), fc1_bias: R.Tensor((128,), dtype="float32"), fc2_weight: R.Tensor((10, 128), dtype="float32"), fc2_bias: R.Tensor((10,), dtype="float32")) -> R.Tensor(("n", 10), dtype="float32"):
        n = T.int64()
        R.func_attr({"num_input": 1})
        cls = Module
        with R.dataflow():
            permute_dims = R.call_tir(cls.transpose, (fc1_weight,), out_sinfo=R.Tensor((784, 128), dtype="float32"))
            lv = R.call_tir(cls.fused_matmul_add_relu, (x, permute_dims, fc1_bias), out_sinfo=R.Tensor((n, 128), dtype="float32"))
            permute_dims1 = R.call_tir(cls.transpose1, (fc2_weight,), out_sinfo=R.Tensor((128, 10), dtype="float32"))
            gv = R.call_tir(cls.fused_matmul1_add1, (lv, permute_dims1, fc2_bias), out_sinfo=R.Tensor((n, 10), dtype="float32"))
            R.output(gv)
        return gv

结果显示,matmuladdrelu 算子被融合到了内核中(即 call_tir)。

有关所有内置 Pass 的详细信息,请参阅 tvm.relax.transform

自定义 Pass#

也可以定义自己的 Pass。以下是将 relu 算子重写为 gelu 算子的示例。

首先,需要编写 Relax IR Mutator 来执行重写。

from tvm.relax.expr_functor import PyExprMutator, mutator


@mutator
class ReluRewriter(PyExprMutator):
    def __init__(self, mod):
        super().__init__(mod)

    def visit_call_(self, call: relax.Call) -> relax.Expr:
        # visit the relax.Call expr, and only handle the case when op is relax.nn.relu
        if call.op.name == "relax.nn.relu":
            return relax.op.nn.gelu(call.args[0])

        return super().visit_call_(call)

然后,可以编写 Pass,将 Mutator 应用到整个模块中。

@tvm.transform.module_pass(opt_level=0, name="ReluToGelu")
class ReluToGelu:  # pylint: disable=too-few-public-methods
    def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule:
        """IRModule-level transformation"""
        rewriter = ReluRewriter(mod)
        for g_var, func in mod.functions_items():
            if isinstance(func, relax.Function):
                func = rewriter.visit_expr(func)
                rewriter.builder_.update_func(g_var, func)
        return rewriter.builder_.get()


mod = ReluToGelu()(origin_mod)
mod.show()
Hide code cell output
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function
    def forward(x: R.Tensor(("n", 784), dtype="float32"), fc1_weight: R.Tensor((128, 784), dtype="float32"), fc1_bias: R.Tensor((128,), dtype="float32"), fc2_weight: R.Tensor((10, 128), dtype="float32"), fc2_bias: R.Tensor((10,), dtype="float32")) -> R.Tensor(("n", 10), dtype="float32"):
        n = T.int64()
        R.func_attr({"num_input": 1})
        with R.dataflow():
            permute_dims: R.Tensor((784, 128), dtype="float32") = R.permute_dims(fc1_weight, axes=None)
            matmul: R.Tensor((n, 128), dtype="float32") = R.matmul(x, permute_dims, out_dtype="void")
            add: R.Tensor((n, 128), dtype="float32") = R.add(matmul, fc1_bias)
            relu: R.Tensor((n, 128), dtype="float32") = R.nn.gelu(add)
            permute_dims1: R.Tensor((128, 10), dtype="float32") = R.permute_dims(fc2_weight, axes=None)
            matmul1: R.Tensor((n, 10), dtype="float32") = R.matmul(relu, permute_dims1, out_dtype="void")
            add1: R.Tensor((n, 10), dtype="float32") = R.add(matmul1, fc2_bias)
            gv: R.Tensor((n, 10), dtype="float32") = add1
            R.output(gv)
        return gv

打印输出显示,relax.nn.relu 运算符已被重写为 relax.nn.gelu 运算符。

有关 Mutator 的详细信息,请参阅 PyExprMutator

总结#

在本节中,展示了如何对 Relax 程序应用转换。还展示了如何定义和应用自定义转换。