PyTorch 模型#
import sys
sys.path.extend([".."])
import set_env
from tvm_book.tools.frontends import Frontend, TrainInputConfig
from tvm_book.tools import display
pytorch 前端模型配置:
print(display.Tree("| ")("models/pytorch_demo"))
pytorch_demo/
文件夹下存在如下内容:
save.py
存储 PyTorch 模型为resnet18.pth
import torch from torchvision.models import resnet18, ResNet18_Weights model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1) input_name = "data" shape = 1, 3, 224, 224 trace = torch.jit.trace(model.eval(), torch.rand(shape).float()) torch.jit.save(trace, "resnet18.pt")
resnet18.pth
存储 PyTorch 模型结构与参数config.toml
存储 PyTorch 模型配置信息[input] # 训练或者量化校准阶段输入数据信息 name = "x" shape = [ 1, 3, 224, 224,] dtype = "float32" layout = "nchw" [model] # 前端模型配置信息 model_type = "pytorch" # 前端模型框架或者类型 path = "resnet18.pt" # PyTorch TorchScript 模型
from pathlib import Path
import toml
config_path = "models/pytorch_demo/config.toml"
config_path = Path(config_path)
config = toml.load(config_path)
model_type = config['model']["model_type"]
if len(config['train_inputs']) == 1:
input_config = config['train_inputs'][0]
if model_type == "pytorch":
shape_dict = {input_config["name"]: input_config["shape"]}
model = Frontend(model_type).load(f"{config_path.parent}/{config['model']['path']}", shape_dict=shape_dict)
import tvm
import tvm.relay as relay
from tvm.relay.build_module import bind_params_by_name
from tvm.ir.instrument import (
PassTimingInstrument,
pass_instrument,
)
import numpy as np
def get_calibration_dataset(mod, input_name):
dataset = []
input_shape = [int(x) for x in mod["main"].checked_type.arg_types[0].shape]
for i in range(5):
data = np.random.uniform(size=input_shape)
dataset.append({input_name: data})
return dataset
dataset = get_calibration_dataset(model.mod, "x")
BASE_CFG = {
"skip_conv_layers": [],
"skip_dense_layers": False,
"dtype_input": "int8",
"dtype_weight": "int8",
"dtype_activation": "int32",
}
with tvm.transform.PassContext(opt_level=3):
with relay.quantize.qconfig(**BASE_CFG, calibrate_mode="percentile"):
qmod = relay.quantize.quantize(model.mod, params=model.params, dataset=dataset)