自定义优化#

Apache TVM 的一个主要设计目标是使优化流水线易于定制,无论是研究或开发目的,还是迭代工程优化。

import set_env
import os
import tempfile
import numpy as np
import tvm
from tvm import IRModule, relax
from tvm.relax.frontend import nn

可组合IRModule优化#

Apache TVM提供了一种灵活的方式来优化 IRModule。围绕 IRModule 优化的所有运算都可以与现有流水线组合。请注意,每个优化可以聚焦于 部分计算图,实现局部 lower 或者局部优化。

准备 Relax 模块#

我们首先准备一个Relax模块。这个模块可以从其他框架导入,用 NN 模块前端或 TVMScript 构建。这里我们使用一个简单的神经网络模型作为例子。

class RelaxModel(nn.Module):
    def __init__(self):
        super(RelaxModel, self).__init__()
        self.fc1 = nn.Linear(784, 256)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(256, 10, bias=False)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu1(x)
        x = self.fc2(x)
        return x


input_shape = (1, 784)
mod, params = RelaxModel().export_tvm({"forward": {"x": nn.spec.Tensor(input_shape, "float32")}})
mod.show()
# from tvm.script import ir as I
# from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function
    def forward(x: R.Tensor((1, 784), dtype="float32"), fc1_weight: R.Tensor((256, 784), dtype="float32"), fc1_bias: R.Tensor((256,), dtype="float32"), fc2_weight: R.Tensor((10, 256), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
        R.func_attr({"num_input": 1})
        with R.dataflow():
            permute_dims: R.Tensor((784, 256), dtype="float32") = R.permute_dims(fc1_weight, axes=None)
            matmul: R.Tensor((1, 256), dtype="float32") = R.matmul(x, permute_dims, out_dtype="void")
            add: R.Tensor((1, 256), dtype="float32") = R.add(matmul, fc1_bias)
            relu: R.Tensor((1, 256), dtype="float32") = R.nn.relu(add)
            permute_dims1: R.Tensor((256, 10), dtype="float32") = R.permute_dims(fc2_weight, axes=None)
            matmul1: R.Tensor((1, 10), dtype="float32") = R.matmul(relu, permute_dims1, out_dtype="void")
            gv: R.Tensor((1, 10), dtype="float32") = matmul1
            R.output(gv)
        return gv

库调度#

我们希望快速尝试针对特定平台(例如 GPU)的变体库优化。我们可以为特定平台和算子编写一个特定的调度过程。这里我们展示如何为某些模式调度 CUBLAS 库。

备注

本教程仅演示了针对 CUBLAS 的单个算子调度,突出显示了优化流水线的灵活性。在真实案例中,我们可以导入多个模式并将它们调度到不同的内核。

# Import cublas pattern
import tvm.relax.backend.contrib.cublas as _cublas


# Define a new pass for CUBLAS dispatch
@tvm.transform.module_pass(opt_level=0, name="CublasDispatch")
class CublasDispatch:
    def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule:
        # Check if CUBLAS is enabled
        if not tvm.get_global_func("relax.ext.cublas", True):
            raise Exception("CUBLAS is not enabled.")

        # Get interested patterns
        patterns = [relax.backend.get_pattern("cublas.matmul_transposed_bias_relu")]
        # Note in real-world cases, we usually get all patterns
        # patterns = relax.backend.get_patterns_with_prefix("cublas")

        # Fuse ops by patterns and then run codegen
        mod = relax.transform.FuseOpsByPattern(patterns, annotate_codegen=True)(mod)
        mod = relax.transform.RunCodegen()(mod)
        return mod


mod = CublasDispatch()(mod)
mod.show()
# from tvm.script import ir as I
# from tvm.script import relax as R

@I.ir_module
class Module:
    I.module_attrs({"external_mods": [metadata["runtime.Module"][0]]})
    @R.function
    def forward(x: R.Tensor((1, 784), dtype="float32"), fc1_weight: R.Tensor((256, 784), dtype="float32"), fc1_bias: R.Tensor((256,), dtype="float32"), fc2_weight: R.Tensor((10, 256), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
        R.func_attr({"num_input": 1})
        with R.dataflow():
            lv = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_relu_cublas", (fc1_weight, x, fc1_bias), out_sinfo=R.Tensor((1, 256), dtype="float32"))
            permute_dims1: R.Tensor((256, 10), dtype="float32") = R.permute_dims(fc2_weight, axes=None)
            matmul1: R.Tensor((1, 10), dtype="float32") = R.matmul(lv, permute_dims1, out_dtype="void")
            gv: R.Tensor((1, 10), dtype="float32") = matmul1
            R.output(gv)
        return gv

# Metadata omitted. Use show_meta=True in script() method to show it.

调度过程之后#

我们可以看到第一个 nn.Linearnn.ReLU 被融合并重写为一个 call_dps_packed 函数,该函数调用CUBLAS库。值得注意的是,其他部分没有改变,这意味着我们可以有选择地为某些计算调度优化。

自动调优#

在之前的例子基础上,我们可以通过自动调优进一步优化模型的 其余计算部分。这里我们展示如何使用元调度来自动调优模型。

我们可以使用 MetaScheduleTuneTIR 过程来简化模型调优,而 MetaScheduleApplyDatabase 过程则将最佳配置应用到模型上。调优过程将生成搜索空间,调优模型,接下来的步骤将把最佳配置应用到模型上。在运行这些过程之前,我们需要通过 LegalizeOps 将 Relax 算子降低为 TensorIR 函数。

device = tvm.cuda(0)
target = tvm.target.Target.from_device(device)
if os.getenv("CI", "") != "true":
    trials = 2000
    with target, tempfile.TemporaryDirectory() as tmp_dir:
        mod = tvm.ir.transform.Sequential(
            [
                relax.get_pipeline("zero"),
                relax.transform.MetaScheduleTuneTIR(work_dir=tmp_dir, max_trials_global=trials),
                relax.transform.MetaScheduleApplyDatabase(work_dir=tmp_dir),
            ]
        )(mod)

    mod.show()

DLight 规则#

DLight 规则是一组用于调度和优化内核的默认规则。DLight规则旨在实现快速编译和公平的性能。在某些情况下,例如语言模型,DLight提供出色的性能,而对于通用模型,它在性能和编译时间之间取得平衡。

from tvm import dlight as dl

# Apply DLight rules
with target:
    mod = tvm.ir.transform.Sequential(
        [
            relax.get_pipeline("zero"),
            dl.ApplyDefaultSchedule(  # pylint: disable=not-callable
                dl.gpu.Matmul(),
                dl.gpu.GEMV(),
                dl.gpu.Reduction(),
                dl.gpu.GeneralReduction(),
                dl.gpu.Fallback(),
            ),
        ]
    )(mod)

mod.show()
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    I.module_attrs({"external_mods": [metadata["runtime.Module"][0]]})
    @T.prim_func(private=True)
    def matmul(lv: T.Buffer((T.int64(1), T.int64(256)), "float32"), permute_dims1: T.Buffer((T.int64(256), T.int64(10)), "float32"), matmul: T.Buffer((T.int64(1), T.int64(10)), "float32")):
        T.func_attr({"op_pattern": 4, "tir.is_scheduled": T.bool(True), "tir.noalias": T.bool(True)})
        # with T.block("root"):
        matmul_local = T.alloc_buffer((T.int64(1), T.int64(10)), scope="local")
        lv_shared = T.alloc_buffer((T.int64(1), T.int64(256)), scope="shared")
        permute_dims1_shared = T.alloc_buffer((T.int64(256), T.int64(10)), scope="shared")
        for i0_0_i1_0_fused in T.thread_binding(T.int64(1), thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 1024, "pragma_unroll_explicit": 1}):
            for i0_1_i1_1_fused in T.thread_binding(T.int64(1), thread="vthread.x"):
                for i0_2_i1_2_fused in T.thread_binding(T.int64(10), thread="threadIdx.x"):
                    for i0_3_init, i1_3_init, i0_4_init, i1_4_init in T.grid(T.int64(1), T.int64(1), T.int64(1), T.int64(1)):
                        with T.block("matmul_init"):
                            v_i0 = T.axis.spatial(T.int64(1), i0_3_init + i0_4_init)
                            v_i1 = T.axis.spatial(T.int64(10), i0_2_i1_2_fused + i1_3_init + i1_4_init)
                            T.reads()
                            T.writes(matmul_local[v_i0, v_i1])
                            T.block_attr({"meta_schedule.thread_extent_high_inclusive": 1024, "meta_schedule.thread_extent_low_inclusive": 1, "meta_schedule.tiling_structure": "SSSRRSRS"})
                            matmul_local[v_i0, v_i1] = T.float32(0.0)
                    for k_0 in range(T.int64(1)):
                        for ax0_ax1_fused_0 in range(T.int64(7)):
                            for ax0_ax1_fused_1 in T.thread_binding(T.int64(10), thread="threadIdx.x"):
                                for ax0_ax1_fused_2 in T.vectorized(T.int64(4)):
                                    with T.block("lv_shared"):
                                        v0 = T.axis.spatial(T.int64(1), T.int64(0))
                                        v1 = T.axis.spatial(T.int64(256), ax0_ax1_fused_0 * T.int64(40) + ax0_ax1_fused_1 * T.int64(4) + ax0_ax1_fused_2)
                                        T.where((ax0_ax1_fused_0 * T.int64(10) + ax0_ax1_fused_1) * T.int64(4) + ax0_ax1_fused_2 < T.int64(256))
                                        T.reads(lv[v0, v1])
                                        T.writes(lv_shared[v0, v1])
                                        lv_shared[v0, v1] = lv[v0, v1]
                        for ax0_ax1_fused_0 in range(T.int64(64)):
                            for ax0_ax1_fused_1 in T.thread_binding(T.int64(10), thread="threadIdx.x"):
                                for ax0_ax1_fused_2 in T.vectorized(T.int64(4)):
                                    with T.block("permute_dims1_shared"):
                                        v0 = T.axis.spatial(T.int64(256), (ax0_ax1_fused_0 * T.int64(40) + ax0_ax1_fused_1 * T.int64(4) + ax0_ax1_fused_2) // T.int64(10))
                                        v1 = T.axis.spatial(T.int64(10), (ax0_ax1_fused_0 * T.int64(40) + ax0_ax1_fused_1 * T.int64(4) + ax0_ax1_fused_2) % T.int64(10))
                                        T.reads(permute_dims1[v0, v1])
                                        T.writes(permute_dims1_shared[v0, v1])
                                        permute_dims1_shared[v0, v1] = permute_dims1[v0, v1]
                        for k_1, i0_3, i1_3, k_2, i0_4, i1_4 in T.grid(T.int64(8), T.int64(1), T.int64(1), T.int64(32), T.int64(1), T.int64(1)):
                            with T.block("matmul_update"):
                                v_i0 = T.axis.spatial(T.int64(1), i0_3 + i0_4)
                                v_i1 = T.axis.spatial(T.int64(10), i0_2_i1_2_fused + i1_3 + i1_4)
                                v_k = T.axis.reduce(T.int64(256), k_0 * T.int64(256) + k_1 * T.int64(32) + k_2)
                                T.reads(matmul_local[v_i0, v_i1], lv_shared[v_i0, v_k], permute_dims1_shared[v_k, v_i1])
                                T.writes(matmul_local[v_i0, v_i1])
                                T.block_attr({"meta_schedule.thread_extent_high_inclusive": 1024, "meta_schedule.thread_extent_low_inclusive": 1, "meta_schedule.tiling_structure": "SSSRRSRS"})
                                matmul_local[v_i0, v_i1] = matmul_local[v_i0, v_i1] + lv_shared[v_i0, v_k] * permute_dims1_shared[v_k, v_i1]
                    for ax0, ax1 in T.grid(T.int64(1), T.int64(1)):
                        with T.block("matmul_local"):
                            v0 = T.axis.spatial(T.int64(1), ax0)
                            v1 = T.axis.spatial(T.int64(10), i0_2_i1_2_fused + ax1)
                            T.reads(matmul_local[v0, v1])
                            T.writes(matmul[v0, v1])
                            matmul[v0, v1] = matmul_local[v0, v1]

    @T.prim_func(private=True)
    def transpose(fc2_weight: T.Buffer((T.int64(10), T.int64(256)), "float32"), T_transpose: T.Buffer((T.int64(256), T.int64(10)), "float32")):
        T.func_attr({"op_pattern": 2, "tir.is_scheduled": T.bool(True), "tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0_ax1_fused_0 in T.thread_binding(T.int64(3), thread="blockIdx.x"):
            for ax0_ax1_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"):
                with T.block("T_transpose"):
                    v_ax0 = T.axis.spatial(T.int64(256), (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) // T.int64(10))
                    v_ax1 = T.axis.spatial(T.int64(10), (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) % T.int64(10))
                    T.where(ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1 < T.int64(2560))
                    T.reads(fc2_weight[v_ax1, v_ax0])
                    T.writes(T_transpose[v_ax0, v_ax1])
                    T_transpose[v_ax0, v_ax1] = fc2_weight[v_ax1, v_ax0]

    @R.function
    def forward(x: R.Tensor((1, 784), dtype="float32"), fc1_weight: R.Tensor((256, 784), dtype="float32"), fc1_bias: R.Tensor((256,), dtype="float32"), fc2_weight: R.Tensor((10, 256), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
        R.func_attr({"num_input": 1})
        cls = Module
        with R.dataflow():
            lv = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_relu_cublas", (fc1_weight, x, fc1_bias), out_sinfo=R.Tensor((1, 256), dtype="float32"))
            permute_dims1 = R.call_tir(cls.transpose, (fc2_weight,), out_sinfo=R.Tensor((256, 10), dtype="float32"))
            gv = R.call_tir(cls.matmul, (lv, permute_dims1), out_sinfo=R.Tensor((1, 10), dtype="float32"))
            R.output(gv)
        return gv

# Metadata omitted. Use show_meta=True in script() method to show it.

备注

本教程重点在于展示优化流水线的演示,而不是将性能推向极限。当前的优化可能不是最佳的。

部署优化后的模型#

我们可以构建并将优化后的模型部署到 TVM 运行时。

ex = relax.build(mod, target="cuda")
dev = tvm.device("cuda", 0)
vm = relax.VirtualMachine(ex, dev)
# Need to allocate data and params on GPU device
data = tvm.nd.array(np.random.rand(*input_shape).astype("float32"), dev)
gpu_params = [tvm.nd.array(np.random.rand(*p.shape).astype(p.dtype), dev) for _, p in params]
gpu_out = vm["forward"](data, *gpu_params).numpy()
print(gpu_out)
[[24598.252 24066.623 25208.867 25324.975 25332.447 24816.111 24261.271
  25795.818 25539.488 24348.896]]

总结#

本教程展示了如何为 Apache TVM 中的机器学习模型自定义优化流水线。我们可以容易地组合优化过程,并为计算图的不同部分自定义优化。优化流水线的灵活性使我们能够快速迭代优化并提高模型性能。