自定义优化#

Apache TVM 主要设计目标是使优化流水线易于定制,无论是研究或开发目的,还是迭代工程优化。

import os
import tempfile
import numpy as np
import tvm
from tvm import IRModule, relax
from tvm.relax.frontend import nn
/media/pc/data/lxw/ai/tvm/3rdparty/tvm-ffi/python/tvm_ffi/_optional_torch_c_dlpack.py:409: UserWarning: Failed to load torch c dlpack extension: Error building extension 'c_dlpack': [1/2] /media/pc/data/lxw/envs/anaconda3a/envs/py313/bin/x86_64-conda-linux-gnu-c++ -MMD -MF main.o.d -DTORCH_EXTENSION_NAME=c_dlpack -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1018\" -I/media/pc/data/lxw/ai/tvm/3rdparty/tvm-ffi/include -I/media/pc/data/lxw/ai/tvm/3rdparty/tvm-ffi/3rdparty/dlpack/include -I/media/pc/data/lxw/ai/tvm/3rdparty/tvm-ffi/python/tvm_ffi/cython -I/media/pc/data/lxw/envs/anaconda3a/envs/py313/lib/python3.13/site-packages/torch/include -I/media/pc/data/lxw/envs/anaconda3a/envs/py313/lib/python3.13/site-packages/torch/include/torch/csrc/api/include -I/media/pc/data/lxw/envs/anaconda3a/envs/py313/include -isystem /media/pc/data/lxw/envs/anaconda3a/envs/py313/lib/python3.13/site-packages/torch/include -isystem /media/pc/data/lxw/envs/anaconda3a/envs/py313/lib/python3.13/site-packages/torch/include/torch/csrc/api/include -isystem /media/pc/data/lxw/envs/anaconda3a/envs/py313/include/python3.13 -fPIC -std=c++17 -O3 -DBUILD_WITH_CUDA -c /home/ai/.cache/torch_extensions/py313_cu128/c_dlpack/main.cpp -o main.o 
FAILED: [code=1] main.o 
/media/pc/data/lxw/envs/anaconda3a/envs/py313/bin/x86_64-conda-linux-gnu-c++ -MMD -MF main.o.d -DTORCH_EXTENSION_NAME=c_dlpack -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1018\" -I/media/pc/data/lxw/ai/tvm/3rdparty/tvm-ffi/include -I/media/pc/data/lxw/ai/tvm/3rdparty/tvm-ffi/3rdparty/dlpack/include -I/media/pc/data/lxw/ai/tvm/3rdparty/tvm-ffi/python/tvm_ffi/cython -I/media/pc/data/lxw/envs/anaconda3a/envs/py313/lib/python3.13/site-packages/torch/include -I/media/pc/data/lxw/envs/anaconda3a/envs/py313/lib/python3.13/site-packages/torch/include/torch/csrc/api/include -I/media/pc/data/lxw/envs/anaconda3a/envs/py313/include -isystem /media/pc/data/lxw/envs/anaconda3a/envs/py313/lib/python3.13/site-packages/torch/include -isystem /media/pc/data/lxw/envs/anaconda3a/envs/py313/lib/python3.13/site-packages/torch/include/torch/csrc/api/include -isystem /media/pc/data/lxw/envs/anaconda3a/envs/py313/include/python3.13 -fPIC -std=c++17 -O3 -DBUILD_WITH_CUDA -c /home/ai/.cache/torch_extensions/py313_cu128/c_dlpack/main.cpp -o main.o 
In file included from /home/ai/.cache/torch_extensions/py313_cu128/c_dlpack/main.cpp:8:
/media/pc/data/lxw/envs/anaconda3a/envs/py313/lib/python3.13/site-packages/torch/include/c10/cuda/CUDAStream.h:3:10: fatal error: cuda_runtime_api.h: No such file or directory
    3 | #include <cuda_runtime_api.h>
      |          ^~~~~~~~~~~~~~~~~~~~
compilation terminated.
ninja: build stopped: subcommand failed.
,EnvTensorAllocator will not be enabled.
  warnings.warn(

可组合IRModule优化#

Apache TVM提供了一种灵活的方式来优化 IRModule。围绕 IRModule 优化的所有运算都可以与现有流水线组合。请注意,每个优化可以聚焦于 部分计算图,实现局部 lower 或者局部优化。

准备 Relax 模块#

首先准备 Relax 模块。这个模块可以从其他框架导入,用 NN 模块前端或 TVMScript 构建。这里使用简单的神经网络模型作为例子。

Hide code cell content

class RelaxModel(nn.Module):
    def __init__(self):
        super(RelaxModel, self).__init__()
        self.fc1 = nn.Linear(784, 256)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(256, 10, bias=False)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu1(x)
        x = self.fc2(x)
        return x


input_shape = (1, 784)
mod, params = RelaxModel().export_tvm({"forward": {"x": nn.spec.Tensor(input_shape, "float32")}})
mod.show()
# from tvm.script import ir as I
# from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function
    def forward(x: R.Tensor((1, 784), dtype="float32"), fc1_weight: R.Tensor((256, 784), dtype="float32"), fc1_bias: R.Tensor((256,), dtype="float32"), fc2_weight: R.Tensor((10, 256), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
        R.func_attr({"num_input": 1})
        with R.dataflow():
            permute_dims: R.Tensor((784, 256), dtype="float32") = R.permute_dims(fc1_weight, axes=None)
            matmul: R.Tensor((1, 256), dtype="float32") = R.matmul(x, permute_dims, out_dtype="void")
            add: R.Tensor((1, 256), dtype="float32") = R.add(matmul, fc1_bias)
            relu: R.Tensor((1, 256), dtype="float32") = R.nn.relu(add)
            permute_dims1: R.Tensor((256, 10), dtype="float32") = R.permute_dims(fc2_weight, axes=None)
            matmul1: R.Tensor((1, 10), dtype="float32") = R.matmul(relu, permute_dims1, out_dtype="void")
            gv: R.Tensor((1, 10), dtype="float32") = matmul1
            R.output(gv)
        return gv

库调度#

希望快速尝试针对特定平台(例如 GPU)的变体库优化。可以为特定平台和算子编写特定的调度过程。这里展示如何为某些模式调度 CUBLAS 库。

备注

本教程仅演示了针对 CUBLAS 的单个算子调度,突出显示了优化流水线的灵活性。在真实案例中,可以导入多个模式并将它们调度到不同的内核。

Hide code cell content

# Import cublas pattern
import tvm.relax.backend.contrib.cublas as _cublas


# Define a new pass for CUBLAS dispatch
@tvm.transform.module_pass(opt_level=0, name="CublasDispatch")
class CublasDispatch:
    def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule:
        # Check if CUBLAS is enabled
        if not tvm.get_global_func("relax.ext.cublas", True):
            raise Exception("CUBLAS is not enabled.")

        # Get interested patterns
        patterns = [relax.backend.get_pattern("cublas.matmul_transposed_bias_relu")]
        # Note in real-world cases, we usually get all patterns
        # patterns = relax.backend.get_patterns_with_prefix("cublas")

        # Fuse ops by patterns and then run codegen
        mod = relax.transform.FuseOpsByPattern(patterns, annotate_codegen=True)(mod)
        mod = relax.transform.RunCodegen()(mod)
        return mod


mod = CublasDispatch()(mod)
mod.show()

调度过程之后#

可以看到第一个 nn.Linearnn.ReLU 被融合并重写为 call_dps_packed 函数,该函数调用 CUBLAS 库。值得注意的是,其他部分没有改变,这意味着我们可以有选择地为某些计算调度优化。

自动调优#

在之前的例子基础上,可以通过自动调优进一步优化模型的 其余计算部分。这里我们展示如何使用元调度来自动调优模型。

可以使用 MetaScheduleTuneTIR 过程来简化模型调优,而 MetaScheduleApplyDatabase 过程则将最佳配置应用到模型上。调优过程将生成搜索空间,调优模型,接下来的步骤将把最佳配置应用到模型上。在运行这些过程之前,需要通过 LegalizeOps 将 Relax 算子降低为 TensorIR 函数。

device = tvm.cuda(0)
target = tvm.target.Target.from_device(device)
if os.getenv("CI", "") != "true":
    trials = 2000
    with target, tempfile.TemporaryDirectory() as tmp_dir:
        mod = tvm.ir.transform.Sequential(
            [
                relax.get_pipeline("zero"),
                relax.transform.MetaScheduleTuneTIR(work_dir=tmp_dir, max_trials_global=trials),
                relax.transform.MetaScheduleApplyDatabase(work_dir=tmp_dir),
            ]
        )(mod)

    mod.show()

DLight 规则#

DLight 规则是一组用于调度和优化内核的默认规则。DLight规则旨在实现快速编译和公平的性能。在某些情况下,例如语言模型,DLight提供出色的性能,而对于通用模型,它在性能和编译时间之间取得平衡。

from tvm import dlight as dl

# Apply DLight rules
with target:
    mod = tvm.ir.transform.Sequential(
        [
            relax.get_pipeline("zero"),
            dl.ApplyDefaultSchedule(  # pylint: disable=not-callable
                dl.gpu.Matmul(),
                dl.gpu.GEMV(),
                dl.gpu.Reduction(),
                dl.gpu.GeneralReduction(),
                dl.gpu.Fallback(),
            ),
        ]
    )(mod)

mod.show()

备注

本教程重点在于展示优化流水线的演示,而不是将性能推向极限。当前的优化可能不是最佳的。

部署优化后的模型#

可以构建并将优化后的模型部署到 TVM 运行时。

ex = tvm.compile(mod, target="cuda")
dev = tvm.device("cuda", 0)
vm = relax.VirtualMachine(ex, dev)
# Need to allocate data and params on GPU device
data = tvm.runtime.tensor(np.random.rand(*input_shape).astype("float32"), dev)
gpu_params = [tvm.runtime.tensor(np.random.rand(*p.shape).astype(p.dtype), dev) for _, p in params]
gpu_out = vm["forward"](data, *gpu_params).numpy()
print(gpu_out)

总结#

本教程展示了如何为 Apache TVM 中的机器学习模型自定义优化流水线。我们可以容易地组合优化过程,并为计算图的不同部分自定义优化。优化流水线的灵活性使我们能够快速迭代优化并提高模型性能。