自定义优化#
Apache TVM 的一个主要设计目标是使优化流水线易于定制,无论是研究或开发目的,还是迭代工程优化。
import set_env
import os
import tempfile
import numpy as np
import tvm
from tvm import IRModule, relax
from tvm.relax.frontend import nn
可组合IRModule优化#
Apache TVM提供了一种灵活的方式来优化 IRModule。围绕 IRModule 优化的所有运算都可以与现有流水线组合。请注意,每个优化可以聚焦于 部分计算图,实现局部 lower 或者局部优化。
准备 Relax 模块#
我们首先准备一个Relax模块。这个模块可以从其他框架导入,用 NN 模块前端或 TVMScript 构建。这里我们使用一个简单的神经网络模型作为例子。
class RelaxModel(nn.Module):
def __init__(self):
super(RelaxModel, self).__init__()
self.fc1 = nn.Linear(784, 256)
self.relu1 = nn.ReLU()
self.fc2 = nn.Linear(256, 10, bias=False)
def forward(self, x):
x = self.fc1(x)
x = self.relu1(x)
x = self.fc2(x)
return x
input_shape = (1, 784)
mod, params = RelaxModel().export_tvm({"forward": {"x": nn.spec.Tensor(input_shape, "float32")}})
mod.show()
# from tvm.script import ir as I
# from tvm.script import relax as R
@I.ir_module
class Module:
@R.function
def forward(x: R.Tensor((1, 784), dtype="float32"), fc1_weight: R.Tensor((256, 784), dtype="float32"), fc1_bias: R.Tensor((256,), dtype="float32"), fc2_weight: R.Tensor((10, 256), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
permute_dims: R.Tensor((784, 256), dtype="float32") = R.permute_dims(fc1_weight, axes=None)
matmul: R.Tensor((1, 256), dtype="float32") = R.matmul(x, permute_dims, out_dtype="void")
add: R.Tensor((1, 256), dtype="float32") = R.add(matmul, fc1_bias)
relu: R.Tensor((1, 256), dtype="float32") = R.nn.relu(add)
permute_dims1: R.Tensor((256, 10), dtype="float32") = R.permute_dims(fc2_weight, axes=None)
matmul1: R.Tensor((1, 10), dtype="float32") = R.matmul(relu, permute_dims1, out_dtype="void")
gv: R.Tensor((1, 10), dtype="float32") = matmul1
R.output(gv)
return gv
库调度#
我们希望快速尝试针对特定平台(例如 GPU)的变体库优化。我们可以为特定平台和算子编写一个特定的调度过程。这里我们展示如何为某些模式调度 CUBLAS 库。
备注
本教程仅演示了针对 CUBLAS 的单个算子调度,突出显示了优化流水线的灵活性。在真实案例中,我们可以导入多个模式并将它们调度到不同的内核。
# Import cublas pattern
import tvm.relax.backend.contrib.cublas as _cublas
# Define a new pass for CUBLAS dispatch
@tvm.transform.module_pass(opt_level=0, name="CublasDispatch")
class CublasDispatch:
def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule:
# Check if CUBLAS is enabled
if not tvm.get_global_func("relax.ext.cublas", True):
raise Exception("CUBLAS is not enabled.")
# Get interested patterns
patterns = [relax.backend.get_pattern("cublas.matmul_transposed_bias_relu")]
# Note in real-world cases, we usually get all patterns
# patterns = relax.backend.get_patterns_with_prefix("cublas")
# Fuse ops by patterns and then run codegen
mod = relax.transform.FuseOpsByPattern(patterns, annotate_codegen=True)(mod)
mod = relax.transform.RunCodegen()(mod)
return mod
mod = CublasDispatch()(mod)
mod.show()
# from tvm.script import ir as I
# from tvm.script import relax as R
@I.ir_module
class Module:
I.module_attrs({"external_mods": [metadata["runtime.Module"][0]]})
@R.function
def forward(x: R.Tensor((1, 784), dtype="float32"), fc1_weight: R.Tensor((256, 784), dtype="float32"), fc1_bias: R.Tensor((256,), dtype="float32"), fc2_weight: R.Tensor((10, 256), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
lv = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_relu_cublas", (fc1_weight, x, fc1_bias), out_sinfo=R.Tensor((1, 256), dtype="float32"))
permute_dims1: R.Tensor((256, 10), dtype="float32") = R.permute_dims(fc2_weight, axes=None)
matmul1: R.Tensor((1, 10), dtype="float32") = R.matmul(lv, permute_dims1, out_dtype="void")
gv: R.Tensor((1, 10), dtype="float32") = matmul1
R.output(gv)
return gv
# Metadata omitted. Use show_meta=True in script() method to show it.
调度过程之后#
我们可以看到第一个 nn.Linear
和 nn.ReLU
被融合并重写为一个 call_dps_packed
函数,该函数调用CUBLAS库。值得注意的是,其他部分没有改变,这意味着我们可以有选择地为某些计算调度优化。
自动调优#
在之前的例子基础上,我们可以通过自动调优进一步优化模型的 其余计算部分。这里我们展示如何使用元调度来自动调优模型。
我们可以使用 MetaScheduleTuneTIR
过程来简化模型调优,而 MetaScheduleApplyDatabase
过程则将最佳配置应用到模型上。调优过程将生成搜索空间,调优模型,接下来的步骤将把最佳配置应用到模型上。在运行这些过程之前,我们需要通过 LegalizeOps
将 Relax 算子降低为 TensorIR 函数。
device = tvm.cuda(0)
target = tvm.target.Target.from_device(device)
if os.getenv("CI", "") != "true":
trials = 2000
with target, tempfile.TemporaryDirectory() as tmp_dir:
mod = tvm.ir.transform.Sequential(
[
relax.get_pipeline("zero"),
relax.transform.MetaScheduleTuneTIR(work_dir=tmp_dir, max_trials_global=trials),
relax.transform.MetaScheduleApplyDatabase(work_dir=tmp_dir),
]
)(mod)
mod.show()
DLight 规则#
DLight 规则是一组用于调度和优化内核的默认规则。DLight规则旨在实现快速编译和公平的性能。在某些情况下,例如语言模型,DLight提供出色的性能,而对于通用模型,它在性能和编译时间之间取得平衡。
from tvm import dlight as dl
# Apply DLight rules
with target:
mod = tvm.ir.transform.Sequential(
[
relax.get_pipeline("zero"),
dl.ApplyDefaultSchedule( # pylint: disable=not-callable
dl.gpu.Matmul(),
dl.gpu.GEMV(),
dl.gpu.Reduction(),
dl.gpu.GeneralReduction(),
dl.gpu.Fallback(),
),
]
)(mod)
mod.show()
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R
@I.ir_module
class Module:
I.module_attrs({"external_mods": [metadata["runtime.Module"][0]]})
@T.prim_func(private=True)
def matmul(lv: T.Buffer((T.int64(1), T.int64(256)), "float32"), permute_dims1: T.Buffer((T.int64(256), T.int64(10)), "float32"), matmul: T.Buffer((T.int64(1), T.int64(10)), "float32")):
T.func_attr({"op_pattern": 4, "tir.is_scheduled": T.bool(True), "tir.noalias": T.bool(True)})
# with T.block("root"):
matmul_local = T.alloc_buffer((T.int64(1), T.int64(10)), scope="local")
lv_shared = T.alloc_buffer((T.int64(1), T.int64(256)), scope="shared")
permute_dims1_shared = T.alloc_buffer((T.int64(256), T.int64(10)), scope="shared")
for i0_0_i1_0_fused in T.thread_binding(T.int64(1), thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 1024, "pragma_unroll_explicit": 1}):
for i0_1_i1_1_fused in T.thread_binding(T.int64(1), thread="vthread.x"):
for i0_2_i1_2_fused in T.thread_binding(T.int64(10), thread="threadIdx.x"):
for i0_3_init, i1_3_init, i0_4_init, i1_4_init in T.grid(T.int64(1), T.int64(1), T.int64(1), T.int64(1)):
with T.block("matmul_init"):
v_i0 = T.axis.spatial(T.int64(1), i0_3_init + i0_4_init)
v_i1 = T.axis.spatial(T.int64(10), i0_2_i1_2_fused + i1_3_init + i1_4_init)
T.reads()
T.writes(matmul_local[v_i0, v_i1])
T.block_attr({"meta_schedule.thread_extent_high_inclusive": 1024, "meta_schedule.thread_extent_low_inclusive": 1, "meta_schedule.tiling_structure": "SSSRRSRS"})
matmul_local[v_i0, v_i1] = T.float32(0.0)
for k_0 in range(T.int64(1)):
for ax0_ax1_fused_0 in range(T.int64(7)):
for ax0_ax1_fused_1 in T.thread_binding(T.int64(10), thread="threadIdx.x"):
for ax0_ax1_fused_2 in T.vectorized(T.int64(4)):
with T.block("lv_shared"):
v0 = T.axis.spatial(T.int64(1), T.int64(0))
v1 = T.axis.spatial(T.int64(256), ax0_ax1_fused_0 * T.int64(40) + ax0_ax1_fused_1 * T.int64(4) + ax0_ax1_fused_2)
T.where((ax0_ax1_fused_0 * T.int64(10) + ax0_ax1_fused_1) * T.int64(4) + ax0_ax1_fused_2 < T.int64(256))
T.reads(lv[v0, v1])
T.writes(lv_shared[v0, v1])
lv_shared[v0, v1] = lv[v0, v1]
for ax0_ax1_fused_0 in range(T.int64(64)):
for ax0_ax1_fused_1 in T.thread_binding(T.int64(10), thread="threadIdx.x"):
for ax0_ax1_fused_2 in T.vectorized(T.int64(4)):
with T.block("permute_dims1_shared"):
v0 = T.axis.spatial(T.int64(256), (ax0_ax1_fused_0 * T.int64(40) + ax0_ax1_fused_1 * T.int64(4) + ax0_ax1_fused_2) // T.int64(10))
v1 = T.axis.spatial(T.int64(10), (ax0_ax1_fused_0 * T.int64(40) + ax0_ax1_fused_1 * T.int64(4) + ax0_ax1_fused_2) % T.int64(10))
T.reads(permute_dims1[v0, v1])
T.writes(permute_dims1_shared[v0, v1])
permute_dims1_shared[v0, v1] = permute_dims1[v0, v1]
for k_1, i0_3, i1_3, k_2, i0_4, i1_4 in T.grid(T.int64(8), T.int64(1), T.int64(1), T.int64(32), T.int64(1), T.int64(1)):
with T.block("matmul_update"):
v_i0 = T.axis.spatial(T.int64(1), i0_3 + i0_4)
v_i1 = T.axis.spatial(T.int64(10), i0_2_i1_2_fused + i1_3 + i1_4)
v_k = T.axis.reduce(T.int64(256), k_0 * T.int64(256) + k_1 * T.int64(32) + k_2)
T.reads(matmul_local[v_i0, v_i1], lv_shared[v_i0, v_k], permute_dims1_shared[v_k, v_i1])
T.writes(matmul_local[v_i0, v_i1])
T.block_attr({"meta_schedule.thread_extent_high_inclusive": 1024, "meta_schedule.thread_extent_low_inclusive": 1, "meta_schedule.tiling_structure": "SSSRRSRS"})
matmul_local[v_i0, v_i1] = matmul_local[v_i0, v_i1] + lv_shared[v_i0, v_k] * permute_dims1_shared[v_k, v_i1]
for ax0, ax1 in T.grid(T.int64(1), T.int64(1)):
with T.block("matmul_local"):
v0 = T.axis.spatial(T.int64(1), ax0)
v1 = T.axis.spatial(T.int64(10), i0_2_i1_2_fused + ax1)
T.reads(matmul_local[v0, v1])
T.writes(matmul[v0, v1])
matmul[v0, v1] = matmul_local[v0, v1]
@T.prim_func(private=True)
def transpose(fc2_weight: T.Buffer((T.int64(10), T.int64(256)), "float32"), T_transpose: T.Buffer((T.int64(256), T.int64(10)), "float32")):
T.func_attr({"op_pattern": 2, "tir.is_scheduled": T.bool(True), "tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0_ax1_fused_0 in T.thread_binding(T.int64(3), thread="blockIdx.x"):
for ax0_ax1_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"):
with T.block("T_transpose"):
v_ax0 = T.axis.spatial(T.int64(256), (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) // T.int64(10))
v_ax1 = T.axis.spatial(T.int64(10), (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1) % T.int64(10))
T.where(ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1 < T.int64(2560))
T.reads(fc2_weight[v_ax1, v_ax0])
T.writes(T_transpose[v_ax0, v_ax1])
T_transpose[v_ax0, v_ax1] = fc2_weight[v_ax1, v_ax0]
@R.function
def forward(x: R.Tensor((1, 784), dtype="float32"), fc1_weight: R.Tensor((256, 784), dtype="float32"), fc1_bias: R.Tensor((256,), dtype="float32"), fc2_weight: R.Tensor((10, 256), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
R.func_attr({"num_input": 1})
cls = Module
with R.dataflow():
lv = R.call_dps_packed("fused_relax_permute_dims_relax_matmul_relax_add_relax_nn_relu_cublas", (fc1_weight, x, fc1_bias), out_sinfo=R.Tensor((1, 256), dtype="float32"))
permute_dims1 = R.call_tir(cls.transpose, (fc2_weight,), out_sinfo=R.Tensor((256, 10), dtype="float32"))
gv = R.call_tir(cls.matmul, (lv, permute_dims1), out_sinfo=R.Tensor((1, 10), dtype="float32"))
R.output(gv)
return gv
# Metadata omitted. Use show_meta=True in script() method to show it.
备注
本教程重点在于展示优化流水线的演示,而不是将性能推向极限。当前的优化可能不是最佳的。
部署优化后的模型#
我们可以构建并将优化后的模型部署到 TVM 运行时。
ex = relax.build(mod, target="cuda")
dev = tvm.device("cuda", 0)
vm = relax.VirtualMachine(ex, dev)
# Need to allocate data and params on GPU device
data = tvm.nd.array(np.random.rand(*input_shape).astype("float32"), dev)
gpu_params = [tvm.nd.array(np.random.rand(*p.shape).astype(p.dtype), dev) for _, p in params]
gpu_out = vm["forward"](data, *gpu_params).numpy()
print(gpu_out)
[[24598.252 24066.623 25208.867 25324.975 25332.447 24816.111 24261.271
25795.818 25539.488 24348.896]]
总结#
本教程展示了如何为 Apache TVM 中的机器学习模型自定义优化流水线。我们可以容易地组合优化过程,并为计算图的不同部分自定义优化。优化流水线的灵活性使我们能够快速迭代优化并提高模型性能。