TVM 样例#
Apache TVM 是遵循 Python 优先开发、通用部署原则的机器学习编译框架。它接收预训练的机器学习模型,编译并生成可嵌入和在任何地方运行的部署模块。Apache TVM 还允许自定义优化过程,引入新的优化、库、代码生成等。
Apache TVM 可以帮助:
优化:ML 工作负载、组合库和代码生成的性能。
部署:ML 工作负载部署到一组不同的新环境,包括新运行时和新的硬件。
持续改进和定制:通过快速自定义在 Python 中部署 ML 管道 库调度,引入自定义算子和代码生成。
整体流程包括以下步骤:
构建或导入模型:构建神经网络模型或从其他框架(例如 PyTorch、ONNX)导入预训练模型,并创建 TVM IRModule,其中包含编译所需的所有信息,包括用于计算图的高级 Relax 函数和用于张量程序的低级 TensorIR 函数。
执行可组合优化:执行一系列优化变换,如计算图优化、张量程序优化和库调度。
构建和通用部署:将优化后的模型构建为可部署模块到通用运行时,并在不同设备上执行,如 CPU、GPU 或其他加速器。
构造或导入模型#
使用 TVM Relax 前端直接定义一个两层感知器(MLP)网络,其 API 与 PyTorch 类似。
import tvm
from tvm import relax
from tvm.relax.frontend import nn
class MLPModel(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(784, 256)
self.relu1 = nn.ReLU()
self.fc2 = nn.Linear(256, 10)
def forward(self, x):
x = self.fc1(x)
x = self.relu1(x)
x = self.fc2(x)
return x
可以将模型导出为 TVM IRModule,这是 TVM 中的核心中间表示。
model = MLPModel()
mod, param_spec = model.export_tvm(
spec={"forward": {"x": nn.spec.Tensor((1, 784), "float32")}}
)
mod.show()
Show code cell output
# from tvm.script import ir as I
# from tvm.script import relax as R
@I.ir_module
class Module:
@R.function
def forward(x: R.Tensor((1, 784), dtype="float32"), fc1_weight: R.Tensor((256, 784), dtype="float32"), fc1_bias: R.Tensor((256,), dtype="float32"), fc2_weight: R.Tensor((10, 256), dtype="float32"), fc2_bias: R.Tensor((10,), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
permute_dims: R.Tensor((784, 256), dtype="float32") = R.permute_dims(fc1_weight, axes=None)
matmul: R.Tensor((1, 256), dtype="float32") = R.matmul(x, permute_dims, out_dtype="void")
add: R.Tensor((1, 256), dtype="float32") = R.add(matmul, fc1_bias)
relu: R.Tensor((1, 256), dtype="float32") = R.nn.relu(add)
permute_dims1: R.Tensor((256, 10), dtype="float32") = R.permute_dims(fc2_weight, axes=None)
matmul1: R.Tensor((1, 10), dtype="float32") = R.matmul(relu, permute_dims1, out_dtype="void")
add1: R.Tensor((1, 10), dtype="float32") = R.add(matmul1, fc2_bias)
gv: R.Tensor((1, 10), dtype="float32") = add1
R.output(gv)
return gv
执行优化转换#
Apache TVM 利用 pipeline
变换和优化程序。该管道封装了一系列变换,实现两个目标(在同一层级):
模型优化:例如算子融合、布局重写。
张量程序优化:将算子映射到底层实现(包括库或代码生成)
备注
这两个是目标,而不是 pipeline
的阶段。这两种优化是在同一层级进行的,或者在两个阶段分别进行。
在本教程中,只演示整体流程,通过利用 zero
优化管道,而不是针对任何特定目标进行优化。
mod = relax.get_pipeline("zero")(mod)
mod.show()
Show code cell output
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R
@I.ir_module
class Module:
@T.prim_func(private=True)
def fused_matmul1_add1(relu: T.Buffer((T.int64(1), T.int64(256)), "float32"), permute_dims1: T.Buffer((T.int64(256), T.int64(10)), "float32"), fc2_bias: T.Buffer((T.int64(10),), "float32"), T_add_intermediate: T.Buffer((T.int64(1), T.int64(10)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(10)))
for i0, i1, k in T.grid(T.int64(1), T.int64(10), T.int64(256)):
with T.block("matmul"):
v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
T.reads(relu[v_i0, v_k], permute_dims1[v_k, v_i1])
T.writes(matmul_intermediate[v_i0, v_i1])
with T.init():
matmul_intermediate[v_i0, v_i1] = T.float32(0.0)
matmul_intermediate[v_i0, v_i1] = matmul_intermediate[v_i0, v_i1] + relu[v_i0, v_k] * permute_dims1[v_k, v_i1]
for ax0, ax1 in T.grid(T.int64(1), T.int64(10)):
with T.block("T_add"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(matmul_intermediate[v_ax0, v_ax1], fc2_bias[v_ax1])
T.writes(T_add_intermediate[v_ax0, v_ax1])
T_add_intermediate[v_ax0, v_ax1] = matmul_intermediate[v_ax0, v_ax1] + fc2_bias[v_ax1]
@T.prim_func(private=True)
def fused_matmul_add_relu(x: T.Buffer((T.int64(1), T.int64(784)), "float32"), permute_dims: T.Buffer((T.int64(784), T.int64(256)), "float32"), fc1_bias: T.Buffer((T.int64(256),), "float32"), compute_intermediate: T.Buffer((T.int64(1), T.int64(256)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(256)))
T_add_intermediate = T.alloc_buffer((T.int64(1), T.int64(256)))
for i0, i1, k in T.grid(T.int64(1), T.int64(256), T.int64(784)):
with T.block("matmul"):
v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
T.reads(x[v_i0, v_k], permute_dims[v_k, v_i1])
T.writes(matmul_intermediate[v_i0, v_i1])
with T.init():
matmul_intermediate[v_i0, v_i1] = T.float32(0.0)
matmul_intermediate[v_i0, v_i1] = matmul_intermediate[v_i0, v_i1] + x[v_i0, v_k] * permute_dims[v_k, v_i1]
for ax0, ax1 in T.grid(T.int64(1), T.int64(256)):
with T.block("T_add"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(matmul_intermediate[v_ax0, v_ax1], fc1_bias[v_ax1])
T.writes(T_add_intermediate[v_ax0, v_ax1])
T_add_intermediate[v_ax0, v_ax1] = matmul_intermediate[v_ax0, v_ax1] + fc1_bias[v_ax1]
for i0, i1 in T.grid(T.int64(1), T.int64(256)):
with T.block("compute"):
v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
T.reads(T_add_intermediate[v_i0, v_i1])
T.writes(compute_intermediate[v_i0, v_i1])
compute_intermediate[v_i0, v_i1] = T.max(T_add_intermediate[v_i0, v_i1], T.float32(0.0))
@T.prim_func(private=True)
def transpose(fc1_weight: T.Buffer((T.int64(256), T.int64(784)), "float32"), T_transpose: T.Buffer((T.int64(784), T.int64(256)), "float32")):
T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1 in T.grid(T.int64(784), T.int64(256)):
with T.block("T_transpose"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(fc1_weight[v_ax1, v_ax0])
T.writes(T_transpose[v_ax0, v_ax1])
T_transpose[v_ax0, v_ax1] = fc1_weight[v_ax1, v_ax0]
@T.prim_func(private=True)
def transpose1(fc2_weight: T.Buffer((T.int64(10), T.int64(256)), "float32"), T_transpose: T.Buffer((T.int64(256), T.int64(10)), "float32")):
T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1 in T.grid(T.int64(256), T.int64(10)):
with T.block("T_transpose"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(fc2_weight[v_ax1, v_ax0])
T.writes(T_transpose[v_ax0, v_ax1])
T_transpose[v_ax0, v_ax1] = fc2_weight[v_ax1, v_ax0]
@R.function
def forward(x: R.Tensor((1, 784), dtype="float32"), fc1_weight: R.Tensor((256, 784), dtype="float32"), fc1_bias: R.Tensor((256,), dtype="float32"), fc2_weight: R.Tensor((10, 256), dtype="float32"), fc2_bias: R.Tensor((10,), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
R.func_attr({"num_input": 1})
cls = Module
with R.dataflow():
permute_dims = R.call_tir(cls.transpose, (fc1_weight,), out_sinfo=R.Tensor((784, 256), dtype="float32"))
lv = R.call_tir(cls.fused_matmul_add_relu, (x, permute_dims, fc1_bias), out_sinfo=R.Tensor((1, 256), dtype="float32"))
permute_dims1 = R.call_tir(cls.transpose1, (fc2_weight,), out_sinfo=R.Tensor((256, 10), dtype="float32"))
gv = R.call_tir(cls.fused_matmul1_add1, (lv, permute_dims1, fc2_bias), out_sinfo=R.Tensor((1, 10), dtype="float32"))
R.output(gv)
return gv
构建和通用部署#
优化完成后,将模型构建为可部署模块,并在不同设备上运行。
import numpy as np
target = tvm.target.Target("llvm")
ex = relax.build(mod, target)
device = tvm.cpu()
vm = relax.VirtualMachine(ex, device)
data = np.random.rand(1, 784).astype("float32")
tvm_data = tvm.nd.array(data, device=device)
params = [np.random.rand(*param.shape).astype("float32") for _, param in param_spec]
params = [tvm.nd.array(param, device=device) for param in params]
print(vm["forward"](tvm_data, *params).numpy())
[[24355.219 23935.193 26146.168 24525.396 23935.293 24308.348 24173.424
24650.494 24086.562 23485.203]]
目标是将机器学习带入任何感兴趣语言的应用程序中,同时提供最小的运行时支持。
在 IRModule 中的每个函数成为运行时的可运行函数。例如,在 LLM 情况下,可以直接调用
prefill
和decode
函数。prefill_logits = vm["prefill"](inputs, weight, kv_cache) decoded_logits = vm["decode"](inputs, weight, kv_cache)
TVM 运行时附带原生数据结构,如 NDArray,也可以与现有生态系统(通过 DLPack 与 PyTorch 交换)进行零拷贝交换。
# 将 PyTorch 张量转换为 TVM NDArray x_tvm = tvm.nd.from_dlpack(x_torch.to_dlpack()) # 转换 TVM NDArray 为 PyTorch 张量 x_torch = torch.from_dlpack(x_tvm.to_dlpack())
TVM 运行时在非 Python 环境中工作,因此它可以在移动设备等设置上运行。
// C++ snippet runtime::Module vm = ex.GetFunction("load_executable")(); vm.GetFunction("init")(...); NDArray out = vm.GetFunction("prefill")(data, weight, kv_cache);
// Java snippet Module vm = ex.getFunction("load_executable").invoke(); vm.getFunction("init").pushArg(...).invoke; NDArray out = vm.getFunction("prefill").pushArg(data).pushArg(weight).pushArg(kv_cache).invoke();