自定义优化#

Apache TVM 的一个主要设计目标是使优化流水线易于定制,无论是研究或开发目的,还是迭代工程优化。

import set_env
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
E0000 00:00:1732545462.539809  190669 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1732545462.547134  190669 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
import os
import tempfile
import numpy as np
import tvm
from tvm import IRModule, relax
from tvm.relax.frontend import nn

可组合IRModule优化#

Apache TVM提供了一种灵活的方式来优化 IRModule。围绕 IRModule 优化的所有运算都可以与现有流水线组合。请注意,每个优化可以聚焦于 部分计算图,实现局部 lower 或者局部优化。

准备 Relax 模块#

我们首先准备一个Relax模块。这个模块可以从其他框架导入,用 NN 模块前端或 TVMScript 构建。这里我们使用一个简单的神经网络模型作为例子。

class RelaxModel(nn.Module):
    def __init__(self):
        super(RelaxModel, self).__init__()
        self.fc1 = nn.Linear(784, 256)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(256, 10, bias=False)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu1(x)
        x = self.fc2(x)
        return x


input_shape = (1, 784)
mod, params = RelaxModel().export_tvm({"forward": {"x": nn.spec.Tensor(input_shape, "float32")}})
mod.show()
# from tvm.script import ir as I
# from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function
    def forward(x: R.Tensor((1, 784), dtype="float32"), fc1_weight: R.Tensor((256, 784), dtype="float32"), fc1_bias: R.Tensor((256,), dtype="float32"), fc2_weight: R.Tensor((10, 256), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
        R.func_attr({"num_input": 1})
        with R.dataflow():
            permute_dims: R.Tensor((784, 256), dtype="float32") = R.permute_dims(fc1_weight, axes=None)
            matmul: R.Tensor((1, 256), dtype="float32") = R.matmul(x, permute_dims, out_dtype="void")
            add: R.Tensor((1, 256), dtype="float32") = R.add(matmul, fc1_bias)
            relu: R.Tensor((1, 256), dtype="float32") = R.nn.relu(add)
            permute_dims1: R.Tensor((256, 10), dtype="float32") = R.permute_dims(fc2_weight, axes=None)
            matmul1: R.Tensor((1, 10), dtype="float32") = R.matmul(relu, permute_dims1, out_dtype="void")
            gv: R.Tensor((1, 10), dtype="float32") = matmul1
            R.output(gv)
        return gv

库调度#

我们希望快速尝试针对特定平台(例如 GPU)的变体库优化。我们可以为特定平台和算子编写一个特定的调度过程。这里我们展示如何为某些模式调度 CUBLAS 库。

备注

本教程仅演示了针对 CUBLAS 的单个算子调度,突出显示了优化流水线的灵活性。在真实案例中,我们可以导入多个模式并将它们调度到不同的内核。

# Import cublas pattern
import tvm.relax.backend.contrib.cublas as _cublas


# Define a new pass for CUBLAS dispatch
@tvm.transform.module_pass(opt_level=0, name="CublasDispatch")
class CublasDispatch:
    def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule:
        # Check if CUBLAS is enabled
        if not tvm.get_global_func("relax.ext.cublas", True):
            raise Exception("CUBLAS is not enabled.")

        # Get interested patterns
        patterns = [relax.backend.get_pattern("cublas.matmul_transposed_bias_relu")]
        # Note in real-world cases, we usually get all patterns
        # patterns = relax.backend.get_patterns_with_prefix("cublas")

        # Fuse ops by patterns and then run codegen
        mod = relax.transform.FuseOpsByPattern(patterns, annotate_codegen=True)(mod)
        mod = relax.transform.RunCodegen()(mod)
        return mod


mod = CublasDispatch()(mod)
mod.show()
Hide code cell output
---------------------------------------------------------------------------
Exception                                 Traceback (most recent call last)
Cell In[4], line 24
     20         mod = relax.transform.RunCodegen()(mod)
     21         return mod
---> 24 mod = CublasDispatch()(mod)
     25 mod.show()

File /media/pc/data/lxw/ai/tvm/python/tvm/ir/transform.py:238, in Pass.__call__(self, mod)
    224 def __call__(self, mod):
    225     """Execute the pass. Note that for sequential pass, the dependency among
    226     different passes will be resolved in the backend.
    227 
   (...)
    236         The updated module after applying this pass.
    237     """
--> 238     return _ffi_transform_api.RunPass(self, mod)

File /media/pc/data/lxw/ai/tvm/python/tvm/_ffi/_ctypes/packed_func.py:245, in PackedFuncBase.__call__(self, *args)
    233 ret_tcode = ctypes.c_int()
    234 if (
    235     _LIB.TVMFuncCall(
    236         self.handle,
   (...)
    243     != 0
    244 ):
--> 245     raise_last_ffi_error()
    246 _ = temp_args
    247 _ = args

File /media/pc/data/lxw/ai/tvm/python/tvm/_ffi/base.py:481, in raise_last_ffi_error()
    475 # The exception PyObject may contain a large amount of state,
    476 # including all stack frames that may be inspected in a later
    477 # PDB post-mortem.  Therefore, we must make sure to remove the
    478 # underlying PyObject* from the C++ side after we retrieve it.
    479 _LIB.TVMDropLastPythonError()
--> 481 raise py_err

File /media/pc/data/lxw/ai/tvm/python/tvm/ir/transform.py:307, in _wrap_class_module_pass.<locals>.PyModulePass.__init__.<locals>._pass_func(mod, ctx)
    306 def _pass_func(mod, ctx):
--> 307     return inst.transform_module(mod, ctx)

Cell In[4], line 11, in CublasDispatch.transform_module(self, mod, _ctx)
      8 def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule:
      9     # Check if CUBLAS is enabled
     10     if not tvm.get_global_func("relax.ext.cublas", True):
---> 11         raise Exception("CUBLAS is not enabled.")
     13     # Get interested patterns
     14     patterns = [relax.backend.get_pattern("cublas.matmul_transposed_bias_relu")]

Exception: CUBLAS is not enabled.

调度过程之后#

我们可以看到第一个 nn.Linearnn.ReLU 被融合并重写为一个 call_dps_packed 函数,该函数调用CUBLAS库。值得注意的是,其他部分没有改变,这意味着我们可以有选择地为某些计算调度优化。

自动调优#

在之前的例子基础上,我们可以通过自动调优进一步优化模型的 其余计算部分。这里我们展示如何使用元调度来自动调优模型。

我们可以使用 MetaScheduleTuneTIR 过程来简化模型调优,而 MetaScheduleApplyDatabase 过程则将最佳配置应用到模型上。调优过程将生成搜索空间,调优模型,接下来的步骤将把最佳配置应用到模型上。在运行这些过程之前,我们需要通过 LegalizeOps 将 Relax 算子降低为 TensorIR 函数。

device = tvm.cuda(0)
target = tvm.target.Target.from_device(device)
if os.getenv("CI", "") != "true":
    trials = 2000
    with target, tempfile.TemporaryDirectory() as tmp_dir:
        mod = tvm.ir.transform.Sequential(
            [
                relax.get_pipeline("zero"),
                relax.transform.MetaScheduleTuneTIR(work_dir=tmp_dir, max_trials_global=trials),
                relax.transform.MetaScheduleApplyDatabase(work_dir=tmp_dir),
            ]
        )(mod)

    mod.show()

DLight 规则#

DLight 规则是一组用于调度和优化内核的默认规则。DLight规则旨在实现快速编译和公平的性能。在某些情况下,例如语言模型,DLight提供出色的性能,而对于通用模型,它在性能和编译时间之间取得平衡。

from tvm import dlight as dl

# Apply DLight rules
with target:
    mod = tvm.ir.transform.Sequential(
        [
            relax.get_pipeline("zero"),
            dl.ApplyDefaultSchedule(  # pylint: disable=not-callable
                dl.gpu.Matmul(),
                dl.gpu.GEMV(),
                dl.gpu.Reduction(),
                dl.gpu.GeneralReduction(),
                dl.gpu.Fallback(),
            ),
        ]
    )(mod)

mod.show()
Hide code cell output
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    I.module_attrs({"external_mods": [metadata["runtime.Module"][0]]})
    @T.prim_func(private=True)
    def matmul(lv: T.Buffer((T.int64(1), T.int64(256)), "float32"), permute_dims1: T.Buffer((T.int64(256), T.int64(10)), "float32"), matmul: T.Buffer((T.int64(1), T.int64(10)), "float32")):
        T.func_attr({"op_pattern": 4, "tir.is_scheduled": T.bool(True), "tir.noalias": T.bool(True)})
        # with T.block("root"):
        matmul_local = T.alloc_buffer((T.int64(1), T.int64(10)), scope="local")
        lv_shared = T.alloc_buffer((T.int64(1), T.int64(256)), scope="shared")
        permute_dims1_shared = T.alloc_buffer((T.int64(256), T.int64(10)), scope="shared")
        for i0_0_i1_0_fused in T.thread_binding(T.int64(1), thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 1024, "pragma_unroll_explicit": 1}):
            for i0_1_i1_1_fused in T.thread_binding(T.int64(1), thread="vthread.x"):
                for i0_2_i1_2_fused in T.thread_binding(T.int64(10), thread="threadIdx.x"):
                    for i0_3_init, i1_3_init, i0_4_init, i1_4_init in T.grid(T.int64(1), T.int64(1), T.int64(1), T.int64(1)):
                        with T.block("matmul_init"):
                            v_i0 = T.axis.spatial(T.int64(1), i0_3_init + i0_4_init)
                            v_i1 = T.axis.spatial(T.int64(10), i0_2_i1_2_fused + i1_3_init + i1_4_init)
                            T.reads()
                            T.writes(matmul_local[v_i0, v_i1])
                            T.block_attr({"meta_schedule.thread_extent_high_inclusive": 1024, "meta_schedule.thread_extent_low_inclusive": 1, "meta_schedule.tiling_structure": "SSSRRSRS"})
                            matmul_local[v_i0, v_i1] = T.float32(0.0)
                    for k_0 in range(T.int64(1)):
                        for ax0_ax1_fused_0 in range(T.int64(7)):
                            for ax0_ax1_fused_1 in T.thread_binding(T.int64(10), thread="threadIdx.x"):
                                for ax0_ax1_fused_2 in T.vectorized(T.int64(4)):
                                    with T.block("lv_shared"):
                                        v0 = T.axis.spatial(T.int64(1), T.int64(0))
                                        v1 = T.axis.spatial(T.int64(256), ax0_ax1_fused_0 * T.int64(40) + ax0_ax1_fused_1 * T.int64(4) + ax0_ax1_fused_2)
                                        T.where((ax0_ax1_fused_0 * T.int64(10) + ax0_ax1_fused_1) * T.int64(4) + ax0_ax1_fused_2 < T.int64(256))
                                        T.reads(lv[v0, v1])
                                        T.writes(lv_shared[v0, v1])
                                        lv_shared[v0, v1] = lv[v0, v1]
                        for ax0_ax1_fused_0 in range(T.int64(64)):
                            for ax0_ax1_fused_1 in T.thread_binding(T.int64(10), thread="threadIdx.x"):
                                for ax0_ax1_fused_2 in T.vectorized(T.int64(4)):
                                    with T.block("permute_dims1_shared"):
                                        v0 = T.axis.spatial(T.int64(256), (ax0_ax1_fused_0 * T.int64(40) + ax0_ax1_fused_1 * T.int64(4) + ax0_ax1_fused_2) // T.int64(10))
                                        v1 = T.axis.spatial(T.int64(10), (ax0_ax1_fused_0 * T.int64(40) + ax0_ax1_fused_1 * T.int64(4) + ax0_ax1_fused_2) % T.int64(10))
                                        T.reads(permute_dims1[v0, v1])
                                        T.writes(permute_dims1_shared[v0, v1])
                                        permute_dims1_shared[v0, v1] = permute_dims1[v0, v1]
                        for k_1, i0_3, i1_3, k_2, i0_4, i1_4 in T.grid(T.int64(8), T.int64(1), T.int64(1), T.int64(32), T.int64(1), T.int64(1)):
                            with T.block("matmul_update"):
                                v_i0 = T.axis.spatial(T.int64(1), i0_3 + i0_4)
                                v_i1 = T.axis.spatial(T.int64(10), i0_2_i1_2_fused + i1_3 + i1_4)
                                v_k = T.axis.reduce(T.int64(256), k_0 * T.int64(256) + k_1 * T.int64(32) + k_2)
                                T.reads(matmul_local[v_i0, v_i1], lv_shared[v_i0, v_k], permute_dims1_shared[v_k, v_i1])
                                T.writes(matmul_local[v_i0, v_i1])
                                T.block_attr({"meta_schedule.thread_extent_high_inclusive": 1024, "meta_schedule.thread_extent_low_inclusive": 1, "meta_schedule.tiling_structure": "SSSRRSRS"})
                                matmul_local[v_i0, v_i1] = matmul_local[v_i0, v_i1] + lv_shared[v_i0, v_k] * permute_dims1_shared[v_k, v_i1]
                    for ax0, ax1 in T.grid(T.int64(1), T.int64(1)):
                        with T.block("matmul_local"):
                            v0 = T.axis.spatial(T.int64(1), ax0)
                            v1 = T.axis.spatial(T.int64(10), i0_2_i1_2_fused + ax1)
                            T.reads(matmul_local[v0, v1])
                            T.writes(matmul[v0, v1])
                            matmul[v0, v1] = matmul_local[v0, v1]

    @T.prim_func(private=True)
    def transpose(fc2_weight: T.Buffer((T.int64(10), T.int64(256)), "float32"), T_transpose: T.Buffer((T.int64(256), T.int64(10)), "float32")):
        T.func_attr({"op_pattern": 2, "tir.is_scheduled": T.bool(True), "tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0_ax1_fused_0 in T.thread_binding(T.int64(3), thread="blockIdx.x"):
            for ax0_ax1_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"):
                with T.block("T_transpose"):
                    v_ax0 = T.axis.spatial(T.int64(256), (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) // T.int64(10))
                    v_ax1 = T.axis.spatial(T.int64(10), (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) % T.int64(10))
                    T.where(ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1 < T.int64(2560))
                    T.reads(fc2_weight[v_ax1, v_ax0])
                    T.writes(T_transpose[v_ax0, v_ax1])
                    T_transpose[v_ax0, v_ax1] = fc2_weight[v_ax1, v_ax0]

    @R.function
    def forward(x: R.Tensor((1, 784), dtype="float32"), fc1_weight: R.Tensor((256, 784), dtype="float32"), fc1_bias: R.Tensor((256,), dtype="float32"), fc2_weight: R.Tensor((10, 256), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
        R.func_attr({"num_input": 1})
        cls = Module
        with R.dataflow():
            lv = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_relu_cublas", (fc1_weight, x, fc1_bias), out_sinfo=R.Tensor((1, 256), dtype="float32"))
            permute_dims1 = R.call_tir(cls.transpose, (fc2_weight,), out_sinfo=R.Tensor((256, 10), dtype="float32"))
            gv = R.call_tir(cls.matmul, (lv, permute_dims1), out_sinfo=R.Tensor((1, 10), dtype="float32"))
            R.output(gv)
        return gv

# Metadata omitted. Use show_meta=True in script() method to show it.

备注

本教程重点在于展示优化流水线的演示,而不是将性能推向极限。当前的优化可能不是最佳的。

部署优化后的模型#

我们可以构建并将优化后的模型部署到 TVM 运行时。

ex = relax.build(mod, target="cuda")
dev = tvm.device("cuda", 0)
vm = relax.VirtualMachine(ex, dev)
# Need to allocate data and params on GPU device
data = tvm.nd.array(np.random.rand(*input_shape).astype("float32"), dev)
gpu_params = [tvm.nd.array(np.random.rand(*p.shape).astype(p.dtype), dev) for _, p in params]
gpu_out = vm["forward"](data, *gpu_params).numpy()
print(gpu_out)
[[24598.252 24066.623 25208.867 25324.975 25332.447 24816.111 24261.271
  25795.818 25539.488 24348.896]]

总结#

本教程展示了如何为 Apache TVM 中的机器学习模型自定义优化流水线。我们可以容易地组合优化过程,并为计算图的不同部分自定义优化。优化流水线的灵活性使我们能够快速迭代优化并提高模型性能。