端到端模型优化样例#

参考:e2e_opt_model

本教程展示了如何使用 Apache TVM 来优化机器学习模型。使用来自 PyTorch 的预训练 ResNet-18 模型,并利用 TVM 的 Relax API 对其进行端到端的优化。请注意,默认的端到端优化可能不适合复杂的模型。

准备阶段#

首先,准备模型和输入信息。使用来自PyTorch的预训练ResNet-18模型。

import set_env
from pathlib import Path

temp_dir = Path(".temp")
temp_dir.mkdir(exist_ok=True)
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
E0000 00:00:1732543611.250788   94505 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1732543611.258525   94505 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
import os
import numpy as np
import torch
from torch import fx
from torchvision.models.resnet import ResNet18_Weights, resnet18

torch_model = resnet18(weights=ResNet18_Weights.DEFAULT)

整体流程概述#

https://raw.githubusercontent.com/tlc-pack/web-data/main/images/design/tvm_overall_flow.svg

整体流程包括以下步骤:

  • 构建或导入模型:构建一个神经网络模型,或者从其他框架(如 PyTorch、ONNX)导入预训练的模型,并创建 TVM IRModule,其中包含编译所需的所有信息,包括用于计算图的高级别 Relax 函数和用于张量程序的低级 TensorIR 函数。

  • 执行可组合优化:执行一系列优化转换,例如图优化、张量程序优化和库调度。

  • 构建和通用部署:将优化后的模型构建为可在通用运行时部署的模块,并在不同设备上执行,如 CPU、GPU 或其他加速器。

将模型转换为 IRModule#

使用 Relax 前端(面向 PyTorch)将模型转换为 IRModule,以便进一步优化。除了模型外,我们还需要提供输入的形状和数据类型。

import tvm
from tvm import relax
from tvm.relax.frontend.torch import from_fx

torch_model = resnet18(weights=ResNet18_Weights.DEFAULT)

# Give the input shape and data type
input_info = [((1, 3, 224, 224), "float32")]

# Convert the model to IRModule
with torch.no_grad():
    torch_fx_model = fx.symbolic_trace(torch_model)
    mod = from_fx(torch_fx_model, input_info, keep_params_as_input=True)

mod, params = relax.frontend.detach_params(mod)
mod.show()
Hide code cell output
# from tvm.script import ir as I
# from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function
    def main(inp_0: R.Tensor((1, 3, 224, 224), dtype="float32"), bn1_bias: R.Tensor((64,), dtype="float32"), bn1_weight: R.Tensor((64,), dtype="float32"), conv1_weight: R.Tensor((64, 3, 7, 7), dtype="float32"), fc_bias: R.Tensor((1000,), dtype="float32"), fc_weight: R.Tensor((1000, 512), dtype="float32"), layer1_0_bn1_bias: R.Tensor((64,), dtype="float32"), layer1_0_bn1_weight: R.Tensor((64,), dtype="float32"), layer1_0_bn2_bias: R.Tensor((64,), dtype="float32"), layer1_0_bn2_weight: R.Tensor((64,), dtype="float32"), layer1_0_conv1_weight: R.Tensor((64, 64, 3, 3), dtype="float32"), layer1_0_conv2_weight: R.Tensor((64, 64, 3, 3), dtype="float32"), layer1_1_bn1_bias: R.Tensor((64,), dtype="float32"), layer1_1_bn1_weight: R.Tensor((64,), dtype="float32"), layer1_1_bn2_bias: R.Tensor((64,), dtype="float32"), layer1_1_bn2_weight: R.Tensor((64,), dtype="float32"), layer1_1_conv1_weight: R.Tensor((64, 64, 3, 3), dtype="float32"), layer1_1_conv2_weight: R.Tensor((64, 64, 3, 3), dtype="float32"), layer2_0_bn1_bias: R.Tensor((128,), dtype="float32"), layer2_0_bn1_weight: R.Tensor((128,), dtype="float32"), layer2_0_bn2_bias: R.Tensor((128,), dtype="float32"), layer2_0_bn2_weight: R.Tensor((128,), dtype="float32"), layer2_0_conv1_weight: R.Tensor((128, 64, 3, 3), dtype="float32"), layer2_0_conv2_weight: R.Tensor((128, 128, 3, 3), dtype="float32"), layer2_0_downsample_0_weight: R.Tensor((128, 64, 1, 1), dtype="float32"), layer2_0_downsample_1_bias: R.Tensor((128,), dtype="float32"), layer2_0_downsample_1_weight: R.Tensor((128,), dtype="float32"), layer2_1_bn1_bias: R.Tensor((128,), dtype="float32"), layer2_1_bn1_weight: R.Tensor((128,), dtype="float32"), layer2_1_bn2_bias: R.Tensor((128,), dtype="float32"), layer2_1_bn2_weight: R.Tensor((128,), dtype="float32"), layer2_1_conv1_weight: R.Tensor((128, 128, 3, 3), dtype="float32"), layer2_1_conv2_weight: R.Tensor((128, 128, 3, 3), dtype="float32"), layer3_0_bn1_bias: R.Tensor((256,), dtype="float32"), layer3_0_bn1_weight: R.Tensor((256,), dtype="float32"), layer3_0_bn2_bias: R.Tensor((256,), dtype="float32"), layer3_0_bn2_weight: R.Tensor((256,), dtype="float32"), layer3_0_conv1_weight: R.Tensor((256, 128, 3, 3), dtype="float32"), layer3_0_conv2_weight: R.Tensor((256, 256, 3, 3), dtype="float32"), layer3_0_downsample_0_weight: R.Tensor((256, 128, 1, 1), dtype="float32"), layer3_0_downsample_1_bias: R.Tensor((256,), dtype="float32"), layer3_0_downsample_1_weight: R.Tensor((256,), dtype="float32"), layer3_1_bn1_bias: R.Tensor((256,), dtype="float32"), layer3_1_bn1_weight: R.Tensor((256,), dtype="float32"), layer3_1_bn2_bias: R.Tensor((256,), dtype="float32"), layer3_1_bn2_weight: R.Tensor((256,), dtype="float32"), layer3_1_conv1_weight: R.Tensor((256, 256, 3, 3), dtype="float32"), layer3_1_conv2_weight: R.Tensor((256, 256, 3, 3), dtype="float32"), layer4_0_bn1_bias: R.Tensor((512,), dtype="float32"), layer4_0_bn1_weight: R.Tensor((512,), dtype="float32"), layer4_0_bn2_bias: R.Tensor((512,), dtype="float32"), layer4_0_bn2_weight: R.Tensor((512,), dtype="float32"), layer4_0_conv1_weight: R.Tensor((512, 256, 3, 3), dtype="float32"), layer4_0_conv2_weight: R.Tensor((512, 512, 3, 3), dtype="float32"), layer4_0_downsample_0_weight: R.Tensor((512, 256, 1, 1), dtype="float32"), layer4_0_downsample_1_bias: R.Tensor((512,), dtype="float32"), layer4_0_downsample_1_weight: R.Tensor((512,), dtype="float32"), layer4_1_bn1_bias: R.Tensor((512,), dtype="float32"), layer4_1_bn1_weight: R.Tensor((512,), dtype="float32"), layer4_1_bn2_bias: R.Tensor((512,), dtype="float32"), layer4_1_bn2_weight: R.Tensor((512,), dtype="float32"), layer4_1_conv1_weight: R.Tensor((512, 512, 3, 3), dtype="float32"), layer4_1_conv2_weight: R.Tensor((512, 512, 3, 3), dtype="float32")) -> R.Tensor((1, 1000), dtype="float32"):
        R.func_attr({"num_input": 1})
        with R.dataflow():
            lv: R.Tensor((1, 64, 112, 112), dtype="float32") = R.nn.conv2d(inp_0, conv1_weight, strides=[2, 2], padding=[3, 3, 3, 3], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
            lv1: R.Tuple(R.Tensor((1, 64, 112, 112), dtype="float32"), R.Tensor((64,), dtype="float32"), R.Tensor((64,), dtype="float32")) = R.nn.batch_norm(lv, bn1_weight, bn1_bias, metadata["relax.expr.Constant"][0], metadata["relax.expr.Constant"][1], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001)
            lv2: R.Tensor((1, 64, 112, 112), dtype="float32") = lv1[0]
            lv3: R.Tensor((1, 64, 112, 112), dtype="float32") = R.nn.relu(lv2)
            lv4: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.max_pool2d(lv3, pool_size=[3, 3], strides=[2, 2], dilation=[1, 1], padding=[1, 1, 1, 1], ceil_mode=False, count_include_pad=False, layout="NCHW", out_layout="NCHW")
            lv5: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.conv2d(lv4, layer1_0_conv1_weight, strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
            lv6: R.Tuple(R.Tensor((1, 64, 56, 56), dtype="float32"), R.Tensor((64,), dtype="float32"), R.Tensor((64,), dtype="float32")) = R.nn.batch_norm(lv5, layer1_0_bn1_weight, layer1_0_bn1_bias, metadata["relax.expr.Constant"][2], metadata["relax.expr.Constant"][3], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001)
            lv7: R.Tensor((1, 64, 56, 56), dtype="float32") = lv6[0]
            lv8: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.relu(lv7)
            lv9: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.conv2d(lv8, layer1_0_conv2_weight, strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
            lv10: R.Tuple(R.Tensor((1, 64, 56, 56), dtype="float32"), R.Tensor((64,), dtype="float32"), R.Tensor((64,), dtype="float32")) = R.nn.batch_norm(lv9, layer1_0_bn2_weight, layer1_0_bn2_bias, metadata["relax.expr.Constant"][4], metadata["relax.expr.Constant"][5], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001)
            lv11: R.Tensor((1, 64, 56, 56), dtype="float32") = lv10[0]
            lv12: R.Tensor((1, 64, 56, 56), dtype="float32") = R.add(lv11, lv4)
            lv13: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.relu(lv12)
            lv14: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.conv2d(lv13, layer1_1_conv1_weight, strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
            lv15: R.Tuple(R.Tensor((1, 64, 56, 56), dtype="float32"), R.Tensor((64,), dtype="float32"), R.Tensor((64,), dtype="float32")) = R.nn.batch_norm(lv14, layer1_1_bn1_weight, layer1_1_bn1_bias, metadata["relax.expr.Constant"][6], metadata["relax.expr.Constant"][7], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001)
            lv16: R.Tensor((1, 64, 56, 56), dtype="float32") = lv15[0]
            lv17: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.relu(lv16)
            lv18: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.conv2d(lv17, layer1_1_conv2_weight, strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
            lv19: R.Tuple(R.Tensor((1, 64, 56, 56), dtype="float32"), R.Tensor((64,), dtype="float32"), R.Tensor((64,), dtype="float32")) = R.nn.batch_norm(lv18, layer1_1_bn2_weight, layer1_1_bn2_bias, metadata["relax.expr.Constant"][8], metadata["relax.expr.Constant"][9], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001)
            lv20: R.Tensor((1, 64, 56, 56), dtype="float32") = lv19[0]
            lv21: R.Tensor((1, 64, 56, 56), dtype="float32") = R.add(lv20, lv13)
            lv22: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.relu(lv21)
            lv23: R.Tensor((1, 128, 28, 28), dtype="float32") = R.nn.conv2d(lv22, layer2_0_conv1_weight, strides=[2, 2], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
            lv24: R.Tuple(R.Tensor((1, 128, 28, 28), dtype="float32"), R.Tensor((128,), dtype="float32"), R.Tensor((128,), dtype="float32")) = R.nn.batch_norm(lv23, layer2_0_bn1_weight, layer2_0_bn1_bias, metadata["relax.expr.Constant"][10], metadata["relax.expr.Constant"][11], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001)
            lv25: R.Tensor((1, 128, 28, 28), dtype="float32") = lv24[0]
            lv26: R.Tensor((1, 128, 28, 28), dtype="float32") = R.nn.relu(lv25)
            lv27: R.Tensor((1, 128, 28, 28), dtype="float32") = R.nn.conv2d(lv26, layer2_0_conv2_weight, strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
            lv28: R.Tuple(R.Tensor((1, 128, 28, 28), dtype="float32"), R.Tensor((128,), dtype="float32"), R.Tensor((128,), dtype="float32")) = R.nn.batch_norm(lv27, layer2_0_bn2_weight, layer2_0_bn2_bias, metadata["relax.expr.Constant"][12], metadata["relax.expr.Constant"][13], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001)
            lv29: R.Tensor((1, 128, 28, 28), dtype="float32") = lv28[0]
            lv30: R.Tensor((1, 128, 28, 28), dtype="float32") = R.nn.conv2d(lv22, layer2_0_downsample_0_weight, strides=[2, 2], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
            lv31: R.Tuple(R.Tensor((1, 128, 28, 28), dtype="float32"), R.Tensor((128,), dtype="float32"), R.Tensor((128,), dtype="float32")) = R.nn.batch_norm(lv30, layer2_0_downsample_1_weight, layer2_0_downsample_1_bias, metadata["relax.expr.Constant"][14], metadata["relax.expr.Constant"][15], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001)
            lv32: R.Tensor((1, 128, 28, 28), dtype="float32") = lv31[0]
            lv33: R.Tensor((1, 128, 28, 28), dtype="float32") = R.add(lv29, lv32)
            lv34: R.Tensor((1, 128, 28, 28), dtype="float32") = R.nn.relu(lv33)
            lv35: R.Tensor((1, 128, 28, 28), dtype="float32") = R.nn.conv2d(lv34, layer2_1_conv1_weight, strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
            lv36: R.Tuple(R.Tensor((1, 128, 28, 28), dtype="float32"), R.Tensor((128,), dtype="float32"), R.Tensor((128,), dtype="float32")) = R.nn.batch_norm(lv35, layer2_1_bn1_weight, layer2_1_bn1_bias, metadata["relax.expr.Constant"][16], metadata["relax.expr.Constant"][17], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001)
            lv37: R.Tensor((1, 128, 28, 28), dtype="float32") = lv36[0]
            lv38: R.Tensor((1, 128, 28, 28), dtype="float32") = R.nn.relu(lv37)
            lv39: R.Tensor((1, 128, 28, 28), dtype="float32") = R.nn.conv2d(lv38, layer2_1_conv2_weight, strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
            lv40: R.Tuple(R.Tensor((1, 128, 28, 28), dtype="float32"), R.Tensor((128,), dtype="float32"), R.Tensor((128,), dtype="float32")) = R.nn.batch_norm(lv39, layer2_1_bn2_weight, layer2_1_bn2_bias, metadata["relax.expr.Constant"][18], metadata["relax.expr.Constant"][19], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001)
            lv41: R.Tensor((1, 128, 28, 28), dtype="float32") = lv40[0]
            lv42: R.Tensor((1, 128, 28, 28), dtype="float32") = R.add(lv41, lv34)
            lv43: R.Tensor((1, 128, 28, 28), dtype="float32") = R.nn.relu(lv42)
            lv44: R.Tensor((1, 256, 14, 14), dtype="float32") = R.nn.conv2d(lv43, layer3_0_conv1_weight, strides=[2, 2], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
            lv45: R.Tuple(R.Tensor((1, 256, 14, 14), dtype="float32"), R.Tensor((256,), dtype="float32"), R.Tensor((256,), dtype="float32")) = R.nn.batch_norm(lv44, layer3_0_bn1_weight, layer3_0_bn1_bias, metadata["relax.expr.Constant"][20], metadata["relax.expr.Constant"][21], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001)
            lv46: R.Tensor((1, 256, 14, 14), dtype="float32") = lv45[0]
            lv47: R.Tensor((1, 256, 14, 14), dtype="float32") = R.nn.relu(lv46)
            lv48: R.Tensor((1, 256, 14, 14), dtype="float32") = R.nn.conv2d(lv47, layer3_0_conv2_weight, strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
            lv49: R.Tuple(R.Tensor((1, 256, 14, 14), dtype="float32"), R.Tensor((256,), dtype="float32"), R.Tensor((256,), dtype="float32")) = R.nn.batch_norm(lv48, layer3_0_bn2_weight, layer3_0_bn2_bias, metadata["relax.expr.Constant"][22], metadata["relax.expr.Constant"][23], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001)
            lv50: R.Tensor((1, 256, 14, 14), dtype="float32") = lv49[0]
            lv51: R.Tensor((1, 256, 14, 14), dtype="float32") = R.nn.conv2d(lv43, layer3_0_downsample_0_weight, strides=[2, 2], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
            lv52: R.Tuple(R.Tensor((1, 256, 14, 14), dtype="float32"), R.Tensor((256,), dtype="float32"), R.Tensor((256,), dtype="float32")) = R.nn.batch_norm(lv51, layer3_0_downsample_1_weight, layer3_0_downsample_1_bias, metadata["relax.expr.Constant"][24], metadata["relax.expr.Constant"][25], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001)
            lv53: R.Tensor((1, 256, 14, 14), dtype="float32") = lv52[0]
            lv54: R.Tensor((1, 256, 14, 14), dtype="float32") = R.add(lv50, lv53)
            lv55: R.Tensor((1, 256, 14, 14), dtype="float32") = R.nn.relu(lv54)
            lv56: R.Tensor((1, 256, 14, 14), dtype="float32") = R.nn.conv2d(lv55, layer3_1_conv1_weight, strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
            lv57: R.Tuple(R.Tensor((1, 256, 14, 14), dtype="float32"), R.Tensor((256,), dtype="float32"), R.Tensor((256,), dtype="float32")) = R.nn.batch_norm(lv56, layer3_1_bn1_weight, layer3_1_bn1_bias, metadata["relax.expr.Constant"][26], metadata["relax.expr.Constant"][27], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001)
            lv58: R.Tensor((1, 256, 14, 14), dtype="float32") = lv57[0]
            lv59: R.Tensor((1, 256, 14, 14), dtype="float32") = R.nn.relu(lv58)
            lv60: R.Tensor((1, 256, 14, 14), dtype="float32") = R.nn.conv2d(lv59, layer3_1_conv2_weight, strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
            lv61: R.Tuple(R.Tensor((1, 256, 14, 14), dtype="float32"), R.Tensor((256,), dtype="float32"), R.Tensor((256,), dtype="float32")) = R.nn.batch_norm(lv60, layer3_1_bn2_weight, layer3_1_bn2_bias, metadata["relax.expr.Constant"][28], metadata["relax.expr.Constant"][29], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001)
            lv62: R.Tensor((1, 256, 14, 14), dtype="float32") = lv61[0]
            lv63: R.Tensor((1, 256, 14, 14), dtype="float32") = R.add(lv62, lv55)
            lv64: R.Tensor((1, 256, 14, 14), dtype="float32") = R.nn.relu(lv63)
            lv65: R.Tensor((1, 512, 7, 7), dtype="float32") = R.nn.conv2d(lv64, layer4_0_conv1_weight, strides=[2, 2], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
            lv66: R.Tuple(R.Tensor((1, 512, 7, 7), dtype="float32"), R.Tensor((512,), dtype="float32"), R.Tensor((512,), dtype="float32")) = R.nn.batch_norm(lv65, layer4_0_bn1_weight, layer4_0_bn1_bias, metadata["relax.expr.Constant"][30], metadata["relax.expr.Constant"][31], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001)
            lv67: R.Tensor((1, 512, 7, 7), dtype="float32") = lv66[0]
            lv68: R.Tensor((1, 512, 7, 7), dtype="float32") = R.nn.relu(lv67)
            lv69: R.Tensor((1, 512, 7, 7), dtype="float32") = R.nn.conv2d(lv68, layer4_0_conv2_weight, strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
            lv70: R.Tuple(R.Tensor((1, 512, 7, 7), dtype="float32"), R.Tensor((512,), dtype="float32"), R.Tensor((512,), dtype="float32")) = R.nn.batch_norm(lv69, layer4_0_bn2_weight, layer4_0_bn2_bias, metadata["relax.expr.Constant"][32], metadata["relax.expr.Constant"][33], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001)
            lv71: R.Tensor((1, 512, 7, 7), dtype="float32") = lv70[0]
            lv72: R.Tensor((1, 512, 7, 7), dtype="float32") = R.nn.conv2d(lv64, layer4_0_downsample_0_weight, strides=[2, 2], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
            lv73: R.Tuple(R.Tensor((1, 512, 7, 7), dtype="float32"), R.Tensor((512,), dtype="float32"), R.Tensor((512,), dtype="float32")) = R.nn.batch_norm(lv72, layer4_0_downsample_1_weight, layer4_0_downsample_1_bias, metadata["relax.expr.Constant"][34], metadata["relax.expr.Constant"][35], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001)
            lv74: R.Tensor((1, 512, 7, 7), dtype="float32") = lv73[0]
            lv75: R.Tensor((1, 512, 7, 7), dtype="float32") = R.add(lv71, lv74)
            lv76: R.Tensor((1, 512, 7, 7), dtype="float32") = R.nn.relu(lv75)
            lv77: R.Tensor((1, 512, 7, 7), dtype="float32") = R.nn.conv2d(lv76, layer4_1_conv1_weight, strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
            lv78: R.Tuple(R.Tensor((1, 512, 7, 7), dtype="float32"), R.Tensor((512,), dtype="float32"), R.Tensor((512,), dtype="float32")) = R.nn.batch_norm(lv77, layer4_1_bn1_weight, layer4_1_bn1_bias, metadata["relax.expr.Constant"][36], metadata["relax.expr.Constant"][37], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001)
            lv79: R.Tensor((1, 512, 7, 7), dtype="float32") = lv78[0]
            lv80: R.Tensor((1, 512, 7, 7), dtype="float32") = R.nn.relu(lv79)
            lv81: R.Tensor((1, 512, 7, 7), dtype="float32") = R.nn.conv2d(lv80, layer4_1_conv2_weight, strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
            lv82: R.Tuple(R.Tensor((1, 512, 7, 7), dtype="float32"), R.Tensor((512,), dtype="float32"), R.Tensor((512,), dtype="float32")) = R.nn.batch_norm(lv81, layer4_1_bn2_weight, layer4_1_bn2_bias, metadata["relax.expr.Constant"][38], metadata["relax.expr.Constant"][39], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001)
            lv83: R.Tensor((1, 512, 7, 7), dtype="float32") = lv82[0]
            lv84: R.Tensor((1, 512, 7, 7), dtype="float32") = R.add(lv83, lv76)
            lv85: R.Tensor((1, 512, 7, 7), dtype="float32") = R.nn.relu(lv84)
            lv86: R.Tensor((1, 512, 1, 1), dtype="float32") = R.nn.adaptive_avg_pool2d(lv85, output_size=[1, 1], layout="NCHW", out_layout="NCHW")
            lv87: R.Tensor((1, 512), dtype="float32") = R.reshape(lv86, R.shape([1, 512]))
            lv88: R.Tensor((512, 1000), dtype="float32") = R.permute_dims(fc_weight, axes=None)
            lv89: R.Tensor((1, 1000), dtype="float32") = R.matmul(lv87, lv88, out_dtype="float32")
            lv90: R.Tensor((1, 1000), dtype="float32") = R.add(lv89, fc_bias)
            gv: R.Tensor((1, 1000), dtype="float32") = lv90
            R.output(gv)
        return gv

# Metadata omitted. Use show_meta=True in script() method to show it.

IRModule 优化#

Apache TVM 提供了一个灵活的方式来优化 IRModule。围绕 IRModule 优化的所有运算都可以与现有的流水线组合。注意,每个转换都可以通过 tvm.ir.transform.Sequential 组合成一个优化流水线。

在本教程中,专注于通过自动调优(auto-tuning)对模型进行端到端优化。利用 MetaSchedule 调优模型,并将调优日志存储到数据库。还可以应用数据库到模型以获得最佳性能。

TOTAL_TRIALS = 8000  # Change to 20000 for better performance if needed
target = tvm.target.Target("nvidia/geforce-rtx-3090-ti")  # Change to your target device
work_dir = f"{temp_dir}/tuning_logs"

# Skip running in CI environment
IS_IN_CI = os.getenv("CI", "") == "true"
if not IS_IN_CI:
    with target:
        mod = tvm.ir.transform.Sequential(
            [
                # Convert BatchNorm into a sequence of simpler ops for fusion
                relax.transform.DecomposeOpsForInference(),
                # Canonicalize the bindings
                relax.transform.CanonicalizeBindings(),
                # Run default optimization pipeline
                relax.get_pipeline("zero"),
                # Tune the model and store the log to database
                relax.transform.MetaScheduleTuneIRMod({}, work_dir, TOTAL_TRIALS),
                # Apply the database
                relax.transform.MetaScheduleApplyDatabase(work_dir),
            ]
        )(mod)

    # Only show the main function
    mod["main"].show()

构建和部署#

最后,我们构建优化后的模型并将其部署到目标设备。

if not IS_IN_CI:
    ex = relax.build(mod, target="cuda")
    dev = tvm.device("cuda", 0)
    vm = relax.VirtualMachine(ex, dev)
    # Need to allocate data and params on GPU device
    gpu_data = tvm.nd.array(np.random.rand(1, 3, 224, 224).astype("float32"), dev)
    gpu_params = [tvm.nd.array(p, dev) for p in params["main"]]
    gpu_out = vm["main"](gpu_data, *gpu_params).numpy()

    print(gpu_out.shape)