将你的 TVMScript 代码包装为 PyTorch 模块

将你的 TVMScript 代码包装为 PyTorch 模块#

作者Yaoda Zhou

本文是关于如何将 TVMScript 代码包装为 PyTorch 模块的教程。 使用装饰器 as_torch,用户可以自然地将 TVMScript 代码包装成 PyTorch torch.~nn.Module

要跟随本教程,需要安装 PyTorch。

%%shell
pip install torch
import set_env
import torch
import torch.nn.functional as F
import torch.utils.benchmark as benchmark

import tvm
from tvm.contrib.torch import as_torch
from tvm.script import tir as T
/media/pc/data/lxw/ai/tvm/python/tvm/contrib/torch/__init__.py:50: RuntimeWarning: The library libpt_tvmdsoop is not built successfully. /media/pc/data/lxw/ai/tvm/build/libpt_tvmdsoop.so: cannot open shared object file: No such file or directory
  warnings.warn(
/media/pc/data/lxw/ai/tvm/python/tvm/contrib/torch/__init__.py:50: RuntimeWarning: The library libpt_tvmdsoop_new is not built successfully. /media/pc/data/lxw/ai/tvm/build/libpt_tvmdsoop_new.so: cannot open shared object file: No such file or directory
  warnings.warn(

使用 TVMScript 编写自己的 PyTorch 算子#

PyTorch 是非常流行的机器学习框架,其中包含了大多数常用算子的优化实现。尽管如此,有时你可能想在 PyTorch 中编写自己的算子。在这种情况下,这些自定义算子的性能可能无法满足你的需求。

例如,假设我们要定义 1-d 深度卷积算子,输入通道数和输出通道数都是 70,宽度是 80,卷积核大小是 20,那么1-d深度卷积可以在 PyTorch 中用一行代码来表示:

in_channel = 70
out_channel = 70
width = 80
kernel_size = 20


def torch_depthwise(inputs, filters):
    return F.conv1d(inputs, filters.view(out_channel, 1, kernel_size), groups=out_channel)

可以这样运行函数:

inputs = torch.randn(in_channel, width)
filters = torch.randn(out_channel, kernel_size)
ret_torch = torch_depthwise(inputs, filters)

在普通的 Python 代码中,torch_depthwise 函数可以写成:

def vanilla_depthwise(input, weight):
    ret = torch.zeros(out_channel, width - kernel_size + 1)
    for j in range(out_channel):
        for i in range(width - kernel_size + 1):
            for k in range(kernel_size):
                ret[j, i] += weight[j, k] * input[j, i + k]
    return ret

然后,计划利用 TVM 的强大功能来优化 depthwise 函数。TVM 社区提出了一种嵌入在 Python 中的特定领域语言,称为 TVMScript,它作为 TVM 的 Tensor IR 的高级前端。

上面的深度卷积 1D 代码可以按照如下方式转换为 TVMScript。我们提供了 as_torch 装饰器,它会自动将 TVMScript 代码转换为 PyTorch 的 nn.Module

@as_torch
@T.prim_func
def tvm_depthwise(
    A: T.Buffer((70, 80), "float32"),
    B: T.Buffer((70, 20), "float32"),
    C: T.Buffer((70, 61), "float32"),
) -> None:
    for j, i, k in T.grid(70, 61, 20):
        with T.block():
            vi, vj, vk = T.axis.remap("SSR", [i, j, k])
            with T.init():
                C[vj, vi] = T.float32(0)
            C[vj, vi] += B[vj, vk] * A[vj, vi + vk]

我们可以通过调用默认设置下的 tune 方法来构建 TVMScript 代码。如果不提供额外信息,模型将会针对 CPU 进行优化。

tvm_depthwise.tune()
2024-03-20 12:10:34 [INFO] Logging directory: /tmp/tmphj33434s/logs
2024-03-20 12:10:48 [INFO] LocalBuilder: max_workers = 24
2024-03-20 12:10:50 [INFO] LocalRunner: max_workers = 1
2024-03-20 12:10:51 [INFO] [task_scheduler.cc:159] Initializing Task #0: "main"
2024-03-20 12:10:51 [INFO] [task_scheduler.cc:180] TaskScheduler picks Task #0: "main"
2024-03-20 12:10:51 [INFO] [task_scheduler.cc:193] Sending 32 sample(s) to builder
2024-03-20 12:10:56 [INFO] [task_scheduler.cc:195] Sending 32 sample(s) to runner
2024-03-20 12:11:04 [DEBUG] XGB iter   0: tr-p-rmse: 0.394348	tr-a-peak@32: 0.999438	tr-rmse: 0.394544	tr-rmse: 0.394544
2024-03-20 12:11:04 [DEBUG] XGB iter  25: tr-p-rmse: 0.013164	tr-a-peak@32: 0.999686	tr-rmse: 0.012791	tr-rmse: 0.012791
2024-03-20 12:11:04 [DEBUG] XGB iter  50: tr-p-rmse: 0.013063	tr-a-peak@32: 0.999686	tr-rmse: 0.012660	tr-rmse: 0.012660
2024-03-20 12:11:04 [DEBUG] XGB iter  75: tr-p-rmse: 0.013063	tr-a-peak@32: 0.999686	tr-rmse: 0.012660	tr-rmse: 0.012660
2024-03-20 12:11:04 [DEBUG] XGB stopped. Best iteration: [34] tr-p-rmse:0.01306	tr-a-peak@32:0.99969	tr-rmse:0.01266	tr-rmse:0.01266 
2024-03-20 12:11:04 [INFO] [task_scheduler.cc:237] [Updated] Task #0: "main"

Total trials: 32
Total latency (us): 12.9516

2024-03-20 12:11:04 [DEBUG] [task_scheduler.cc:318] 
 ID | Name |   FLOP | Weight | Speed (GFLOPS) | Latency (us) | Weighted Latency (us) | Trials | Done 
-----------------------------------------------------------------------------------------------------
  0 | main | 170800 |      1 |        13.1876 |      12.9516 |               12.9516 |     32 |      
-----------------------------------------------------------------------------------------------------
Total trials: 32
Total latency (us): 12.9516

2024-03-20 12:11:04 [INFO] [task_scheduler.cc:260] Task #0 has finished. Remaining task(s): 0
2024-03-20 12:11:04 [DEBUG] [task_scheduler.cc:318] 
 ID | Name |   FLOP | Weight | Speed (GFLOPS) | Latency (us) | Weighted Latency (us) | Trials | Done 
-----------------------------------------------------------------------------------------------------
  0 | main | 170800 |      1 |        13.1876 |      12.9516 |               12.9516 |     32 |    Y 
-----------------------------------------------------------------------------------------------------
Total trials: 32
Total latency (us): 12.9516


Total trials: 32
Total latency (us): 12.9516
Name FLOP Weight Speed (GFLOPS) Latency (us) Weighted Latency (us) Trials Done
0 main 170800 1 13.1876 12.9516 12.9516 32
Name FLOP Weight Speed (GFLOPS) Latency (us) Weighted Latency (us) Trials Done
0 main 170800 1 13.1876 12.9516 12.9516 32 Y
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[7], line 1
----> 1 tvm_depthwise.tune()

File /media/pc/data/lxw/ai/tvm/python/tvm/contrib/torch/as_torch.py:107, in OperatorModuleWrapper.tune(self, target, max_trials_global, num_trials_per_iter, builder, runner, database, cost_model, measure_callbacks, task_scheduler, space, strategy, num_tuning_cores, seed)
    105 sch = ms.tir_integration.compile_tir(database, self.ir_module, target)
    106 self.ir_module = sch.mod
--> 107 self.build(target)

File /media/pc/data/lxw/ai/tvm/python/tvm/contrib/torch/as_torch.py:117, in OperatorModuleWrapper.build(self, target)
    114 func = tvm.get_global_func("tvmtorch.save_runtime_mod", allow_missing=True)
    116 if func is None:
--> 117     raise ValueError('as_torch requires the flag /"USE_PT_TVMDSOOP/" set in config.cmake')
    118 func(runtime_module)
    120 self.rt_module = torch.classes.tvm_torch.OperatorModuleWrapper()

ValueError: as_torch requires the flag /"USE_PT_TVMDSOOP/" set in config.cmake

可以打印出优化后的 TVMScript 代码,以查看程序是如何被转换的,如下

print(tvm_depthwise.script())

可以验证这两个输出是相同的:

ret_tvm = torch.zeros(out_channel, width - kernel_size + 1)
tvm_depthwise(inputs, filters, ret_tvm)

testing.assert_allclose(ret_torch.cpu().numpy(), ret_tvm.cpu().numpy(), atol=1e-5, rtol=1e-5)
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[8], line 2
      1 ret_tvm = torch.zeros(out_channel, width - kernel_size + 1)
----> 2 tvm_depthwise(inputs, filters, ret_tvm)
      4 testing.assert_allclose(ret_torch.cpu().numpy(), ret_tvm.cpu().numpy(), atol=1e-5, rtol=1e-5)

File /media/pc/data/tmp/cache/conda/envs/xin/lib/python3.12/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File /media/pc/data/tmp/cache/conda/envs/xin/lib/python3.12/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File /media/pc/data/lxw/ai/tvm/python/tvm/contrib/torch/as_torch.py:127, in OperatorModuleWrapper.forward(self, *torch_inputs)
    125     self.build(target="cuda")
    126 elif torch_inputs[0].device.type == "cpu":
--> 127     self.build()
    128 else:
    129     raise Exception(f"the target {torch_inputs[0].device.type} is not supported yet")

File /media/pc/data/lxw/ai/tvm/python/tvm/contrib/torch/as_torch.py:117, in OperatorModuleWrapper.build(self, target)
    114 func = tvm.get_global_func("tvmtorch.save_runtime_mod", allow_missing=True)
    116 if func is None:
--> 117     raise ValueError('as_torch requires the flag /"USE_PT_TVMDSOOP/" set in config.cmake')
    118 func(runtime_module)
    120 self.rt_module = torch.classes.tvm_torch.OperatorModuleWrapper()

ValueError: as_torch requires the flag /"USE_PT_TVMDSOOP/" set in config.cmake

Benchmark#

results = []
for i in range(5):
    inputs = torch.randn(out_channel, width)
    filters = torch.randn(out_channel, kernel_size)
    res = torch.zeros(out_channel, width - kernel_size + 1)
    sub_label = f"[test {i}]"
    results.append(
        benchmark.Timer(
            stmt="tvm_depthwise(inputs, filters, res)",
            setup="from __main__ import tvm_depthwise",
            globals={"inputs": inputs, "filters": filters, "res": res},
            sub_label=sub_label,
            description="TVMScript",
        ).blocked_autorange()
    )
    results.append(
        benchmark.Timer(
            stmt="torch_depthwise(inputs, filters)",
            setup="from __main__ import torch_depthwise",
            globals={
                "inputs": inputs,
                "filters": filters,
            },
            sub_label=sub_label,
            description="PyTorch",
        ).blocked_autorange()
    )
compare = benchmark.Compare(results)
compare.print()

在作者的环境中,tvm_depthwise 的平均推理时间是120.0微秒,而 torch_depthwise 的平均推理时间是196.0微秒(PyTorch版本是1.11.0),显示出大约38%的速度提升。