编译 PyTorch 模型#

Author: Yaoda Zhou

本文是一篇使用装饰器optimize_torch优化PyTorch模型的教程。要跟随本教程,需要安装 PyTorch 以及 TorchVision:

%%shell
pip install torch
pip install torchvision
import set_env
import torch
import torch.nn as nn
import torch.nn.functional as F

# Import library for profiling
import torch.utils.benchmark as benchmark
from torchvision.models import resnet18

# Import `optimize_torch` function
from tvm.contrib.torch import optimize_torch
/media/pc/data/lxw/ai/tvm/python/tvm/contrib/torch/__init__.py:50: RuntimeWarning: The library libpt_tvmdsoop is not built successfully. /media/pc/data/lxw/ai/tvm/build/libpt_tvmdsoop.so: cannot open shared object file: No such file or directory
  warnings.warn(
/media/pc/data/lxw/ai/tvm/python/tvm/contrib/torch/__init__.py:50: RuntimeWarning: The library libpt_tvmdsoop_new is not built successfully. /media/pc/data/lxw/ai/tvm/build/libpt_tvmdsoop_new.so: cannot open shared object file: No such file or directory
  warnings.warn(

使用 PyTorch 构建简单模型#

class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

使用 TVM MetaSchedule 优化 SimpleModel#

我们提供了optimize_torch函数,其用法与torch.jit.trace类似。用户需要提供要优化的PyTorch模型以及其示例输入。PyTorch模块将由TVM针对目标硬件进行调优。如果不提供额外信息,模型将针对CPU进行调优。

simple_model = SimpleModel()
example_input = torch.randn(20, 1, 10, 10)
model_optimized_by_tvm = optimize_torch(simple_model, example_input, max_trials_global=2)
2024-03-20 12:21:19 [INFO] Logging directory: /tmp/tmpl0j3jqte/logs
2024-03-20 12:21:36 [INFO] LocalBuilder: max_workers = 24
2024-03-20 12:21:38 [INFO] LocalRunner: max_workers = 1
2024-03-20 12:21:39 [INFO] [task_scheduler.cc:159] Initializing Task #0: "fused_layout_transform"
2024-03-20 12:21:39 [INFO] [task_scheduler.cc:159] Initializing Task #1: "fused_nn_contrib_conv2d_NCHWc_add_nn_relu"
2024-03-20 12:21:39 [INFO] [task_scheduler.cc:159] Initializing Task #2: "fused_nn_contrib_conv2d_NCHWc_add_nn_relu_1"
2024-03-20 12:21:39 [INFO] [task_scheduler.cc:159] Initializing Task #3: "fused_layout_transform_1"
2024-03-20 12:21:40 [DEBUG] [task_scheduler.cc:318] 
 ID |                                        Name |    FLOP | Weight | Speed (GFLOPS) | Latency (us) | Weighted Latency (us) | Trials | Done 
---------------------------------------------------------------------------------------------------------------------------------------------
  0 |                      fused_layout_transform |       1 |      1 |            N/A |          N/A |                   N/A |      0 |      
  1 |   fused_nn_contrib_conv2d_NCHWc_add_nn_relu |  748800 |      1 |            N/A |          N/A |                   N/A |      0 |      
  2 | fused_nn_contrib_conv2d_NCHWc_add_nn_relu_1 | 1603200 |      1 |            N/A |          N/A |                   N/A |      0 |      
  3 |                    fused_layout_transform_1 |       1 |      1 |            N/A |          N/A |                   N/A |      0 |      
---------------------------------------------------------------------------------------------------------------------------------------------
Total trials: 0
Total latency (us): 0


Total trials: 0
Total latency (us): 0

2024-03-20 12:21:40 [INFO] [task_scheduler.cc:180] TaskScheduler picks Task #0: "fused_layout_transform"
2024-03-20 12:21:40 [INFO] [task_scheduler.cc:193] Sending 2 sample(s) to builder
2024-03-20 12:21:42 [INFO] [task_scheduler.cc:195] Sending 2 sample(s) to runner
2024-03-20 12:21:43 [DEBUG] XGB iter   0: tr-p-rmse: 0.424805	tr-a-peak@32: 1.000000	tr-rmse: 0.424910	tr-rmse: 0.424910
2024-03-20 12:21:43 [DEBUG] XGB iter  25: tr-p-rmse: 0.015707	tr-a-peak@32: 1.000000	tr-rmse: 0.015787	tr-rmse: 0.015787
2024-03-20 12:21:43 [DEBUG] XGB iter  50: tr-p-rmse: 0.010228	tr-a-peak@32: 1.000000	tr-rmse: 0.010230	tr-rmse: 0.010230
2024-03-20 12:21:43 [DEBUG] XGB iter  75: tr-p-rmse: 0.010225	tr-a-peak@32: 1.000000	tr-rmse: 0.010224	tr-rmse: 0.010224
2024-03-20 12:21:43 [DEBUG] XGB iter 100: tr-p-rmse: 0.010226	tr-a-peak@32: 1.000000	tr-rmse: 0.010224	tr-rmse: 0.010224
2024-03-20 12:21:43 [DEBUG] XGB stopped. Best iteration: [54] tr-p-rmse:0.01022	tr-a-peak@32:1.00000	tr-rmse:0.01023	tr-rmse:0.01023 
2024-03-20 12:21:43 [INFO] [task_scheduler.cc:237] [Updated] Task #0: "fused_layout_transform"
2024-03-20 12:21:43 [DEBUG] [task_scheduler.cc:318] 
 ID |                                        Name |    FLOP | Weight | Speed (GFLOPS) | Latency (us) | Weighted Latency (us) | Trials | Done 
---------------------------------------------------------------------------------------------------------------------------------------------
  0 |                      fused_layout_transform |       1 |      1 |         0.0001 |      11.0077 |               11.0077 |      2 |      
  1 |   fused_nn_contrib_conv2d_NCHWc_add_nn_relu |  748800 |      1 |            N/A |          N/A |                   N/A |      0 |      
  2 | fused_nn_contrib_conv2d_NCHWc_add_nn_relu_1 | 1603200 |      1 |            N/A |          N/A |                   N/A |      0 |      
  3 |                    fused_layout_transform_1 |       1 |      1 |            N/A |          N/A |                   N/A |      0 |      
---------------------------------------------------------------------------------------------------------------------------------------------
Total trials: 2
Total latency (us): 11.0077


Total trials: 2
Total latency (us): 11.0077

2024-03-20 12:21:43 [INFO] [task_scheduler.cc:260] Task #0 has finished. Remaining task(s): 3
2024-03-20 12:21:43 [DEBUG] [task_scheduler.cc:318] 
 ID |                                        Name |    FLOP | Weight | Speed (GFLOPS) | Latency (us) | Weighted Latency (us) | Trials | Done 
---------------------------------------------------------------------------------------------------------------------------------------------
  0 |                      fused_layout_transform |       1 |      1 |         0.0001 |      11.0077 |               11.0077 |      2 |    Y 
  1 |   fused_nn_contrib_conv2d_NCHWc_add_nn_relu |  748800 |      1 |            N/A |          N/A |                   N/A |      0 |      
  2 | fused_nn_contrib_conv2d_NCHWc_add_nn_relu_1 | 1603200 |      1 |            N/A |          N/A |                   N/A |      0 |      
  3 |                    fused_layout_transform_1 |       1 |      1 |            N/A |          N/A |                   N/A |      0 |      
---------------------------------------------------------------------------------------------------------------------------------------------
Total trials: 2
Total latency (us): 11.0077


Total trials: 2
Total latency (us): 11.0077

2024-03-20 12:21:43 [INFO] [task_scheduler.cc:260] Task #1 has finished. Remaining task(s): 2
2024-03-20 12:21:43 [DEBUG] [task_scheduler.cc:318] 
 ID |                                        Name |    FLOP | Weight | Speed (GFLOPS) | Latency (us) | Weighted Latency (us) | Trials | Done 
---------------------------------------------------------------------------------------------------------------------------------------------
  0 |                      fused_layout_transform |       1 |      1 |         0.0001 |      11.0077 |               11.0077 |      2 |    Y 
  1 |   fused_nn_contrib_conv2d_NCHWc_add_nn_relu |  748800 |      1 |            N/A |          N/A |                   N/A |      0 |    Y 
  2 | fused_nn_contrib_conv2d_NCHWc_add_nn_relu_1 | 1603200 |      1 |            N/A |          N/A |                   N/A |      0 |      
  3 |                    fused_layout_transform_1 |       1 |      1 |            N/A |          N/A |                   N/A |      0 |      
---------------------------------------------------------------------------------------------------------------------------------------------
Total trials: 2
Total latency (us): 11.0077


Total trials: 2
Total latency (us): 11.0077

2024-03-20 12:21:43 [INFO] [task_scheduler.cc:260] Task #2 has finished. Remaining task(s): 1
2024-03-20 12:21:43 [DEBUG] [task_scheduler.cc:318] 
 ID |                                        Name |    FLOP | Weight | Speed (GFLOPS) | Latency (us) | Weighted Latency (us) | Trials | Done 
---------------------------------------------------------------------------------------------------------------------------------------------
  0 |                      fused_layout_transform |       1 |      1 |         0.0001 |      11.0077 |               11.0077 |      2 |    Y 
  1 |   fused_nn_contrib_conv2d_NCHWc_add_nn_relu |  748800 |      1 |            N/A |          N/A |                   N/A |      0 |    Y 
  2 | fused_nn_contrib_conv2d_NCHWc_add_nn_relu_1 | 1603200 |      1 |            N/A |          N/A |                   N/A |      0 |    Y 
  3 |                    fused_layout_transform_1 |       1 |      1 |            N/A |          N/A |                   N/A |      0 |      
---------------------------------------------------------------------------------------------------------------------------------------------
Total trials: 2
Total latency (us): 11.0077


Total trials: 2
Total latency (us): 11.0077

2024-03-20 12:21:43 [INFO] [task_scheduler.cc:260] Task #3 has finished. Remaining task(s): 0
2024-03-20 12:21:43 [DEBUG] [task_scheduler.cc:318] 
 ID |                                        Name |    FLOP | Weight | Speed (GFLOPS) | Latency (us) | Weighted Latency (us) | Trials | Done 
---------------------------------------------------------------------------------------------------------------------------------------------
  0 |                      fused_layout_transform |       1 |      1 |         0.0001 |      11.0077 |               11.0077 |      2 |    Y 
  1 |   fused_nn_contrib_conv2d_NCHWc_add_nn_relu |  748800 |      1 |            N/A |          N/A |                   N/A |      0 |    Y 
  2 | fused_nn_contrib_conv2d_NCHWc_add_nn_relu_1 | 1603200 |      1 |            N/A |          N/A |                   N/A |      0 |    Y 
  3 |                    fused_layout_transform_1 |       1 |      1 |            N/A |          N/A |                   N/A |      0 |    Y 
---------------------------------------------------------------------------------------------------------------------------------------------
Total trials: 2
Total latency (us): 11.0077


Total trials: 2
Total latency (us): 11.0077
Name FLOP Weight Speed (GFLOPS) Latency (us) Weighted Latency (us) Trials Done
0 fused_layout_transform 1 1 N/A N/A N/A 0
1 fused_nn_contrib_conv2d_NCHWc_add_nn_relu 748800 1 N/A N/A N/A 0
2 fused_nn_contrib_conv2d_NCHWc_add_nn_relu_1 1603200 1 N/A N/A N/A 0
3 fused_layout_transform_1 1 1 N/A N/A N/A 0
Name FLOP Weight Speed (GFLOPS) Latency (us) Weighted Latency (us) Trials Done
0 fused_layout_transform 1 1 0.0001 11.0077 11.0077 2
1 fused_nn_contrib_conv2d_NCHWc_add_nn_relu 748800 1 N/A N/A N/A 0
2 fused_nn_contrib_conv2d_NCHWc_add_nn_relu_1 1603200 1 N/A N/A N/A 0
3 fused_layout_transform_1 1 1 N/A N/A N/A 0
Name FLOP Weight Speed (GFLOPS) Latency (us) Weighted Latency (us) Trials Done
0 fused_layout_transform 1 1 0.0001 11.0077 11.0077 2 Y
1 fused_nn_contrib_conv2d_NCHWc_add_nn_relu 748800 1 N/A N/A N/A 0
2 fused_nn_contrib_conv2d_NCHWc_add_nn_relu_1 1603200 1 N/A N/A N/A 0
3 fused_layout_transform_1 1 1 N/A N/A N/A 0
Name FLOP Weight Speed (GFLOPS) Latency (us) Weighted Latency (us) Trials Done
0 fused_layout_transform 1 1 0.0001 11.0077 11.0077 2 Y
1 fused_nn_contrib_conv2d_NCHWc_add_nn_relu 748800 1 N/A N/A N/A 0 Y
2 fused_nn_contrib_conv2d_NCHWc_add_nn_relu_1 1603200 1 N/A N/A N/A 0
3 fused_layout_transform_1 1 1 N/A N/A N/A 0
Name FLOP Weight Speed (GFLOPS) Latency (us) Weighted Latency (us) Trials Done
0 fused_layout_transform 1 1 0.0001 11.0077 11.0077 2 Y
1 fused_nn_contrib_conv2d_NCHWc_add_nn_relu 748800 1 N/A N/A N/A 0 Y
2 fused_nn_contrib_conv2d_NCHWc_add_nn_relu_1 1603200 1 N/A N/A N/A 0 Y
3 fused_layout_transform_1 1 1 N/A N/A N/A 0
Name FLOP Weight Speed (GFLOPS) Latency (us) Weighted Latency (us) Trials Done
0 fused_layout_transform 1 1 0.0001 11.0077 11.0077 2 Y
1 fused_nn_contrib_conv2d_NCHWc_add_nn_relu 748800 1 N/A N/A N/A 0 Y
2 fused_nn_contrib_conv2d_NCHWc_add_nn_relu_1 1603200 1 N/A N/A N/A 0 Y
3 fused_layout_transform_1 1 1 N/A N/A N/A 0 Y
[12:21:44] /media/pc/data/lxw/ai/tvm/src/relay/backend/te_compiler_cache.cc:679: Warning: Cannot find workload: fused_nn_contrib_conv2d_NCHWc_add_nn_relu
[12:21:44] /media/pc/data/lxw/ai/tvm/src/relay/backend/te_compiler_cache.cc:679: Warning: Cannot find workload: fused_nn_contrib_conv2d_NCHWc_add_nn_relu
[12:21:44] /media/pc/data/lxw/ai/tvm/src/relay/backend/te_compiler_cache.cc:679: Warning: Cannot find workload: fused_layout_transform
[12:21:44] /media/pc/data/lxw/ai/tvm/src/relay/backend/te_compiler_cache.cc:679: Warning: Cannot find workload: tvmgen_default_fused_nn_contrib_conv2d_NCHWc_add_nn_relu
[12:21:44] /media/pc/data/lxw/ai/tvm/src/relay/backend/te_compiler_cache.cc:679: Warning: Cannot find workload: tvmgen_default_fused_nn_contrib_conv2d_NCHWc_add_nn_relu_1
[12:21:44] /media/pc/data/lxw/ai/tvm/src/relay/backend/te_compiler_cache.cc:679: Warning: Cannot find workload: tvmgen_default_fused_layout_transform_1
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[6], line 3
      1 simple_model = SimpleModel()
      2 example_input = torch.randn(20, 1, 10, 10)
----> 3 model_optimized_by_tvm = optimize_torch(simple_model, example_input, max_trials_global=2)

File /media/pc/data/lxw/ai/tvm/python/tvm/contrib/torch/optimize_torch.py:166, in optimize_torch(func, example_inputs, max_trials_global, work_dir, target, max_trials_per_task, num_trials_per_iter, builder, runner, database, cost_model, measure_callbacks, task_scheduler, space, strategy, seed)
    164 save_runtime_mod = get_global_func("tvmtorch.save_runtime_mod", allow_missing=True)
    165 if save_runtime_mod is None:
--> 166     raise ValueError('optimize_torch requires the flag /"USE_PT_TVMDSOOP/" set in config.cmake')
    167 save_runtime_mod(executor_factory.module)
    169 return GraphExecutorFactoryWrapper(torch.classes.tvm_torch.GraphExecutorFactoryWrapper())

ValueError: optimize_torch requires the flag /"USE_PT_TVMDSOOP/" set in config.cmake

保存/加载模块#

我们可以像标准的nn.Module一样保存和加载我们优化过的模块。

让我们运行我们的优化模块。

ret1 = model_optimized_by_tvm(example_input)

torch.save(model_optimized_by_tvm, "model_optimized.pt")
model_loaded = torch.load("model_optimized.pt")
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
Cell In[7], line 1
----> 1 ret1 = model_optimized_by_tvm(example_input)
      3 torch.save(model_optimized_by_tvm, "model_optimized.pt")
      4 model_loaded = torch.load("model_optimized.pt")

NameError: name 'model_optimized_by_tvm' is not defined
# We load the module and run it again.
ret2 = model_loaded(example_input)

# We will show 2 results:
# (1) we can safely load and save model by showing the result of model
# after save and load operations is still the same as original one;
# (2) the model we optimize returns the same result as the original PyTorch model.

ret3 = simple_model(example_input)
testing.assert_allclose(ret1.detach().numpy(), ret2.detach().numpy(), atol=1e-5, rtol=1e-5)
testing.assert_allclose(ret1.detach().numpy(), ret3.detach().numpy(), atol=1e-5, rtol=1e-5)

######################################################################
# Optimize resnet18
# -----------------
# In the following, we will show that our approach is able to
# accelerate common models, such as resnet18.

# We will tune our model for the GPU.
target_cuda = "nvidia/geforce-rtx-3070"

# For PyTorch users, the code could be written as usual, except for
# applying "optimize_torch" function on the resnet18 model.

resnet18_tvm = optimize_torch(
    resnet18().cuda().eval(), [torch.rand(1, 3, 224, 224).cuda()], target=target_cuda
)

# TorchScript also provides a built-in "optimize_for_inference" function to accelerate the inference.
resnet18_torch = torch.jit.optimize_for_inference(torch.jit.script(resnet18().cuda().eval()))


######################################################################
# Compare the performance between two approaches
# ----------------------------------------------

results = []
for i in range(5):
    test_input = torch.rand(1, 3, 224, 224).cuda()
    sub_label = f"[test {i}]"
    results.append(
        benchmark.Timer(
            stmt="resnet18_tvm(test_input)",
            setup="from __main__ import resnet18_tvm",
            globals={"test_input": test_input},
            sub_label=sub_label,
            description="tuning by meta",
        ).blocked_autorange()
    )
    results.append(
        benchmark.Timer(
            stmt="resnet18_torch(test_input)",
            setup="from __main__ import resnet18_torch",
            globals={"test_input": test_input},
            sub_label=sub_label,
            description="tuning by jit",
        ).blocked_autorange()
    )

compare = benchmark.Compare(results)
compare.print()

# In author's environment, the average inference time of `resnet18_tvm` is 620.0 us,
# while the average inference time of `resnet18_torch` is 980.0 us (PyTorch version is 1.11.0),
# showing the speedup of around 38%.