TVM 样例#
Apache TVM 是遵循 Python 优先开发、通用部署原则的机器学习编译框架。它接收预训练的机器学习模型,编译并生成可嵌入和在任何地方运行的部署模块。Apache TVM 还允许自定义优化过程,引入新的优化、库、代码生成等。
Apache TVM 可以帮助:
优化:ML 工作负载、组合库和代码生成的性能。
部署:ML 工作负载部署到一组不同的新环境,包括新运行时和新的硬件。
持续改进和定制:通过快速自定义在 Python 中部署 ML 管道 库调度,引入自定义算子和代码生成。
整体流程包括以下步骤:
构建或导入模型:构建神经网络模型或从其他框架(例如 PyTorch、ONNX)导入预训练模型,并创建 TVM IRModule,其中包含编译所需的所有信息,包括用于计算图的高级 Relax 函数和用于张量程序的低级 TensorIR 函数。
执行可组合优化:执行一系列优化变换,如计算图优化、张量程序优化和库调度。
构建和通用部署:将优化后的模型构建为可部署模块到通用运行时,并在不同设备上执行,如 CPU、GPU 或其他加速器。
构造或导入模型#
使用 TVM Relax 前端直接定义两层的感知器(MLP)网络,其 API 与 PyTorch 类似。
import tvm
from tvm import relax
from tvm.relax.frontend import nn
class MLPModel(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(784, 256)
self.relu1 = nn.ReLU()
self.fc2 = nn.Linear(256, 10)
def forward(self, x):
x = self.fc1(x)
x = self.relu1(x)
x = self.fc2(x)
return x
可以将模型导出为 TVM IRModule,这是 TVM 中的核心中间表示。
model = MLPModel()
mod, param_spec = model.export_tvm(
spec={"forward": {"x": nn.spec.Tensor((1, 784), "float32")}}
)
mod.show()
执行优化转换#
Apache TVM 利用 pipeline
变换和优化程序。该管道封装了一系列变换,实现两个目标(在同一层级):
模型优化:例如算子融合、布局重写。
张量程序优化:将算子映射到底层实现(包括库或代码生成)
备注
这两个是目标,而不是 pipeline
的阶段。这两种优化是在同一层级进行的,或者在两个阶段分别进行。
在本教程中,只演示整体流程,通过利用 zero
优化管道,而不是针对任何特定目标进行优化。
mod = relax.get_pipeline("zero")(mod)
mod.show()
构建和通用部署#
优化完成后,将模型构建为可部署模块,并在不同设备上运行。
import numpy as np
target = tvm.target.Target("llvm")
ex = tvm.compile(mod, target)
device = tvm.cpu()
vm = relax.VirtualMachine(ex, device)
data = np.random.rand(1, 784).astype("float32")
tvm_data = tvm.nd.array(data, device=device)
params = [np.random.rand(*param.shape).astype("float32") for _, param in param_spec]
params = [tvm.nd.array(param, device=device) for param in params]
print(vm["forward"](tvm_data, *params).numpy())
[[24432.27 26025.193 24283.773 25502.045 26776.664 27174.902 26169.416
25385.572 24756.78 26742.55 ]]
TVM 的目标是将机器学习带入任何感兴趣语言的应用程序中,同时提供最小的运行时支持。
在 IRModule 中的每个函数成为运行时的可运行函数。例如,在 LLM 情况下,可以直接调用
prefill
和decode
函数。prefill_logits = vm["prefill"](inputs, weight, kv_cache) decoded_logits = vm["decode"](inputs, weight, kv_cache)
TVM 运行时附带原生数据结构,如 NDArray,也可以与现有生态系统(通过 DLPack 与 PyTorch 交换)进行零拷贝交换。
# 将 PyTorch 张量转换为 TVM NDArray x_tvm = tvm.nd.from_dlpack(x_torch.to_dlpack()) # 转换 TVM NDArray 为 PyTorch 张量 x_torch = torch.from_dlpack(x_tvm.to_dlpack())
TVM 运行时在非 Python 环境中工作,因此它可以在移动设备等设置上运行。
// C++ snippet runtime::Module vm = ex.GetFunction("load_executable")(); vm.GetFunction("init")(...); NDArray out = vm.GetFunction("prefill")(data, weight, kv_cache);
// Java snippet Module vm = ex.getFunction("load_executable").invoke(); vm.getFunction("init").pushArg(...).invoke; NDArray out = vm.getFunction("prefill").pushArg(data).pushArg(weight).pushArg(kv_cache).invoke();