将你的 TVMScript 代码包装为 PyTorch 模块#
作者:Yaoda Zhou
本文是关于如何将 TVMScript 代码包装为 PyTorch 模块的教程。
使用装饰器 as_torch
,用户可以自然地将 TVMScript 代码包装成 PyTorch torch.~nn.Module
。
要跟随本教程,需要安装 PyTorch。
%%shell
pip install torch
import set_env
import torch
import torch.nn.functional as F
import torch.utils.benchmark as benchmark
import tvm
from tvm.contrib.torch import as_torch
from tvm.script import tir as T
/media/pc/data/lxw/ai/tvm/python/tvm/contrib/torch/__init__.py:50: RuntimeWarning: The library libpt_tvmdsoop is not built successfully. /media/pc/data/lxw/ai/tvm/build/libpt_tvmdsoop.so: cannot open shared object file: No such file or directory
warnings.warn(
/media/pc/data/lxw/ai/tvm/python/tvm/contrib/torch/__init__.py:50: RuntimeWarning: The library libpt_tvmdsoop_new is not built successfully. /media/pc/data/lxw/ai/tvm/build/libpt_tvmdsoop_new.so: cannot open shared object file: No such file or directory
warnings.warn(
使用 TVMScript 编写自己的 PyTorch 算子#
PyTorch 是非常流行的机器学习框架,其中包含了大多数常用算子的优化实现。尽管如此,有时你可能想在 PyTorch 中编写自己的算子。在这种情况下,这些自定义算子的性能可能无法满足你的需求。
例如,假设我们要定义 1-d 深度卷积算子,输入通道数和输出通道数都是 70,宽度是 80,卷积核大小是 20,那么1-d深度卷积可以在 PyTorch 中用一行代码来表示:
in_channel = 70
out_channel = 70
width = 80
kernel_size = 20
def torch_depthwise(inputs, filters):
return F.conv1d(inputs, filters.view(out_channel, 1, kernel_size), groups=out_channel)
可以这样运行函数:
inputs = torch.randn(in_channel, width)
filters = torch.randn(out_channel, kernel_size)
ret_torch = torch_depthwise(inputs, filters)
在普通的 Python 代码中,torch_depthwise
函数可以写成:
def vanilla_depthwise(input, weight):
ret = torch.zeros(out_channel, width - kernel_size + 1)
for j in range(out_channel):
for i in range(width - kernel_size + 1):
for k in range(kernel_size):
ret[j, i] += weight[j, k] * input[j, i + k]
return ret
然后,计划利用 TVM 的强大功能来优化 depthwise
函数。TVM 社区提出了一种嵌入在 Python 中的特定领域语言,称为 TVMScript,它作为 TVM 的 Tensor IR 的高级前端。
上面的深度卷积 1D 代码可以按照如下方式转换为 TVMScript。我们提供了 as_torch
装饰器,它会自动将 TVMScript 代码转换为 PyTorch 的 nn.Module
。
@as_torch
@T.prim_func
def tvm_depthwise(
A: T.Buffer((70, 80), "float32"),
B: T.Buffer((70, 20), "float32"),
C: T.Buffer((70, 61), "float32"),
) -> None:
for j, i, k in T.grid(70, 61, 20):
with T.block():
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
with T.init():
C[vj, vi] = T.float32(0)
C[vj, vi] += B[vj, vk] * A[vj, vi + vk]
我们可以通过调用默认设置下的 tune
方法来构建 TVMScript 代码。如果不提供额外信息,模型将会针对 CPU 进行优化。
tvm_depthwise.tune()
2024-03-20 12:10:34 [INFO] Logging directory: /tmp/tmphj33434s/logs
2024-03-20 12:10:48 [INFO] LocalBuilder: max_workers = 24
2024-03-20 12:10:50 [INFO] LocalRunner: max_workers = 1
2024-03-20 12:10:51 [INFO] [task_scheduler.cc:159] Initializing Task #0: "main"
2024-03-20 12:10:51 [INFO] [task_scheduler.cc:180] TaskScheduler picks Task #0: "main"
2024-03-20 12:10:51 [INFO] [task_scheduler.cc:193] Sending 32 sample(s) to builder
2024-03-20 12:10:56 [INFO] [task_scheduler.cc:195] Sending 32 sample(s) to runner
2024-03-20 12:11:04 [DEBUG] XGB iter 0: tr-p-rmse: 0.394348 tr-a-peak@32: 0.999438 tr-rmse: 0.394544 tr-rmse: 0.394544
2024-03-20 12:11:04 [DEBUG] XGB iter 25: tr-p-rmse: 0.013164 tr-a-peak@32: 0.999686 tr-rmse: 0.012791 tr-rmse: 0.012791
2024-03-20 12:11:04 [DEBUG] XGB iter 50: tr-p-rmse: 0.013063 tr-a-peak@32: 0.999686 tr-rmse: 0.012660 tr-rmse: 0.012660
2024-03-20 12:11:04 [DEBUG] XGB iter 75: tr-p-rmse: 0.013063 tr-a-peak@32: 0.999686 tr-rmse: 0.012660 tr-rmse: 0.012660
2024-03-20 12:11:04 [DEBUG] XGB stopped. Best iteration: [34] tr-p-rmse:0.01306 tr-a-peak@32:0.99969 tr-rmse:0.01266 tr-rmse:0.01266
2024-03-20 12:11:04 [INFO] [task_scheduler.cc:237] [Updated] Task #0: "main"
Total trials: 32
Total latency (us): 12.9516
2024-03-20 12:11:04 [DEBUG] [task_scheduler.cc:318]
ID | Name | FLOP | Weight | Speed (GFLOPS) | Latency (us) | Weighted Latency (us) | Trials | Done
-----------------------------------------------------------------------------------------------------
0 | main | 170800 | 1 | 13.1876 | 12.9516 | 12.9516 | 32 |
-----------------------------------------------------------------------------------------------------
Total trials: 32
Total latency (us): 12.9516
2024-03-20 12:11:04 [INFO] [task_scheduler.cc:260] Task #0 has finished. Remaining task(s): 0
2024-03-20 12:11:04 [DEBUG] [task_scheduler.cc:318]
ID | Name | FLOP | Weight | Speed (GFLOPS) | Latency (us) | Weighted Latency (us) | Trials | Done
-----------------------------------------------------------------------------------------------------
0 | main | 170800 | 1 | 13.1876 | 12.9516 | 12.9516 | 32 | Y
-----------------------------------------------------------------------------------------------------
Total trials: 32
Total latency (us): 12.9516
Total trials: 32
Total latency (us): 12.9516
Name | FLOP | Weight | Speed (GFLOPS) | Latency (us) | Weighted Latency (us) | Trials | Done | |
---|---|---|---|---|---|---|---|---|
0 | main | 170800 | 1 | 13.1876 | 12.9516 | 12.9516 | 32 |
Name | FLOP | Weight | Speed (GFLOPS) | Latency (us) | Weighted Latency (us) | Trials | Done | |
---|---|---|---|---|---|---|---|---|
0 | main | 170800 | 1 | 13.1876 | 12.9516 | 12.9516 | 32 | Y |
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[7], line 1
----> 1 tvm_depthwise.tune()
File /media/pc/data/lxw/ai/tvm/python/tvm/contrib/torch/as_torch.py:107, in OperatorModuleWrapper.tune(self, target, max_trials_global, num_trials_per_iter, builder, runner, database, cost_model, measure_callbacks, task_scheduler, space, strategy, num_tuning_cores, seed)
105 sch = ms.tir_integration.compile_tir(database, self.ir_module, target)
106 self.ir_module = sch.mod
--> 107 self.build(target)
File /media/pc/data/lxw/ai/tvm/python/tvm/contrib/torch/as_torch.py:117, in OperatorModuleWrapper.build(self, target)
114 func = tvm.get_global_func("tvmtorch.save_runtime_mod", allow_missing=True)
116 if func is None:
--> 117 raise ValueError('as_torch requires the flag /"USE_PT_TVMDSOOP/" set in config.cmake')
118 func(runtime_module)
120 self.rt_module = torch.classes.tvm_torch.OperatorModuleWrapper()
ValueError: as_torch requires the flag /"USE_PT_TVMDSOOP/" set in config.cmake
可以打印出优化后的 TVMScript 代码,以查看程序是如何被转换的,如下
print(tvm_depthwise.script())
可以验证这两个输出是相同的:
ret_tvm = torch.zeros(out_channel, width - kernel_size + 1)
tvm_depthwise(inputs, filters, ret_tvm)
testing.assert_allclose(ret_torch.cpu().numpy(), ret_tvm.cpu().numpy(), atol=1e-5, rtol=1e-5)
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[8], line 2
1 ret_tvm = torch.zeros(out_channel, width - kernel_size + 1)
----> 2 tvm_depthwise(inputs, filters, ret_tvm)
4 testing.assert_allclose(ret_torch.cpu().numpy(), ret_tvm.cpu().numpy(), atol=1e-5, rtol=1e-5)
File /media/pc/data/tmp/cache/conda/envs/xin/lib/python3.12/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
1509 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1510 else:
-> 1511 return self._call_impl(*args, **kwargs)
File /media/pc/data/tmp/cache/conda/envs/xin/lib/python3.12/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
1515 # If we don't have any hooks, we want to skip the rest of the logic in
1516 # this function, and just call forward.
1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1518 or _global_backward_pre_hooks or _global_backward_hooks
1519 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520 return forward_call(*args, **kwargs)
1522 try:
1523 result = None
File /media/pc/data/lxw/ai/tvm/python/tvm/contrib/torch/as_torch.py:127, in OperatorModuleWrapper.forward(self, *torch_inputs)
125 self.build(target="cuda")
126 elif torch_inputs[0].device.type == "cpu":
--> 127 self.build()
128 else:
129 raise Exception(f"the target {torch_inputs[0].device.type} is not supported yet")
File /media/pc/data/lxw/ai/tvm/python/tvm/contrib/torch/as_torch.py:117, in OperatorModuleWrapper.build(self, target)
114 func = tvm.get_global_func("tvmtorch.save_runtime_mod", allow_missing=True)
116 if func is None:
--> 117 raise ValueError('as_torch requires the flag /"USE_PT_TVMDSOOP/" set in config.cmake')
118 func(runtime_module)
120 self.rt_module = torch.classes.tvm_torch.OperatorModuleWrapper()
ValueError: as_torch requires the flag /"USE_PT_TVMDSOOP/" set in config.cmake
Benchmark#
results = []
for i in range(5):
inputs = torch.randn(out_channel, width)
filters = torch.randn(out_channel, kernel_size)
res = torch.zeros(out_channel, width - kernel_size + 1)
sub_label = f"[test {i}]"
results.append(
benchmark.Timer(
stmt="tvm_depthwise(inputs, filters, res)",
setup="from __main__ import tvm_depthwise",
globals={"inputs": inputs, "filters": filters, "res": res},
sub_label=sub_label,
description="TVMScript",
).blocked_autorange()
)
results.append(
benchmark.Timer(
stmt="torch_depthwise(inputs, filters)",
setup="from __main__ import torch_depthwise",
globals={
"inputs": inputs,
"filters": filters,
},
sub_label=sub_label,
description="PyTorch",
).blocked_autorange()
)
compare = benchmark.Compare(results)
compare.print()
在作者的环境中,tvm_depthwise
的平均推理时间是120.0微秒,而 torch_depthwise
的平均推理时间是196.0微秒(PyTorch版本是1.11.0),显示出大约38%的速度提升。