编译 PyTorch 模型#
Author: Yaoda Zhou
本文是一篇使用装饰器optimize_torch
优化PyTorch模型的教程。要跟随本教程,需要安装 PyTorch 以及 TorchVision:
%%shell
pip install torch
pip install torchvision
import set_env
import torch
import torch.nn as nn
import torch.nn.functional as F
# Import library for profiling
import torch.utils.benchmark as benchmark
from torchvision.models import resnet18
# Import `optimize_torch` function
from tvm.contrib.torch import optimize_torch
/media/pc/data/lxw/ai/tvm/python/tvm/contrib/torch/__init__.py:50: RuntimeWarning: The library libpt_tvmdsoop is not built successfully. /media/pc/data/lxw/ai/tvm/build/libpt_tvmdsoop.so: cannot open shared object file: No such file or directory
warnings.warn(
/media/pc/data/lxw/ai/tvm/python/tvm/contrib/torch/__init__.py:50: RuntimeWarning: The library libpt_tvmdsoop_new is not built successfully. /media/pc/data/lxw/ai/tvm/build/libpt_tvmdsoop_new.so: cannot open shared object file: No such file or directory
warnings.warn(
使用 PyTorch 构建简单模型#
class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
使用 TVM MetaSchedule 优化 SimpleModel#
我们提供了optimize_torch
函数,其用法与torch.jit.trace
类似。用户需要提供要优化的PyTorch模型以及其示例输入。PyTorch模块将由TVM针对目标硬件进行调优。如果不提供额外信息,模型将针对CPU进行调优。
simple_model = SimpleModel()
example_input = torch.randn(20, 1, 10, 10)
model_optimized_by_tvm = optimize_torch(simple_model, example_input, max_trials_global=2)
2024-03-20 12:21:19 [INFO] Logging directory: /tmp/tmpl0j3jqte/logs
2024-03-20 12:21:36 [INFO] LocalBuilder: max_workers = 24
2024-03-20 12:21:38 [INFO] LocalRunner: max_workers = 1
2024-03-20 12:21:39 [INFO] [task_scheduler.cc:159] Initializing Task #0: "fused_layout_transform"
2024-03-20 12:21:39 [INFO] [task_scheduler.cc:159] Initializing Task #1: "fused_nn_contrib_conv2d_NCHWc_add_nn_relu"
2024-03-20 12:21:39 [INFO] [task_scheduler.cc:159] Initializing Task #2: "fused_nn_contrib_conv2d_NCHWc_add_nn_relu_1"
2024-03-20 12:21:39 [INFO] [task_scheduler.cc:159] Initializing Task #3: "fused_layout_transform_1"
2024-03-20 12:21:40 [DEBUG] [task_scheduler.cc:318]
ID | Name | FLOP | Weight | Speed (GFLOPS) | Latency (us) | Weighted Latency (us) | Trials | Done
---------------------------------------------------------------------------------------------------------------------------------------------
0 | fused_layout_transform | 1 | 1 | N/A | N/A | N/A | 0 |
1 | fused_nn_contrib_conv2d_NCHWc_add_nn_relu | 748800 | 1 | N/A | N/A | N/A | 0 |
2 | fused_nn_contrib_conv2d_NCHWc_add_nn_relu_1 | 1603200 | 1 | N/A | N/A | N/A | 0 |
3 | fused_layout_transform_1 | 1 | 1 | N/A | N/A | N/A | 0 |
---------------------------------------------------------------------------------------------------------------------------------------------
Total trials: 0
Total latency (us): 0
Total trials: 0
Total latency (us): 0
2024-03-20 12:21:40 [INFO] [task_scheduler.cc:180] TaskScheduler picks Task #0: "fused_layout_transform"
2024-03-20 12:21:40 [INFO] [task_scheduler.cc:193] Sending 2 sample(s) to builder
2024-03-20 12:21:42 [INFO] [task_scheduler.cc:195] Sending 2 sample(s) to runner
2024-03-20 12:21:43 [DEBUG] XGB iter 0: tr-p-rmse: 0.424805 tr-a-peak@32: 1.000000 tr-rmse: 0.424910 tr-rmse: 0.424910
2024-03-20 12:21:43 [DEBUG] XGB iter 25: tr-p-rmse: 0.015707 tr-a-peak@32: 1.000000 tr-rmse: 0.015787 tr-rmse: 0.015787
2024-03-20 12:21:43 [DEBUG] XGB iter 50: tr-p-rmse: 0.010228 tr-a-peak@32: 1.000000 tr-rmse: 0.010230 tr-rmse: 0.010230
2024-03-20 12:21:43 [DEBUG] XGB iter 75: tr-p-rmse: 0.010225 tr-a-peak@32: 1.000000 tr-rmse: 0.010224 tr-rmse: 0.010224
2024-03-20 12:21:43 [DEBUG] XGB iter 100: tr-p-rmse: 0.010226 tr-a-peak@32: 1.000000 tr-rmse: 0.010224 tr-rmse: 0.010224
2024-03-20 12:21:43 [DEBUG] XGB stopped. Best iteration: [54] tr-p-rmse:0.01022 tr-a-peak@32:1.00000 tr-rmse:0.01023 tr-rmse:0.01023
2024-03-20 12:21:43 [INFO] [task_scheduler.cc:237] [Updated] Task #0: "fused_layout_transform"
2024-03-20 12:21:43 [DEBUG] [task_scheduler.cc:318]
ID | Name | FLOP | Weight | Speed (GFLOPS) | Latency (us) | Weighted Latency (us) | Trials | Done
---------------------------------------------------------------------------------------------------------------------------------------------
0 | fused_layout_transform | 1 | 1 | 0.0001 | 11.0077 | 11.0077 | 2 |
1 | fused_nn_contrib_conv2d_NCHWc_add_nn_relu | 748800 | 1 | N/A | N/A | N/A | 0 |
2 | fused_nn_contrib_conv2d_NCHWc_add_nn_relu_1 | 1603200 | 1 | N/A | N/A | N/A | 0 |
3 | fused_layout_transform_1 | 1 | 1 | N/A | N/A | N/A | 0 |
---------------------------------------------------------------------------------------------------------------------------------------------
Total trials: 2
Total latency (us): 11.0077
Total trials: 2
Total latency (us): 11.0077
2024-03-20 12:21:43 [INFO] [task_scheduler.cc:260] Task #0 has finished. Remaining task(s): 3
2024-03-20 12:21:43 [DEBUG] [task_scheduler.cc:318]
ID | Name | FLOP | Weight | Speed (GFLOPS) | Latency (us) | Weighted Latency (us) | Trials | Done
---------------------------------------------------------------------------------------------------------------------------------------------
0 | fused_layout_transform | 1 | 1 | 0.0001 | 11.0077 | 11.0077 | 2 | Y
1 | fused_nn_contrib_conv2d_NCHWc_add_nn_relu | 748800 | 1 | N/A | N/A | N/A | 0 |
2 | fused_nn_contrib_conv2d_NCHWc_add_nn_relu_1 | 1603200 | 1 | N/A | N/A | N/A | 0 |
3 | fused_layout_transform_1 | 1 | 1 | N/A | N/A | N/A | 0 |
---------------------------------------------------------------------------------------------------------------------------------------------
Total trials: 2
Total latency (us): 11.0077
Total trials: 2
Total latency (us): 11.0077
2024-03-20 12:21:43 [INFO] [task_scheduler.cc:260] Task #1 has finished. Remaining task(s): 2
2024-03-20 12:21:43 [DEBUG] [task_scheduler.cc:318]
ID | Name | FLOP | Weight | Speed (GFLOPS) | Latency (us) | Weighted Latency (us) | Trials | Done
---------------------------------------------------------------------------------------------------------------------------------------------
0 | fused_layout_transform | 1 | 1 | 0.0001 | 11.0077 | 11.0077 | 2 | Y
1 | fused_nn_contrib_conv2d_NCHWc_add_nn_relu | 748800 | 1 | N/A | N/A | N/A | 0 | Y
2 | fused_nn_contrib_conv2d_NCHWc_add_nn_relu_1 | 1603200 | 1 | N/A | N/A | N/A | 0 |
3 | fused_layout_transform_1 | 1 | 1 | N/A | N/A | N/A | 0 |
---------------------------------------------------------------------------------------------------------------------------------------------
Total trials: 2
Total latency (us): 11.0077
Total trials: 2
Total latency (us): 11.0077
2024-03-20 12:21:43 [INFO] [task_scheduler.cc:260] Task #2 has finished. Remaining task(s): 1
2024-03-20 12:21:43 [DEBUG] [task_scheduler.cc:318]
ID | Name | FLOP | Weight | Speed (GFLOPS) | Latency (us) | Weighted Latency (us) | Trials | Done
---------------------------------------------------------------------------------------------------------------------------------------------
0 | fused_layout_transform | 1 | 1 | 0.0001 | 11.0077 | 11.0077 | 2 | Y
1 | fused_nn_contrib_conv2d_NCHWc_add_nn_relu | 748800 | 1 | N/A | N/A | N/A | 0 | Y
2 | fused_nn_contrib_conv2d_NCHWc_add_nn_relu_1 | 1603200 | 1 | N/A | N/A | N/A | 0 | Y
3 | fused_layout_transform_1 | 1 | 1 | N/A | N/A | N/A | 0 |
---------------------------------------------------------------------------------------------------------------------------------------------
Total trials: 2
Total latency (us): 11.0077
Total trials: 2
Total latency (us): 11.0077
2024-03-20 12:21:43 [INFO] [task_scheduler.cc:260] Task #3 has finished. Remaining task(s): 0
2024-03-20 12:21:43 [DEBUG] [task_scheduler.cc:318]
ID | Name | FLOP | Weight | Speed (GFLOPS) | Latency (us) | Weighted Latency (us) | Trials | Done
---------------------------------------------------------------------------------------------------------------------------------------------
0 | fused_layout_transform | 1 | 1 | 0.0001 | 11.0077 | 11.0077 | 2 | Y
1 | fused_nn_contrib_conv2d_NCHWc_add_nn_relu | 748800 | 1 | N/A | N/A | N/A | 0 | Y
2 | fused_nn_contrib_conv2d_NCHWc_add_nn_relu_1 | 1603200 | 1 | N/A | N/A | N/A | 0 | Y
3 | fused_layout_transform_1 | 1 | 1 | N/A | N/A | N/A | 0 | Y
---------------------------------------------------------------------------------------------------------------------------------------------
Total trials: 2
Total latency (us): 11.0077
Total trials: 2
Total latency (us): 11.0077
Name | FLOP | Weight | Speed (GFLOPS) | Latency (us) | Weighted Latency (us) | Trials | Done | |
---|---|---|---|---|---|---|---|---|
0 | fused_layout_transform | 1 | 1 | N/A | N/A | N/A | 0 | |
1 | fused_nn_contrib_conv2d_NCHWc_add_nn_relu | 748800 | 1 | N/A | N/A | N/A | 0 | |
2 | fused_nn_contrib_conv2d_NCHWc_add_nn_relu_1 | 1603200 | 1 | N/A | N/A | N/A | 0 | |
3 | fused_layout_transform_1 | 1 | 1 | N/A | N/A | N/A | 0 |
Name | FLOP | Weight | Speed (GFLOPS) | Latency (us) | Weighted Latency (us) | Trials | Done | |
---|---|---|---|---|---|---|---|---|
0 | fused_layout_transform | 1 | 1 | 0.0001 | 11.0077 | 11.0077 | 2 | |
1 | fused_nn_contrib_conv2d_NCHWc_add_nn_relu | 748800 | 1 | N/A | N/A | N/A | 0 | |
2 | fused_nn_contrib_conv2d_NCHWc_add_nn_relu_1 | 1603200 | 1 | N/A | N/A | N/A | 0 | |
3 | fused_layout_transform_1 | 1 | 1 | N/A | N/A | N/A | 0 |
Name | FLOP | Weight | Speed (GFLOPS) | Latency (us) | Weighted Latency (us) | Trials | Done | |
---|---|---|---|---|---|---|---|---|
0 | fused_layout_transform | 1 | 1 | 0.0001 | 11.0077 | 11.0077 | 2 | Y |
1 | fused_nn_contrib_conv2d_NCHWc_add_nn_relu | 748800 | 1 | N/A | N/A | N/A | 0 | |
2 | fused_nn_contrib_conv2d_NCHWc_add_nn_relu_1 | 1603200 | 1 | N/A | N/A | N/A | 0 | |
3 | fused_layout_transform_1 | 1 | 1 | N/A | N/A | N/A | 0 |
Name | FLOP | Weight | Speed (GFLOPS) | Latency (us) | Weighted Latency (us) | Trials | Done | |
---|---|---|---|---|---|---|---|---|
0 | fused_layout_transform | 1 | 1 | 0.0001 | 11.0077 | 11.0077 | 2 | Y |
1 | fused_nn_contrib_conv2d_NCHWc_add_nn_relu | 748800 | 1 | N/A | N/A | N/A | 0 | Y |
2 | fused_nn_contrib_conv2d_NCHWc_add_nn_relu_1 | 1603200 | 1 | N/A | N/A | N/A | 0 | |
3 | fused_layout_transform_1 | 1 | 1 | N/A | N/A | N/A | 0 |
Name | FLOP | Weight | Speed (GFLOPS) | Latency (us) | Weighted Latency (us) | Trials | Done | |
---|---|---|---|---|---|---|---|---|
0 | fused_layout_transform | 1 | 1 | 0.0001 | 11.0077 | 11.0077 | 2 | Y |
1 | fused_nn_contrib_conv2d_NCHWc_add_nn_relu | 748800 | 1 | N/A | N/A | N/A | 0 | Y |
2 | fused_nn_contrib_conv2d_NCHWc_add_nn_relu_1 | 1603200 | 1 | N/A | N/A | N/A | 0 | Y |
3 | fused_layout_transform_1 | 1 | 1 | N/A | N/A | N/A | 0 |
Name | FLOP | Weight | Speed (GFLOPS) | Latency (us) | Weighted Latency (us) | Trials | Done | |
---|---|---|---|---|---|---|---|---|
0 | fused_layout_transform | 1 | 1 | 0.0001 | 11.0077 | 11.0077 | 2 | Y |
1 | fused_nn_contrib_conv2d_NCHWc_add_nn_relu | 748800 | 1 | N/A | N/A | N/A | 0 | Y |
2 | fused_nn_contrib_conv2d_NCHWc_add_nn_relu_1 | 1603200 | 1 | N/A | N/A | N/A | 0 | Y |
3 | fused_layout_transform_1 | 1 | 1 | N/A | N/A | N/A | 0 | Y |
[12:21:44] /media/pc/data/lxw/ai/tvm/src/relay/backend/te_compiler_cache.cc:679: Warning: Cannot find workload: fused_nn_contrib_conv2d_NCHWc_add_nn_relu
[12:21:44] /media/pc/data/lxw/ai/tvm/src/relay/backend/te_compiler_cache.cc:679: Warning: Cannot find workload: fused_nn_contrib_conv2d_NCHWc_add_nn_relu
[12:21:44] /media/pc/data/lxw/ai/tvm/src/relay/backend/te_compiler_cache.cc:679: Warning: Cannot find workload: fused_layout_transform
[12:21:44] /media/pc/data/lxw/ai/tvm/src/relay/backend/te_compiler_cache.cc:679: Warning: Cannot find workload: tvmgen_default_fused_nn_contrib_conv2d_NCHWc_add_nn_relu
[12:21:44] /media/pc/data/lxw/ai/tvm/src/relay/backend/te_compiler_cache.cc:679: Warning: Cannot find workload: tvmgen_default_fused_nn_contrib_conv2d_NCHWc_add_nn_relu_1
[12:21:44] /media/pc/data/lxw/ai/tvm/src/relay/backend/te_compiler_cache.cc:679: Warning: Cannot find workload: tvmgen_default_fused_layout_transform_1
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[6], line 3
1 simple_model = SimpleModel()
2 example_input = torch.randn(20, 1, 10, 10)
----> 3 model_optimized_by_tvm = optimize_torch(simple_model, example_input, max_trials_global=2)
File /media/pc/data/lxw/ai/tvm/python/tvm/contrib/torch/optimize_torch.py:166, in optimize_torch(func, example_inputs, max_trials_global, work_dir, target, max_trials_per_task, num_trials_per_iter, builder, runner, database, cost_model, measure_callbacks, task_scheduler, space, strategy, seed)
164 save_runtime_mod = get_global_func("tvmtorch.save_runtime_mod", allow_missing=True)
165 if save_runtime_mod is None:
--> 166 raise ValueError('optimize_torch requires the flag /"USE_PT_TVMDSOOP/" set in config.cmake')
167 save_runtime_mod(executor_factory.module)
169 return GraphExecutorFactoryWrapper(torch.classes.tvm_torch.GraphExecutorFactoryWrapper())
ValueError: optimize_torch requires the flag /"USE_PT_TVMDSOOP/" set in config.cmake
保存/加载模块#
我们可以像标准的nn.Module
一样保存和加载我们优化过的模块。
让我们运行我们的优化模块。
ret1 = model_optimized_by_tvm(example_input)
torch.save(model_optimized_by_tvm, "model_optimized.pt")
model_loaded = torch.load("model_optimized.pt")
---------------------------------------------------------------------------
NameError Traceback (most recent call last)
Cell In[7], line 1
----> 1 ret1 = model_optimized_by_tvm(example_input)
3 torch.save(model_optimized_by_tvm, "model_optimized.pt")
4 model_loaded = torch.load("model_optimized.pt")
NameError: name 'model_optimized_by_tvm' is not defined
# We load the module and run it again.
ret2 = model_loaded(example_input)
# We will show 2 results:
# (1) we can safely load and save model by showing the result of model
# after save and load operations is still the same as original one;
# (2) the model we optimize returns the same result as the original PyTorch model.
ret3 = simple_model(example_input)
testing.assert_allclose(ret1.detach().numpy(), ret2.detach().numpy(), atol=1e-5, rtol=1e-5)
testing.assert_allclose(ret1.detach().numpy(), ret3.detach().numpy(), atol=1e-5, rtol=1e-5)
######################################################################
# Optimize resnet18
# -----------------
# In the following, we will show that our approach is able to
# accelerate common models, such as resnet18.
# We will tune our model for the GPU.
target_cuda = "nvidia/geforce-rtx-3070"
# For PyTorch users, the code could be written as usual, except for
# applying "optimize_torch" function on the resnet18 model.
resnet18_tvm = optimize_torch(
resnet18().cuda().eval(), [torch.rand(1, 3, 224, 224).cuda()], target=target_cuda
)
# TorchScript also provides a built-in "optimize_for_inference" function to accelerate the inference.
resnet18_torch = torch.jit.optimize_for_inference(torch.jit.script(resnet18().cuda().eval()))
######################################################################
# Compare the performance between two approaches
# ----------------------------------------------
results = []
for i in range(5):
test_input = torch.rand(1, 3, 224, 224).cuda()
sub_label = f"[test {i}]"
results.append(
benchmark.Timer(
stmt="resnet18_tvm(test_input)",
setup="from __main__ import resnet18_tvm",
globals={"test_input": test_input},
sub_label=sub_label,
description="tuning by meta",
).blocked_autorange()
)
results.append(
benchmark.Timer(
stmt="resnet18_torch(test_input)",
setup="from __main__ import resnet18_torch",
globals={"test_input": test_input},
sub_label=sub_label,
description="tuning by jit",
).blocked_autorange()
)
compare = benchmark.Compare(results)
compare.print()
# In author's environment, the average inference time of `resnet18_tvm` is 620.0 us,
# while the average inference time of `resnet18_torch` is 980.0 us (PyTorch version is 1.11.0),
# showing the speedup of around 38%.