PyTorch 模型

PyTorch 模型#

import sys
sys.path.extend([".."])
import set_env
from tvm_book.tools.frontends import Frontend, TrainInputConfig
from tvm_book.tools import display

pytorch 前端模型配置:

print(display.Tree("| ")("models/pytorch_demo"))

pytorch_demo/ 文件夹下存在如下内容:

  • save.py 存储 PyTorch 模型为 resnet18.pth

    import torch
    from torchvision.models import resnet18, ResNet18_Weights
    model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
    
    input_name = "data"
    shape = 1, 3, 224, 224
    trace = torch.jit.trace(model.eval(), torch.rand(shape).float())
    torch.jit.save(trace, "resnet18.pt")
    
  • resnet18.pth 存储 PyTorch 模型结构与参数

  • config.toml 存储 PyTorch 模型配置信息

    [input] # 训练或者量化校准阶段输入数据信息
    name = "x"
    shape = [ 1, 3, 224, 224,]
    dtype = "float32"
    layout = "nchw"
    
    [model] # 前端模型配置信息
    model_type = "pytorch" # 前端模型框架或者类型
    path = "resnet18.pt" # PyTorch TorchScript 模型
    
from pathlib import Path
import toml

config_path = "models/pytorch_demo/config.toml"

config_path = Path(config_path) 
config = toml.load(config_path)
model_type = config['model']["model_type"]
if len(config['train_inputs']) == 1:
    input_config = config['train_inputs'][0]
if model_type == "pytorch":
    shape_dict = {input_config["name"]: input_config["shape"]}
    model = Frontend(model_type).load(f"{config_path.parent}/{config['model']['path']}", shape_dict=shape_dict)
import tvm
import tvm.relay as relay
from tvm.relay.build_module import bind_params_by_name
from tvm.ir.instrument import (
    PassTimingInstrument,
    pass_instrument,
)
import numpy as np
def get_calibration_dataset(mod, input_name):
    dataset = []
    input_shape = [int(x) for x in mod["main"].checked_type.arg_types[0].shape]
    for i in range(5):
        data = np.random.uniform(size=input_shape)
        dataset.append({input_name: data})
    return dataset
dataset = get_calibration_dataset(model.mod, "x")
BASE_CFG = {
    "skip_conv_layers": [],
    "skip_dense_layers": False,
    "dtype_input": "int8",
    "dtype_weight": "int8",
    "dtype_activation": "int32",
}
with tvm.transform.PassContext(opt_level=3):
    with relay.quantize.qconfig(**BASE_CFG, calibrate_mode="percentile"):
        qmod = relay.quantize.quantize(model.mod, params=model.params, dataset=dataset)