端到端模型优化样例#

参考:e2e_opt_model

本教程展示了如何使用 Apache TVM 来优化机器学习模型。使用来自 PyTorch 的预训练 ResNet-18 模型,并利用 TVM 的 Relax API 对其进行端到端的优化。请注意,默认的端到端优化可能不适合复杂的模型。

准备阶段#

首先,准备模型和输入信息。使用来自PyTorch的预训练ResNet-18模型。

from pathlib import Path
temp_dir = Path(".temp")
temp_dir.mkdir(exist_ok=True)
import os
import numpy as np
import torch
from torchvision.models.resnet import ResNet18_Weights, resnet18

torch_model = resnet18(weights=ResNet18_Weights.DEFAULT)

整体流程概述#

https://raw.githubusercontent.com/tlc-pack/web-data/main/images/design/tvm_overall_flow.svg

整体流程包括以下步骤:

  • 构建或导入模型:构建神经网络模型,或者从其他框架(如 PyTorch、ONNX)导入预训练的模型,并创建 TVM IRModule,其中包含编译所需的所有信息,包括用于计算图的高级别 Relax 函数和用于张量程序的低级 TensorIR 函数。

  • 执行可组合优化:执行一系列优化转换,例如图优化、张量程序优化和库调度。

  • 构建和通用部署:将优化后的模型构建为可在通用运行时部署的模块,并在不同设备上执行,如 CPU、GPU 或其他加速器。

将模型转换为 IRModule#

使用 Relax 前端(面向 PyTorch)将模型转换为 IRModule,以便进一步优化。除了模型外,我们还需要提供输入的形状和数据类型。

import tvm
from tvm import relax
from tvm.relax.frontend.torch import from_exported_program
from torch.export import export

# Give an example argument to torch.export
example_args = (torch.randn(1, 3, 224, 224, dtype=torch.float32),)

# Skip running in CI environment
IS_IN_CI = os.getenv("CI", "") == "true"

if not IS_IN_CI:
    # Convert the model to IRModule
    with torch.no_grad():
        exported_program = export(torch_model, example_args)
        mod = from_exported_program(exported_program, keep_params_as_input=True)

    mod, params = relax.frontend.detach_params(mod)
    mod.show()

Hide code cell output

/media/pc/data/lxw/ai/tvm/3rdparty/tvm-ffi/python/tvm_ffi/_optional_torch_c_dlpack.py:409: UserWarning: Failed to load torch c dlpack extension: Error building extension 'c_dlpack': [1/2] /media/pc/data/lxw/envs/anaconda3a/envs/py313/bin/x86_64-conda-linux-gnu-c++ -MMD -MF main.o.d -DTORCH_EXTENSION_NAME=c_dlpack -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1018\" -I/media/pc/data/lxw/ai/tvm/3rdparty/tvm-ffi/include -I/media/pc/data/lxw/ai/tvm/3rdparty/tvm-ffi/3rdparty/dlpack/include -I/media/pc/data/lxw/ai/tvm/3rdparty/tvm-ffi/python/tvm_ffi/cython -I/media/pc/data/lxw/envs/anaconda3a/envs/py313/lib/python3.13/site-packages/torch/include -I/media/pc/data/lxw/envs/anaconda3a/envs/py313/lib/python3.13/site-packages/torch/include/torch/csrc/api/include -I/media/pc/data/lxw/envs/anaconda3a/envs/py313/include -isystem /media/pc/data/lxw/envs/anaconda3a/envs/py313/lib/python3.13/site-packages/torch/include -isystem /media/pc/data/lxw/envs/anaconda3a/envs/py313/lib/python3.13/site-packages/torch/include/torch/csrc/api/include -isystem /media/pc/data/lxw/envs/anaconda3a/envs/py313/include/python3.13 -fPIC -std=c++17 -O3 -DBUILD_WITH_CUDA -c /home/ai/.cache/torch_extensions/py313_cu128/c_dlpack/main.cpp -o main.o 
FAILED: [code=1] main.o 
/media/pc/data/lxw/envs/anaconda3a/envs/py313/bin/x86_64-conda-linux-gnu-c++ -MMD -MF main.o.d -DTORCH_EXTENSION_NAME=c_dlpack -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1018\" -I/media/pc/data/lxw/ai/tvm/3rdparty/tvm-ffi/include -I/media/pc/data/lxw/ai/tvm/3rdparty/tvm-ffi/3rdparty/dlpack/include -I/media/pc/data/lxw/ai/tvm/3rdparty/tvm-ffi/python/tvm_ffi/cython -I/media/pc/data/lxw/envs/anaconda3a/envs/py313/lib/python3.13/site-packages/torch/include -I/media/pc/data/lxw/envs/anaconda3a/envs/py313/lib/python3.13/site-packages/torch/include/torch/csrc/api/include -I/media/pc/data/lxw/envs/anaconda3a/envs/py313/include -isystem /media/pc/data/lxw/envs/anaconda3a/envs/py313/lib/python3.13/site-packages/torch/include -isystem /media/pc/data/lxw/envs/anaconda3a/envs/py313/lib/python3.13/site-packages/torch/include/torch/csrc/api/include -isystem /media/pc/data/lxw/envs/anaconda3a/envs/py313/include/python3.13 -fPIC -std=c++17 -O3 -DBUILD_WITH_CUDA -c /home/ai/.cache/torch_extensions/py313_cu128/c_dlpack/main.cpp -o main.o 
In file included from /home/ai/.cache/torch_extensions/py313_cu128/c_dlpack/main.cpp:8:
/media/pc/data/lxw/envs/anaconda3a/envs/py313/lib/python3.13/site-packages/torch/include/c10/cuda/CUDAStream.h:3:10: fatal error: cuda_runtime_api.h: No such file or directory
    3 | #include <cuda_runtime_api.h>
      |          ^~~~~~~~~~~~~~~~~~~~
compilation terminated.
ninja: build stopped: subcommand failed.
,EnvTensorAllocator will not be enabled.
  warnings.warn(
# from tvm.script import ir as I
# from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function
    def main(x: R.Tensor((1, 3, 224, 224), dtype="float32"), p_conv1_weight: R.Tensor((64, 3, 7, 7), dtype="float32"), p_bn1_weight: R.Tensor((64,), dtype="float32"), p_bn1_bias: R.Tensor((64,), dtype="float32"), p_layer1_0_conv1_weight: R.Tensor((64, 64, 3, 3), dtype="float32"), p_layer1_0_bn1_weight: R.Tensor((64,), dtype="float32"), p_layer1_0_bn1_bias: R.Tensor((64,), dtype="float32"), p_layer1_0_conv2_weight: R.Tensor((64, 64, 3, 3), dtype="float32"), p_layer1_0_bn2_weight: R.Tensor((64,), dtype="float32"), p_layer1_0_bn2_bias: R.Tensor((64,), dtype="float32"), p_layer1_1_conv1_weight: R.Tensor((64, 64, 3, 3), dtype="float32"), p_layer1_1_bn1_weight: R.Tensor((64,), dtype="float32"), p_layer1_1_bn1_bias: R.Tensor((64,), dtype="float32"), p_layer1_1_conv2_weight: R.Tensor((64, 64, 3, 3), dtype="float32"), p_layer1_1_bn2_weight: R.Tensor((64,), dtype="float32"), p_layer1_1_bn2_bias: R.Tensor((64,), dtype="float32"), p_layer2_0_conv1_weight: R.Tensor((128, 64, 3, 3), dtype="float32"), p_layer2_0_bn1_weight: R.Tensor((128,), dtype="float32"), p_layer2_0_bn1_bias: R.Tensor((128,), dtype="float32"), p_layer2_0_conv2_weight: R.Tensor((128, 128, 3, 3), dtype="float32"), p_layer2_0_bn2_weight: R.Tensor((128,), dtype="float32"), p_layer2_0_bn2_bias: R.Tensor((128,), dtype="float32"), p_layer2_0_downsample_0_weight: R.Tensor((128, 64, 1, 1), dtype="float32"), p_layer2_0_downsample_1_weight: R.Tensor((128,), dtype="float32"), p_layer2_0_downsample_1_bias: R.Tensor((128,), dtype="float32"), p_layer2_1_conv1_weight: R.Tensor((128, 128, 3, 3), dtype="float32"), p_layer2_1_bn1_weight: R.Tensor((128,), dtype="float32"), p_layer2_1_bn1_bias: R.Tensor((128,), dtype="float32"), p_layer2_1_conv2_weight: R.Tensor((128, 128, 3, 3), dtype="float32"), p_layer2_1_bn2_weight: R.Tensor((128,), dtype="float32"), p_layer2_1_bn2_bias: R.Tensor((128,), dtype="float32"), p_layer3_0_conv1_weight: R.Tensor((256, 128, 3, 3), dtype="float32"), p_layer3_0_bn1_weight: R.Tensor((256,), dtype="float32"), p_layer3_0_bn1_bias: R.Tensor((256,), dtype="float32"), p_layer3_0_conv2_weight: R.Tensor((256, 256, 3, 3), dtype="float32"), p_layer3_0_bn2_weight: R.Tensor((256,), dtype="float32"), p_layer3_0_bn2_bias: R.Tensor((256,), dtype="float32"), p_layer3_0_downsample_0_weight: R.Tensor((256, 128, 1, 1), dtype="float32"), p_layer3_0_downsample_1_weight: R.Tensor((256,), dtype="float32"), p_layer3_0_downsample_1_bias: R.Tensor((256,), dtype="float32"), p_layer3_1_conv1_weight: R.Tensor((256, 256, 3, 3), dtype="float32"), p_layer3_1_bn1_weight: R.Tensor((256,), dtype="float32"), p_layer3_1_bn1_bias: R.Tensor((256,), dtype="float32"), p_layer3_1_conv2_weight: R.Tensor((256, 256, 3, 3), dtype="float32"), p_layer3_1_bn2_weight: R.Tensor((256,), dtype="float32"), p_layer3_1_bn2_bias: R.Tensor((256,), dtype="float32"), p_layer4_0_conv1_weight: R.Tensor((512, 256, 3, 3), dtype="float32"), p_layer4_0_bn1_weight: R.Tensor((512,), dtype="float32"), p_layer4_0_bn1_bias: R.Tensor((512,), dtype="float32"), p_layer4_0_conv2_weight: R.Tensor((512, 512, 3, 3), dtype="float32"), p_layer4_0_bn2_weight: R.Tensor((512,), dtype="float32"), p_layer4_0_bn2_bias: R.Tensor((512,), dtype="float32"), p_layer4_0_downsample_0_weight: R.Tensor((512, 256, 1, 1), dtype="float32"), p_layer4_0_downsample_1_weight: R.Tensor((512,), dtype="float32"), p_layer4_0_downsample_1_bias: R.Tensor((512,), dtype="float32"), p_layer4_1_conv1_weight: R.Tensor((512, 512, 3, 3), dtype="float32"), p_layer4_1_bn1_weight: R.Tensor((512,), dtype="float32"), p_layer4_1_bn1_bias: R.Tensor((512,), dtype="float32"), p_layer4_1_conv2_weight: R.Tensor((512, 512, 3, 3), dtype="float32"), p_layer4_1_bn2_weight: R.Tensor((512,), dtype="float32"), p_layer4_1_bn2_bias: R.Tensor((512,), dtype="float32"), p_fc_weight: R.Tensor((1000, 512), dtype="float32"), p_fc_bias: R.Tensor((1000,), dtype="float32")) -> R.Tuple(R.Tensor((1, 1000), dtype="float32")):
        R.func_attr({"num_input": 1})
        with R.dataflow():
            lv: R.Tensor((1, 64, 112, 112), dtype="float32") = R.nn.conv2d(x, p_conv1_weight, strides=[2, 2], padding=[3, 3, 3, 3], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
            lv1: R.Tensor((), dtype="int64") = R.add(R.const(0, "int64"), R.const(1, "int64"))
            lv2: R.Tuple(R.Tensor((1, 64, 112, 112), dtype="float32"), R.Tensor((64,), dtype="float32"), R.Tensor((64,), dtype="float32")) = R.nn.batch_norm(lv, p_bn1_weight, p_bn1_bias, metadata["relax.expr.Constant"][0], metadata["relax.expr.Constant"][1], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001, training=False)
            lv3: R.Tensor((1, 64, 112, 112), dtype="float32") = lv2[0]
            lv4: R.Tensor((1, 64, 112, 112), dtype="float32") = R.nn.relu(lv3)
            lv5: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.max_pool2d(lv4, pool_size=[3, 3], strides=[2, 2], dilation=[1, 1], padding=[1, 1, 1, 1], ceil_mode=False, count_include_pad=False, layout="NCHW", out_layout="NCHW")
            lv6: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.conv2d(lv5, p_layer1_0_conv1_weight, strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
            lv7: R.Tensor((), dtype="int64") = R.add(R.const(0, "int64"), R.const(1, "int64"))
            lv8: R.Tuple(R.Tensor((1, 64, 56, 56), dtype="float32"), R.Tensor((64,), dtype="float32"), R.Tensor((64,), dtype="float32")) = R.nn.batch_norm(lv6, p_layer1_0_bn1_weight, p_layer1_0_bn1_bias, metadata["relax.expr.Constant"][2], metadata["relax.expr.Constant"][3], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001, training=False)
            lv9: R.Tensor((1, 64, 56, 56), dtype="float32") = lv8[0]
            lv10: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.relu(lv9)
            lv11: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.conv2d(lv10, p_layer1_0_conv2_weight, strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
            lv12: R.Tensor((), dtype="int64") = R.add(R.const(0, "int64"), R.const(1, "int64"))
            lv13: R.Tuple(R.Tensor((1, 64, 56, 56), dtype="float32"), R.Tensor((64,), dtype="float32"), R.Tensor((64,), dtype="float32")) = R.nn.batch_norm(lv11, p_layer1_0_bn2_weight, p_layer1_0_bn2_bias, metadata["relax.expr.Constant"][4], metadata["relax.expr.Constant"][5], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001, training=False)
            lv14: R.Tensor((1, 64, 56, 56), dtype="float32") = lv13[0]
            lv15: R.Tensor((1, 64, 56, 56), dtype="float32") = R.add(lv14, lv5)
            lv16: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.relu(lv15)
            lv17: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.conv2d(lv16, p_layer1_1_conv1_weight, strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
            lv18: R.Tensor((), dtype="int64") = R.add(R.const(0, "int64"), R.const(1, "int64"))
            lv19: R.Tuple(R.Tensor((1, 64, 56, 56), dtype="float32"), R.Tensor((64,), dtype="float32"), R.Tensor((64,), dtype="float32")) = R.nn.batch_norm(lv17, p_layer1_1_bn1_weight, p_layer1_1_bn1_bias, metadata["relax.expr.Constant"][6], metadata["relax.expr.Constant"][7], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001, training=False)
            lv20: R.Tensor((1, 64, 56, 56), dtype="float32") = lv19[0]
            lv21: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.relu(lv20)
            lv22: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.conv2d(lv21, p_layer1_1_conv2_weight, strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
            lv23: R.Tensor((), dtype="int64") = R.add(R.const(0, "int64"), R.const(1, "int64"))
            lv24: R.Tuple(R.Tensor((1, 64, 56, 56), dtype="float32"), R.Tensor((64,), dtype="float32"), R.Tensor((64,), dtype="float32")) = R.nn.batch_norm(lv22, p_layer1_1_bn2_weight, p_layer1_1_bn2_bias, metadata["relax.expr.Constant"][8], metadata["relax.expr.Constant"][9], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001, training=False)
            lv25: R.Tensor((1, 64, 56, 56), dtype="float32") = lv24[0]
            lv26: R.Tensor((1, 64, 56, 56), dtype="float32") = R.add(lv25, lv16)
            lv27: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.relu(lv26)
            lv28: R.Tensor((1, 128, 28, 28), dtype="float32") = R.nn.conv2d(lv27, p_layer2_0_conv1_weight, strides=[2, 2], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
            lv29: R.Tensor((), dtype="int64") = R.add(R.const(0, "int64"), R.const(1, "int64"))
            lv30: R.Tuple(R.Tensor((1, 128, 28, 28), dtype="float32"), R.Tensor((128,), dtype="float32"), R.Tensor((128,), dtype="float32")) = R.nn.batch_norm(lv28, p_layer2_0_bn1_weight, p_layer2_0_bn1_bias, metadata["relax.expr.Constant"][10], metadata["relax.expr.Constant"][11], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001, training=False)
            lv31: R.Tensor((1, 128, 28, 28), dtype="float32") = lv30[0]
            lv32: R.Tensor((1, 128, 28, 28), dtype="float32") = R.nn.relu(lv31)
            lv33: R.Tensor((1, 128, 28, 28), dtype="float32") = R.nn.conv2d(lv32, p_layer2_0_conv2_weight, strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
            lv34: R.Tensor((), dtype="int64") = R.add(R.const(0, "int64"), R.const(1, "int64"))
            lv35: R.Tuple(R.Tensor((1, 128, 28, 28), dtype="float32"), R.Tensor((128,), dtype="float32"), R.Tensor((128,), dtype="float32")) = R.nn.batch_norm(lv33, p_layer2_0_bn2_weight, p_layer2_0_bn2_bias, metadata["relax.expr.Constant"][12], metadata["relax.expr.Constant"][13], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001, training=False)
            lv36: R.Tensor((1, 128, 28, 28), dtype="float32") = lv35[0]
            lv37: R.Tensor((1, 128, 28, 28), dtype="float32") = R.nn.conv2d(lv27, p_layer2_0_downsample_0_weight, strides=[2, 2], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
            lv38: R.Tensor((), dtype="int64") = R.add(R.const(0, "int64"), R.const(1, "int64"))
            lv39: R.Tuple(R.Tensor((1, 128, 28, 28), dtype="float32"), R.Tensor((128,), dtype="float32"), R.Tensor((128,), dtype="float32")) = R.nn.batch_norm(lv37, p_layer2_0_downsample_1_weight, p_layer2_0_downsample_1_bias, metadata["relax.expr.Constant"][14], metadata["relax.expr.Constant"][15], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001, training=False)
            lv40: R.Tensor((1, 128, 28, 28), dtype="float32") = lv39[0]
            lv41: R.Tensor((1, 128, 28, 28), dtype="float32") = R.add(lv36, lv40)
            lv42: R.Tensor((1, 128, 28, 28), dtype="float32") = R.nn.relu(lv41)
            lv43: R.Tensor((1, 128, 28, 28), dtype="float32") = R.nn.conv2d(lv42, p_layer2_1_conv1_weight, strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
            lv44: R.Tensor((), dtype="int64") = R.add(R.const(0, "int64"), R.const(1, "int64"))
            lv45: R.Tuple(R.Tensor((1, 128, 28, 28), dtype="float32"), R.Tensor((128,), dtype="float32"), R.Tensor((128,), dtype="float32")) = R.nn.batch_norm(lv43, p_layer2_1_bn1_weight, p_layer2_1_bn1_bias, metadata["relax.expr.Constant"][16], metadata["relax.expr.Constant"][17], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001, training=False)
            lv46: R.Tensor((1, 128, 28, 28), dtype="float32") = lv45[0]
            lv47: R.Tensor((1, 128, 28, 28), dtype="float32") = R.nn.relu(lv46)
            lv48: R.Tensor((1, 128, 28, 28), dtype="float32") = R.nn.conv2d(lv47, p_layer2_1_conv2_weight, strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
            lv49: R.Tensor((), dtype="int64") = R.add(R.const(0, "int64"), R.const(1, "int64"))
            lv50: R.Tuple(R.Tensor((1, 128, 28, 28), dtype="float32"), R.Tensor((128,), dtype="float32"), R.Tensor((128,), dtype="float32")) = R.nn.batch_norm(lv48, p_layer2_1_bn2_weight, p_layer2_1_bn2_bias, metadata["relax.expr.Constant"][18], metadata["relax.expr.Constant"][19], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001, training=False)
            lv51: R.Tensor((1, 128, 28, 28), dtype="float32") = lv50[0]
            lv52: R.Tensor((1, 128, 28, 28), dtype="float32") = R.add(lv51, lv42)
            lv53: R.Tensor((1, 128, 28, 28), dtype="float32") = R.nn.relu(lv52)
            lv54: R.Tensor((1, 256, 14, 14), dtype="float32") = R.nn.conv2d(lv53, p_layer3_0_conv1_weight, strides=[2, 2], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
            lv55: R.Tensor((), dtype="int64") = R.add(R.const(0, "int64"), R.const(1, "int64"))
            lv56: R.Tuple(R.Tensor((1, 256, 14, 14), dtype="float32"), R.Tensor((256,), dtype="float32"), R.Tensor((256,), dtype="float32")) = R.nn.batch_norm(lv54, p_layer3_0_bn1_weight, p_layer3_0_bn1_bias, metadata["relax.expr.Constant"][20], metadata["relax.expr.Constant"][21], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001, training=False)
            lv57: R.Tensor((1, 256, 14, 14), dtype="float32") = lv56[0]
            lv58: R.Tensor((1, 256, 14, 14), dtype="float32") = R.nn.relu(lv57)
            lv59: R.Tensor((1, 256, 14, 14), dtype="float32") = R.nn.conv2d(lv58, p_layer3_0_conv2_weight, strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
            lv60: R.Tensor((), dtype="int64") = R.add(R.const(0, "int64"), R.const(1, "int64"))
            lv61: R.Tuple(R.Tensor((1, 256, 14, 14), dtype="float32"), R.Tensor((256,), dtype="float32"), R.Tensor((256,), dtype="float32")) = R.nn.batch_norm(lv59, p_layer3_0_bn2_weight, p_layer3_0_bn2_bias, metadata["relax.expr.Constant"][22], metadata["relax.expr.Constant"][23], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001, training=False)
            lv62: R.Tensor((1, 256, 14, 14), dtype="float32") = lv61[0]
            lv63: R.Tensor((1, 256, 14, 14), dtype="float32") = R.nn.conv2d(lv53, p_layer3_0_downsample_0_weight, strides=[2, 2], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
            lv64: R.Tensor((), dtype="int64") = R.add(R.const(0, "int64"), R.const(1, "int64"))
            lv65: R.Tuple(R.Tensor((1, 256, 14, 14), dtype="float32"), R.Tensor((256,), dtype="float32"), R.Tensor((256,), dtype="float32")) = R.nn.batch_norm(lv63, p_layer3_0_downsample_1_weight, p_layer3_0_downsample_1_bias, metadata["relax.expr.Constant"][24], metadata["relax.expr.Constant"][25], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001, training=False)
            lv66: R.Tensor((1, 256, 14, 14), dtype="float32") = lv65[0]
            lv67: R.Tensor((1, 256, 14, 14), dtype="float32") = R.add(lv62, lv66)
            lv68: R.Tensor((1, 256, 14, 14), dtype="float32") = R.nn.relu(lv67)
            lv69: R.Tensor((1, 256, 14, 14), dtype="float32") = R.nn.conv2d(lv68, p_layer3_1_conv1_weight, strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
            lv70: R.Tensor((), dtype="int64") = R.add(R.const(0, "int64"), R.const(1, "int64"))
            lv71: R.Tuple(R.Tensor((1, 256, 14, 14), dtype="float32"), R.Tensor((256,), dtype="float32"), R.Tensor((256,), dtype="float32")) = R.nn.batch_norm(lv69, p_layer3_1_bn1_weight, p_layer3_1_bn1_bias, metadata["relax.expr.Constant"][26], metadata["relax.expr.Constant"][27], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001, training=False)
            lv72: R.Tensor((1, 256, 14, 14), dtype="float32") = lv71[0]
            lv73: R.Tensor((1, 256, 14, 14), dtype="float32") = R.nn.relu(lv72)
            lv74: R.Tensor((1, 256, 14, 14), dtype="float32") = R.nn.conv2d(lv73, p_layer3_1_conv2_weight, strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
            lv75: R.Tensor((), dtype="int64") = R.add(R.const(0, "int64"), R.const(1, "int64"))
            lv76: R.Tuple(R.Tensor((1, 256, 14, 14), dtype="float32"), R.Tensor((256,), dtype="float32"), R.Tensor((256,), dtype="float32")) = R.nn.batch_norm(lv74, p_layer3_1_bn2_weight, p_layer3_1_bn2_bias, metadata["relax.expr.Constant"][28], metadata["relax.expr.Constant"][29], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001, training=False)
            lv77: R.Tensor((1, 256, 14, 14), dtype="float32") = lv76[0]
            lv78: R.Tensor((1, 256, 14, 14), dtype="float32") = R.add(lv77, lv68)
            lv79: R.Tensor((1, 256, 14, 14), dtype="float32") = R.nn.relu(lv78)
            lv80: R.Tensor((1, 512, 7, 7), dtype="float32") = R.nn.conv2d(lv79, p_layer4_0_conv1_weight, strides=[2, 2], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
            lv81: R.Tensor((), dtype="int64") = R.add(R.const(0, "int64"), R.const(1, "int64"))
            lv82: R.Tuple(R.Tensor((1, 512, 7, 7), dtype="float32"), R.Tensor((512,), dtype="float32"), R.Tensor((512,), dtype="float32")) = R.nn.batch_norm(lv80, p_layer4_0_bn1_weight, p_layer4_0_bn1_bias, metadata["relax.expr.Constant"][30], metadata["relax.expr.Constant"][31], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001, training=False)
            lv83: R.Tensor((1, 512, 7, 7), dtype="float32") = lv82[0]
            lv84: R.Tensor((1, 512, 7, 7), dtype="float32") = R.nn.relu(lv83)
            lv85: R.Tensor((1, 512, 7, 7), dtype="float32") = R.nn.conv2d(lv84, p_layer4_0_conv2_weight, strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
            lv86: R.Tensor((), dtype="int64") = R.add(R.const(0, "int64"), R.const(1, "int64"))
            lv87: R.Tuple(R.Tensor((1, 512, 7, 7), dtype="float32"), R.Tensor((512,), dtype="float32"), R.Tensor((512,), dtype="float32")) = R.nn.batch_norm(lv85, p_layer4_0_bn2_weight, p_layer4_0_bn2_bias, metadata["relax.expr.Constant"][32], metadata["relax.expr.Constant"][33], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001, training=False)
            lv88: R.Tensor((1, 512, 7, 7), dtype="float32") = lv87[0]
            lv89: R.Tensor((1, 512, 7, 7), dtype="float32") = R.nn.conv2d(lv79, p_layer4_0_downsample_0_weight, strides=[2, 2], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
            lv90: R.Tensor((), dtype="int64") = R.add(R.const(0, "int64"), R.const(1, "int64"))
            lv91: R.Tuple(R.Tensor((1, 512, 7, 7), dtype="float32"), R.Tensor((512,), dtype="float32"), R.Tensor((512,), dtype="float32")) = R.nn.batch_norm(lv89, p_layer4_0_downsample_1_weight, p_layer4_0_downsample_1_bias, metadata["relax.expr.Constant"][34], metadata["relax.expr.Constant"][35], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001, training=False)
            lv92: R.Tensor((1, 512, 7, 7), dtype="float32") = lv91[0]
            lv93: R.Tensor((1, 512, 7, 7), dtype="float32") = R.add(lv88, lv92)
            lv94: R.Tensor((1, 512, 7, 7), dtype="float32") = R.nn.relu(lv93)
            lv95: R.Tensor((1, 512, 7, 7), dtype="float32") = R.nn.conv2d(lv94, p_layer4_1_conv1_weight, strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
            lv96: R.Tensor((), dtype="int64") = R.add(R.const(0, "int64"), R.const(1, "int64"))
            lv97: R.Tuple(R.Tensor((1, 512, 7, 7), dtype="float32"), R.Tensor((512,), dtype="float32"), R.Tensor((512,), dtype="float32")) = R.nn.batch_norm(lv95, p_layer4_1_bn1_weight, p_layer4_1_bn1_bias, metadata["relax.expr.Constant"][36], metadata["relax.expr.Constant"][37], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001, training=False)
            lv98: R.Tensor((1, 512, 7, 7), dtype="float32") = lv97[0]
            lv99: R.Tensor((1, 512, 7, 7), dtype="float32") = R.nn.relu(lv98)
            lv100: R.Tensor((1, 512, 7, 7), dtype="float32") = R.nn.conv2d(lv99, p_layer4_1_conv2_weight, strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
            lv101: R.Tensor((), dtype="int64") = R.add(R.const(0, "int64"), R.const(1, "int64"))
            lv102: R.Tuple(R.Tensor((1, 512, 7, 7), dtype="float32"), R.Tensor((512,), dtype="float32"), R.Tensor((512,), dtype="float32")) = R.nn.batch_norm(lv100, p_layer4_1_bn2_weight, p_layer4_1_bn2_bias, metadata["relax.expr.Constant"][38], metadata["relax.expr.Constant"][39], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001, training=False)
            lv103: R.Tensor((1, 512, 7, 7), dtype="float32") = lv102[0]
            lv104: R.Tensor((1, 512, 7, 7), dtype="float32") = R.add(lv103, lv94)
            lv105: R.Tensor((1, 512, 7, 7), dtype="float32") = R.nn.relu(lv104)
            lv106: R.Tensor((1, 512, 1, 1), dtype="float32") = R.nn.adaptive_avg_pool2d(lv105, output_size=[1, 1], layout="NCHW", out_layout="NCHW")
            lv107: R.Tensor((1, 512), dtype="float32") = R.reshape(lv106, R.shape([1, 512]))
            lv108: R.Tensor((512, 1000), dtype="float32") = R.permute_dims(p_fc_weight, axes=None)
            lv109: R.Tensor((1, 1000), dtype="float32") = R.matmul(lv107, lv108, out_dtype="float32")
            lv110: R.Tensor((1, 1000), dtype="float32") = R.add(lv109, p_fc_bias)
            gv: R.Tuple(R.Tensor((1, 1000), dtype="float32")) = (lv110,)
            R.output(gv)
        return gv

# Metadata omitted. Use show_meta=True in script() method to show it.

IRModule 优化#

Apache TVM 提供了灵活的方式来优化 IRModule。围绕 IRModule 优化的所有运算都可以与现有的流水线组合。注意,每个转换都可以通过 tvm.ir.transform.Sequential 组合成一个优化流水线。

在本教程中,专注于通过自动调优(auto-tuning)对模型进行端到端优化。利用 MetaSchedule 调优模型,并将调优日志存储到数据库。还可以应用数据库到模型以获得最佳性能。

Hide code cell content

TOTAL_TRIALS = 8000  # Change to 20000 for better performance if needed
target = tvm.target.Target("nvidia/geforce-rtx-3090-ti")  # Change to your target device
work_dir = "tuning_logs"

if not IS_IN_CI:
    mod = relax.get_pipeline("static_shape_tuning", target=target, total_trials=TOTAL_TRIALS)(mod)

    # Only show the main function
    mod["main"].show()

构建和部署#

最后,构建优化后的模型并将其部署到目标设备。

if not IS_IN_CI:
    ex = tvm.compile(mod, target="cuda")
    dev = tvm.device("cuda", 0)
    vm = relax.VirtualMachine(ex, dev)
    # Need to allocate data and params on GPU device
    gpu_data = tvm.runtime.tensor(np.random.rand(1, 3, 224, 224).astype("float32"), dev)
    gpu_params = [tvm.runtime.tensor(p, dev) for p in params["main"]]
    gpu_out = vm["main"](gpu_data, *gpu_params).numpy()

    print(gpu_out.shape)