端到端模型优化样例#
本教程展示了如何使用 Apache TVM 来优化机器学习模型。使用来自 PyTorch 的预训练 ResNet-18 模型,并利用 TVM 的 Relax API 对其进行端到端的优化。请注意,默认的端到端优化可能不适合复杂的模型。
准备阶段#
首先,准备模型和输入信息。使用来自PyTorch的预训练ResNet-18模型。
import set_env
from pathlib import Path
temp_dir = Path(".temp")
temp_dir.mkdir(exist_ok=True)
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
E0000 00:00:1732543611.250788 94505 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1732543611.258525 94505 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
import os
import numpy as np
import torch
from torch import fx
from torchvision.models.resnet import ResNet18_Weights, resnet18
torch_model = resnet18(weights=ResNet18_Weights.DEFAULT)
整体流程概述#
整体流程包括以下步骤:
构建或导入模型:构建一个神经网络模型,或者从其他框架(如 PyTorch、ONNX)导入预训练的模型,并创建 TVM IRModule,其中包含编译所需的所有信息,包括用于计算图的高级别 Relax 函数和用于张量程序的低级 TensorIR 函数。
执行可组合优化:执行一系列优化转换,例如图优化、张量程序优化和库调度。
构建和通用部署:将优化后的模型构建为可在通用运行时部署的模块,并在不同设备上执行,如 CPU、GPU 或其他加速器。
将模型转换为 IRModule#
使用 Relax 前端(面向 PyTorch)将模型转换为 IRModule,以便进一步优化。除了模型外,我们还需要提供输入的形状和数据类型。
import tvm
from tvm import relax
from tvm.relax.frontend.torch import from_fx
torch_model = resnet18(weights=ResNet18_Weights.DEFAULT)
# Give the input shape and data type
input_info = [((1, 3, 224, 224), "float32")]
# Convert the model to IRModule
with torch.no_grad():
torch_fx_model = fx.symbolic_trace(torch_model)
mod = from_fx(torch_fx_model, input_info, keep_params_as_input=True)
mod, params = relax.frontend.detach_params(mod)
mod.show()
Show code cell output
# from tvm.script import ir as I
# from tvm.script import relax as R
@I.ir_module
class Module:
@R.function
def main(inp_0: R.Tensor((1, 3, 224, 224), dtype="float32"), bn1_bias: R.Tensor((64,), dtype="float32"), bn1_weight: R.Tensor((64,), dtype="float32"), conv1_weight: R.Tensor((64, 3, 7, 7), dtype="float32"), fc_bias: R.Tensor((1000,), dtype="float32"), fc_weight: R.Tensor((1000, 512), dtype="float32"), layer1_0_bn1_bias: R.Tensor((64,), dtype="float32"), layer1_0_bn1_weight: R.Tensor((64,), dtype="float32"), layer1_0_bn2_bias: R.Tensor((64,), dtype="float32"), layer1_0_bn2_weight: R.Tensor((64,), dtype="float32"), layer1_0_conv1_weight: R.Tensor((64, 64, 3, 3), dtype="float32"), layer1_0_conv2_weight: R.Tensor((64, 64, 3, 3), dtype="float32"), layer1_1_bn1_bias: R.Tensor((64,), dtype="float32"), layer1_1_bn1_weight: R.Tensor((64,), dtype="float32"), layer1_1_bn2_bias: R.Tensor((64,), dtype="float32"), layer1_1_bn2_weight: R.Tensor((64,), dtype="float32"), layer1_1_conv1_weight: R.Tensor((64, 64, 3, 3), dtype="float32"), layer1_1_conv2_weight: R.Tensor((64, 64, 3, 3), dtype="float32"), layer2_0_bn1_bias: R.Tensor((128,), dtype="float32"), layer2_0_bn1_weight: R.Tensor((128,), dtype="float32"), layer2_0_bn2_bias: R.Tensor((128,), dtype="float32"), layer2_0_bn2_weight: R.Tensor((128,), dtype="float32"), layer2_0_conv1_weight: R.Tensor((128, 64, 3, 3), dtype="float32"), layer2_0_conv2_weight: R.Tensor((128, 128, 3, 3), dtype="float32"), layer2_0_downsample_0_weight: R.Tensor((128, 64, 1, 1), dtype="float32"), layer2_0_downsample_1_bias: R.Tensor((128,), dtype="float32"), layer2_0_downsample_1_weight: R.Tensor((128,), dtype="float32"), layer2_1_bn1_bias: R.Tensor((128,), dtype="float32"), layer2_1_bn1_weight: R.Tensor((128,), dtype="float32"), layer2_1_bn2_bias: R.Tensor((128,), dtype="float32"), layer2_1_bn2_weight: R.Tensor((128,), dtype="float32"), layer2_1_conv1_weight: R.Tensor((128, 128, 3, 3), dtype="float32"), layer2_1_conv2_weight: R.Tensor((128, 128, 3, 3), dtype="float32"), layer3_0_bn1_bias: R.Tensor((256,), dtype="float32"), layer3_0_bn1_weight: R.Tensor((256,), dtype="float32"), layer3_0_bn2_bias: R.Tensor((256,), dtype="float32"), layer3_0_bn2_weight: R.Tensor((256,), dtype="float32"), layer3_0_conv1_weight: R.Tensor((256, 128, 3, 3), dtype="float32"), layer3_0_conv2_weight: R.Tensor((256, 256, 3, 3), dtype="float32"), layer3_0_downsample_0_weight: R.Tensor((256, 128, 1, 1), dtype="float32"), layer3_0_downsample_1_bias: R.Tensor((256,), dtype="float32"), layer3_0_downsample_1_weight: R.Tensor((256,), dtype="float32"), layer3_1_bn1_bias: R.Tensor((256,), dtype="float32"), layer3_1_bn1_weight: R.Tensor((256,), dtype="float32"), layer3_1_bn2_bias: R.Tensor((256,), dtype="float32"), layer3_1_bn2_weight: R.Tensor((256,), dtype="float32"), layer3_1_conv1_weight: R.Tensor((256, 256, 3, 3), dtype="float32"), layer3_1_conv2_weight: R.Tensor((256, 256, 3, 3), dtype="float32"), layer4_0_bn1_bias: R.Tensor((512,), dtype="float32"), layer4_0_bn1_weight: R.Tensor((512,), dtype="float32"), layer4_0_bn2_bias: R.Tensor((512,), dtype="float32"), layer4_0_bn2_weight: R.Tensor((512,), dtype="float32"), layer4_0_conv1_weight: R.Tensor((512, 256, 3, 3), dtype="float32"), layer4_0_conv2_weight: R.Tensor((512, 512, 3, 3), dtype="float32"), layer4_0_downsample_0_weight: R.Tensor((512, 256, 1, 1), dtype="float32"), layer4_0_downsample_1_bias: R.Tensor((512,), dtype="float32"), layer4_0_downsample_1_weight: R.Tensor((512,), dtype="float32"), layer4_1_bn1_bias: R.Tensor((512,), dtype="float32"), layer4_1_bn1_weight: R.Tensor((512,), dtype="float32"), layer4_1_bn2_bias: R.Tensor((512,), dtype="float32"), layer4_1_bn2_weight: R.Tensor((512,), dtype="float32"), layer4_1_conv1_weight: R.Tensor((512, 512, 3, 3), dtype="float32"), layer4_1_conv2_weight: R.Tensor((512, 512, 3, 3), dtype="float32")) -> R.Tensor((1, 1000), dtype="float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
lv: R.Tensor((1, 64, 112, 112), dtype="float32") = R.nn.conv2d(inp_0, conv1_weight, strides=[2, 2], padding=[3, 3, 3, 3], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
lv1: R.Tuple(R.Tensor((1, 64, 112, 112), dtype="float32"), R.Tensor((64,), dtype="float32"), R.Tensor((64,), dtype="float32")) = R.nn.batch_norm(lv, bn1_weight, bn1_bias, metadata["relax.expr.Constant"][0], metadata["relax.expr.Constant"][1], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001)
lv2: R.Tensor((1, 64, 112, 112), dtype="float32") = lv1[0]
lv3: R.Tensor((1, 64, 112, 112), dtype="float32") = R.nn.relu(lv2)
lv4: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.max_pool2d(lv3, pool_size=[3, 3], strides=[2, 2], dilation=[1, 1], padding=[1, 1, 1, 1], ceil_mode=False, count_include_pad=False, layout="NCHW", out_layout="NCHW")
lv5: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.conv2d(lv4, layer1_0_conv1_weight, strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
lv6: R.Tuple(R.Tensor((1, 64, 56, 56), dtype="float32"), R.Tensor((64,), dtype="float32"), R.Tensor((64,), dtype="float32")) = R.nn.batch_norm(lv5, layer1_0_bn1_weight, layer1_0_bn1_bias, metadata["relax.expr.Constant"][2], metadata["relax.expr.Constant"][3], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001)
lv7: R.Tensor((1, 64, 56, 56), dtype="float32") = lv6[0]
lv8: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.relu(lv7)
lv9: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.conv2d(lv8, layer1_0_conv2_weight, strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
lv10: R.Tuple(R.Tensor((1, 64, 56, 56), dtype="float32"), R.Tensor((64,), dtype="float32"), R.Tensor((64,), dtype="float32")) = R.nn.batch_norm(lv9, layer1_0_bn2_weight, layer1_0_bn2_bias, metadata["relax.expr.Constant"][4], metadata["relax.expr.Constant"][5], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001)
lv11: R.Tensor((1, 64, 56, 56), dtype="float32") = lv10[0]
lv12: R.Tensor((1, 64, 56, 56), dtype="float32") = R.add(lv11, lv4)
lv13: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.relu(lv12)
lv14: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.conv2d(lv13, layer1_1_conv1_weight, strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
lv15: R.Tuple(R.Tensor((1, 64, 56, 56), dtype="float32"), R.Tensor((64,), dtype="float32"), R.Tensor((64,), dtype="float32")) = R.nn.batch_norm(lv14, layer1_1_bn1_weight, layer1_1_bn1_bias, metadata["relax.expr.Constant"][6], metadata["relax.expr.Constant"][7], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001)
lv16: R.Tensor((1, 64, 56, 56), dtype="float32") = lv15[0]
lv17: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.relu(lv16)
lv18: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.conv2d(lv17, layer1_1_conv2_weight, strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
lv19: R.Tuple(R.Tensor((1, 64, 56, 56), dtype="float32"), R.Tensor((64,), dtype="float32"), R.Tensor((64,), dtype="float32")) = R.nn.batch_norm(lv18, layer1_1_bn2_weight, layer1_1_bn2_bias, metadata["relax.expr.Constant"][8], metadata["relax.expr.Constant"][9], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001)
lv20: R.Tensor((1, 64, 56, 56), dtype="float32") = lv19[0]
lv21: R.Tensor((1, 64, 56, 56), dtype="float32") = R.add(lv20, lv13)
lv22: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.relu(lv21)
lv23: R.Tensor((1, 128, 28, 28), dtype="float32") = R.nn.conv2d(lv22, layer2_0_conv1_weight, strides=[2, 2], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
lv24: R.Tuple(R.Tensor((1, 128, 28, 28), dtype="float32"), R.Tensor((128,), dtype="float32"), R.Tensor((128,), dtype="float32")) = R.nn.batch_norm(lv23, layer2_0_bn1_weight, layer2_0_bn1_bias, metadata["relax.expr.Constant"][10], metadata["relax.expr.Constant"][11], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001)
lv25: R.Tensor((1, 128, 28, 28), dtype="float32") = lv24[0]
lv26: R.Tensor((1, 128, 28, 28), dtype="float32") = R.nn.relu(lv25)
lv27: R.Tensor((1, 128, 28, 28), dtype="float32") = R.nn.conv2d(lv26, layer2_0_conv2_weight, strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
lv28: R.Tuple(R.Tensor((1, 128, 28, 28), dtype="float32"), R.Tensor((128,), dtype="float32"), R.Tensor((128,), dtype="float32")) = R.nn.batch_norm(lv27, layer2_0_bn2_weight, layer2_0_bn2_bias, metadata["relax.expr.Constant"][12], metadata["relax.expr.Constant"][13], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001)
lv29: R.Tensor((1, 128, 28, 28), dtype="float32") = lv28[0]
lv30: R.Tensor((1, 128, 28, 28), dtype="float32") = R.nn.conv2d(lv22, layer2_0_downsample_0_weight, strides=[2, 2], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
lv31: R.Tuple(R.Tensor((1, 128, 28, 28), dtype="float32"), R.Tensor((128,), dtype="float32"), R.Tensor((128,), dtype="float32")) = R.nn.batch_norm(lv30, layer2_0_downsample_1_weight, layer2_0_downsample_1_bias, metadata["relax.expr.Constant"][14], metadata["relax.expr.Constant"][15], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001)
lv32: R.Tensor((1, 128, 28, 28), dtype="float32") = lv31[0]
lv33: R.Tensor((1, 128, 28, 28), dtype="float32") = R.add(lv29, lv32)
lv34: R.Tensor((1, 128, 28, 28), dtype="float32") = R.nn.relu(lv33)
lv35: R.Tensor((1, 128, 28, 28), dtype="float32") = R.nn.conv2d(lv34, layer2_1_conv1_weight, strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
lv36: R.Tuple(R.Tensor((1, 128, 28, 28), dtype="float32"), R.Tensor((128,), dtype="float32"), R.Tensor((128,), dtype="float32")) = R.nn.batch_norm(lv35, layer2_1_bn1_weight, layer2_1_bn1_bias, metadata["relax.expr.Constant"][16], metadata["relax.expr.Constant"][17], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001)
lv37: R.Tensor((1, 128, 28, 28), dtype="float32") = lv36[0]
lv38: R.Tensor((1, 128, 28, 28), dtype="float32") = R.nn.relu(lv37)
lv39: R.Tensor((1, 128, 28, 28), dtype="float32") = R.nn.conv2d(lv38, layer2_1_conv2_weight, strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
lv40: R.Tuple(R.Tensor((1, 128, 28, 28), dtype="float32"), R.Tensor((128,), dtype="float32"), R.Tensor((128,), dtype="float32")) = R.nn.batch_norm(lv39, layer2_1_bn2_weight, layer2_1_bn2_bias, metadata["relax.expr.Constant"][18], metadata["relax.expr.Constant"][19], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001)
lv41: R.Tensor((1, 128, 28, 28), dtype="float32") = lv40[0]
lv42: R.Tensor((1, 128, 28, 28), dtype="float32") = R.add(lv41, lv34)
lv43: R.Tensor((1, 128, 28, 28), dtype="float32") = R.nn.relu(lv42)
lv44: R.Tensor((1, 256, 14, 14), dtype="float32") = R.nn.conv2d(lv43, layer3_0_conv1_weight, strides=[2, 2], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
lv45: R.Tuple(R.Tensor((1, 256, 14, 14), dtype="float32"), R.Tensor((256,), dtype="float32"), R.Tensor((256,), dtype="float32")) = R.nn.batch_norm(lv44, layer3_0_bn1_weight, layer3_0_bn1_bias, metadata["relax.expr.Constant"][20], metadata["relax.expr.Constant"][21], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001)
lv46: R.Tensor((1, 256, 14, 14), dtype="float32") = lv45[0]
lv47: R.Tensor((1, 256, 14, 14), dtype="float32") = R.nn.relu(lv46)
lv48: R.Tensor((1, 256, 14, 14), dtype="float32") = R.nn.conv2d(lv47, layer3_0_conv2_weight, strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
lv49: R.Tuple(R.Tensor((1, 256, 14, 14), dtype="float32"), R.Tensor((256,), dtype="float32"), R.Tensor((256,), dtype="float32")) = R.nn.batch_norm(lv48, layer3_0_bn2_weight, layer3_0_bn2_bias, metadata["relax.expr.Constant"][22], metadata["relax.expr.Constant"][23], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001)
lv50: R.Tensor((1, 256, 14, 14), dtype="float32") = lv49[0]
lv51: R.Tensor((1, 256, 14, 14), dtype="float32") = R.nn.conv2d(lv43, layer3_0_downsample_0_weight, strides=[2, 2], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
lv52: R.Tuple(R.Tensor((1, 256, 14, 14), dtype="float32"), R.Tensor((256,), dtype="float32"), R.Tensor((256,), dtype="float32")) = R.nn.batch_norm(lv51, layer3_0_downsample_1_weight, layer3_0_downsample_1_bias, metadata["relax.expr.Constant"][24], metadata["relax.expr.Constant"][25], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001)
lv53: R.Tensor((1, 256, 14, 14), dtype="float32") = lv52[0]
lv54: R.Tensor((1, 256, 14, 14), dtype="float32") = R.add(lv50, lv53)
lv55: R.Tensor((1, 256, 14, 14), dtype="float32") = R.nn.relu(lv54)
lv56: R.Tensor((1, 256, 14, 14), dtype="float32") = R.nn.conv2d(lv55, layer3_1_conv1_weight, strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
lv57: R.Tuple(R.Tensor((1, 256, 14, 14), dtype="float32"), R.Tensor((256,), dtype="float32"), R.Tensor((256,), dtype="float32")) = R.nn.batch_norm(lv56, layer3_1_bn1_weight, layer3_1_bn1_bias, metadata["relax.expr.Constant"][26], metadata["relax.expr.Constant"][27], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001)
lv58: R.Tensor((1, 256, 14, 14), dtype="float32") = lv57[0]
lv59: R.Tensor((1, 256, 14, 14), dtype="float32") = R.nn.relu(lv58)
lv60: R.Tensor((1, 256, 14, 14), dtype="float32") = R.nn.conv2d(lv59, layer3_1_conv2_weight, strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
lv61: R.Tuple(R.Tensor((1, 256, 14, 14), dtype="float32"), R.Tensor((256,), dtype="float32"), R.Tensor((256,), dtype="float32")) = R.nn.batch_norm(lv60, layer3_1_bn2_weight, layer3_1_bn2_bias, metadata["relax.expr.Constant"][28], metadata["relax.expr.Constant"][29], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001)
lv62: R.Tensor((1, 256, 14, 14), dtype="float32") = lv61[0]
lv63: R.Tensor((1, 256, 14, 14), dtype="float32") = R.add(lv62, lv55)
lv64: R.Tensor((1, 256, 14, 14), dtype="float32") = R.nn.relu(lv63)
lv65: R.Tensor((1, 512, 7, 7), dtype="float32") = R.nn.conv2d(lv64, layer4_0_conv1_weight, strides=[2, 2], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
lv66: R.Tuple(R.Tensor((1, 512, 7, 7), dtype="float32"), R.Tensor((512,), dtype="float32"), R.Tensor((512,), dtype="float32")) = R.nn.batch_norm(lv65, layer4_0_bn1_weight, layer4_0_bn1_bias, metadata["relax.expr.Constant"][30], metadata["relax.expr.Constant"][31], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001)
lv67: R.Tensor((1, 512, 7, 7), dtype="float32") = lv66[0]
lv68: R.Tensor((1, 512, 7, 7), dtype="float32") = R.nn.relu(lv67)
lv69: R.Tensor((1, 512, 7, 7), dtype="float32") = R.nn.conv2d(lv68, layer4_0_conv2_weight, strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
lv70: R.Tuple(R.Tensor((1, 512, 7, 7), dtype="float32"), R.Tensor((512,), dtype="float32"), R.Tensor((512,), dtype="float32")) = R.nn.batch_norm(lv69, layer4_0_bn2_weight, layer4_0_bn2_bias, metadata["relax.expr.Constant"][32], metadata["relax.expr.Constant"][33], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001)
lv71: R.Tensor((1, 512, 7, 7), dtype="float32") = lv70[0]
lv72: R.Tensor((1, 512, 7, 7), dtype="float32") = R.nn.conv2d(lv64, layer4_0_downsample_0_weight, strides=[2, 2], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
lv73: R.Tuple(R.Tensor((1, 512, 7, 7), dtype="float32"), R.Tensor((512,), dtype="float32"), R.Tensor((512,), dtype="float32")) = R.nn.batch_norm(lv72, layer4_0_downsample_1_weight, layer4_0_downsample_1_bias, metadata["relax.expr.Constant"][34], metadata["relax.expr.Constant"][35], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001)
lv74: R.Tensor((1, 512, 7, 7), dtype="float32") = lv73[0]
lv75: R.Tensor((1, 512, 7, 7), dtype="float32") = R.add(lv71, lv74)
lv76: R.Tensor((1, 512, 7, 7), dtype="float32") = R.nn.relu(lv75)
lv77: R.Tensor((1, 512, 7, 7), dtype="float32") = R.nn.conv2d(lv76, layer4_1_conv1_weight, strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
lv78: R.Tuple(R.Tensor((1, 512, 7, 7), dtype="float32"), R.Tensor((512,), dtype="float32"), R.Tensor((512,), dtype="float32")) = R.nn.batch_norm(lv77, layer4_1_bn1_weight, layer4_1_bn1_bias, metadata["relax.expr.Constant"][36], metadata["relax.expr.Constant"][37], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001)
lv79: R.Tensor((1, 512, 7, 7), dtype="float32") = lv78[0]
lv80: R.Tensor((1, 512, 7, 7), dtype="float32") = R.nn.relu(lv79)
lv81: R.Tensor((1, 512, 7, 7), dtype="float32") = R.nn.conv2d(lv80, layer4_1_conv2_weight, strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
lv82: R.Tuple(R.Tensor((1, 512, 7, 7), dtype="float32"), R.Tensor((512,), dtype="float32"), R.Tensor((512,), dtype="float32")) = R.nn.batch_norm(lv81, layer4_1_bn2_weight, layer4_1_bn2_bias, metadata["relax.expr.Constant"][38], metadata["relax.expr.Constant"][39], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001)
lv83: R.Tensor((1, 512, 7, 7), dtype="float32") = lv82[0]
lv84: R.Tensor((1, 512, 7, 7), dtype="float32") = R.add(lv83, lv76)
lv85: R.Tensor((1, 512, 7, 7), dtype="float32") = R.nn.relu(lv84)
lv86: R.Tensor((1, 512, 1, 1), dtype="float32") = R.nn.adaptive_avg_pool2d(lv85, output_size=[1, 1], layout="NCHW", out_layout="NCHW")
lv87: R.Tensor((1, 512), dtype="float32") = R.reshape(lv86, R.shape([1, 512]))
lv88: R.Tensor((512, 1000), dtype="float32") = R.permute_dims(fc_weight, axes=None)
lv89: R.Tensor((1, 1000), dtype="float32") = R.matmul(lv87, lv88, out_dtype="float32")
lv90: R.Tensor((1, 1000), dtype="float32") = R.add(lv89, fc_bias)
gv: R.Tensor((1, 1000), dtype="float32") = lv90
R.output(gv)
return gv
# Metadata omitted. Use show_meta=True in script() method to show it.
IRModule 优化#
Apache TVM 提供了一个灵活的方式来优化 IRModule。围绕 IRModule 优化的所有运算都可以与现有的流水线组合。注意,每个转换都可以通过 tvm.ir.transform.Sequential
组合成一个优化流水线。
在本教程中,专注于通过自动调优(auto-tuning)对模型进行端到端优化。利用 MetaSchedule 调优模型,并将调优日志存储到数据库。还可以应用数据库到模型以获得最佳性能。
TOTAL_TRIALS = 8000 # Change to 20000 for better performance if needed
target = tvm.target.Target("nvidia/geforce-rtx-3090-ti") # Change to your target device
work_dir = f"{temp_dir}/tuning_logs"
# Skip running in CI environment
IS_IN_CI = os.getenv("CI", "") == "true"
if not IS_IN_CI:
with target:
mod = tvm.ir.transform.Sequential(
[
# Convert BatchNorm into a sequence of simpler ops for fusion
relax.transform.DecomposeOpsForInference(),
# Canonicalize the bindings
relax.transform.CanonicalizeBindings(),
# Run default optimization pipeline
relax.get_pipeline("zero"),
# Tune the model and store the log to database
relax.transform.MetaScheduleTuneIRMod({}, work_dir, TOTAL_TRIALS),
# Apply the database
relax.transform.MetaScheduleApplyDatabase(work_dir),
]
)(mod)
# Only show the main function
mod["main"].show()
构建和部署#
最后,我们构建优化后的模型并将其部署到目标设备。
if not IS_IN_CI:
ex = relax.build(mod, target="cuda")
dev = tvm.device("cuda", 0)
vm = relax.VirtualMachine(ex, dev)
# Need to allocate data and params on GPU device
gpu_data = tvm.nd.array(np.random.rand(1, 3, 224, 224).astype("float32"), dev)
gpu_params = [tvm.nd.array(p, dev) for p in params["main"]]
gpu_out = vm["main"](gpu_data, *gpu_params).numpy()
print(gpu_out.shape)