# PyTorch 模型

In [None]:
import sys
sys.path.extend([".."])
import set_env

In [2]:
from tvm_book.tools.frontends import Frontend, TrainInputConfig
from tvm_book.tools import display

pytorch 前端模型配置：

In [None]:
print(display.Tree("| ")("models/pytorch_demo"))

{icon}`fa-solid fa-folder-open` `pytorch_demo/` 文件夹下存在如下内容：

- {icon}`fa-solid fa-file` `save.py` 存储 PyTorch 模型为 `resnet18.pth`
    ```{include} models/pytorch_demo/save.py
    :code: python
    ```
- {icon}`fa-solid fa-file` `resnet18.pth` 存储 PyTorch 模型结构与参数
- {icon}`fa-solid fa-file` `config.toml` 存储 PyTorch 模型配置信息
    ```{include} models/pytorch_demo/config.toml
    :code: toml
    ```

In [4]:
from pathlib import Path
import toml

config_path = "models/pytorch_demo/config.toml"

config_path = Path(config_path) 
config = toml.load(config_path)
model_type = config['model']["model_type"]
if len(config['train_inputs']) == 1:
    input_config = config['train_inputs'][0]
if model_type == "pytorch":
    shape_dict = {input_config["name"]: input_config["shape"]}
    model = Frontend(model_type).load(f"{config_path.parent}/{config['model']['path']}", shape_dict=shape_dict)

In [5]:
import tvm
import tvm.relay as relay
from tvm.relay.build_module import bind_params_by_name
from tvm.ir.instrument import (
    PassTimingInstrument,
    pass_instrument,
)

In [6]:
import numpy as np
def get_calibration_dataset(mod, input_name):
    dataset = []
    input_shape = [int(x) for x in mod["main"].checked_type.arg_types[0].shape]
    for i in range(5):
        data = np.random.uniform(size=input_shape)
        dataset.append({input_name: data})
    return dataset

In [None]:
dataset = get_calibration_dataset(model.mod, "x")
BASE_CFG = {
    "skip_conv_layers": [],
    "skip_dense_layers": False,
    "dtype_input": "int8",
    "dtype_weight": "int8",
    "dtype_activation": "int32",
}
with tvm.transform.PassContext(opt_level=3):
    with relay.quantize.qconfig(**BASE_CFG, calibrate_mode="percentile"):
        qmod = relay.quantize.quantize(model.mod, params=model.params, dataset=dataset)