VTA Demo

VTA Demo#

加载前端模型:

import torch
import torchvision

model_name = "resnet18"
model = torchvision.models.resnet18(weights=torchvision.models.resnet.ResNet18_Weights.DEFAULT)
model = model.eval()

# 跟踪以获取 TorchScripted 模型
input_shape = [1, 3, 224, 224]
scripted_model = torch.jit.trace(model, torch.randn(input_shape)).eval()

利用给定的输入规范将 graph 编译为 llvm 目标:

%cd ..
import set_env
/media/pc/data/lxw/ai/tvm-book/doc/tutorials

加载测试图片:

from PIL import Image
from tvm.contrib.download import download_testdata
img_url = "https://github.com/dmlc/mxnet.js/blob/main/data/cat.png?raw=true"
img_path = download_testdata(img_url, "cat.png", module="data")
im = Image.open(img_path).resize((224, 224))
im
../../_images/a22a6d51c8acdef568447418e4591222b0965af0ae6f8805ee33a310db018135.png

预处理:

mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
import numpy as np

dtype = "float32"
data = (np.array(im)/256).astype(dtype=dtype)
data -= mean
data /= std
data = data.transpose((2, 0, 1)) # HWC => CHW
data = np.expand_dims(data, axis=0) # CHW => NCHW
shape = data.shape
print(f"data shape: {shape}")
data shape: (1, 3, 224, 224)

加载模型:

import tvm
from tvm import relay

input_name = "data"
shape = 1, 3, 224, 224
shape_list = [(input_name, shape)]
mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)

target = tvm.target.Target("llvm", host="llvm")
dev = tvm.cpu(0)
with tvm.transform.PassContext(opt_level=3):
    lib = relay.build(mod, target=target, params=params)
m = tvm.contrib.graph_executor.GraphModule(lib["default"](dev))
# 推理
m.run(**{input_name: data})
# 获取输出
float_output = m.get_output(0).numpy()
with torch.no_grad():
    torch_img = torch.from_numpy(data)
    output = model(torch_img).numpy()
# 验证浮点模型数值一致性
np.testing.assert_allclose(output, float_output, rtol=1e-07, atol=1e-5)

resnet18 模型量化#

with tvm.transform.PassContext(opt_level=3):
    with relay.quantize.qconfig(calibrate_mode="kl_divergence",
                                skip_conv_layers=[]):
        mod = relay.quantize.quantize(mod, params=params)
relay.quantize.qconfig?
Signature: relay.quantize.qconfig(**kwargs)
Docstring:
Configure the quantization behavior by setting config variables.

Parameters
---------
nbit_dict: dict of QAnnotateKind -> int
    Number of bit for every kind of annotate field.

calibrate_mode: str
    The calibration mode. 'global_scale' or 'kl_divergence'.
    global_scale: use global scale
    kl_divergence: find scales by kl divergence on the dataset.

global_scale: float
    The global scale for calibration.

weight_scale: str
    The way to calculate scales for weights (annotated with QAnnotateKind.WEIGHT).
    power2: Find the maximum of the absolute value of the tensor, and then round up to power
    of two.
    max: Find the maximum of the absolute value of the tensor

skip_dense_layer: boolean
    Whether to skip all nn.dense layer type. By default are skipped.

skip_conv_layers: list
    Specifying which layers to be skipped. Provide a list of indices
    that indicate which conv2d layers to leave untouched. Start from 0.

do_simulation: boolean
    Whether to do simulation with float operation only.

round_for_shift: boolean
    Whether to add bias for rounding during shift.

debug_enabled_ops: None or list of str
    Partially quantize specified operators for debugging. The default value
    is None, which means will try to call all operartors' annotate rewrite
    function.

rounding: "UPWARD" or "TONEAREST"
    Rounding direction for fixed point multiplications.

partition_conversions: 'disabled', 'enabled', or 'fully_integral'
    If set to 'enabled' or 'fully_integral', partitions a quantized
    result into a module containing
    a prefix function (consisting of input conversion into the quantized data space),
    a middle function (consisting of the core quantized network),
    a suffix function (consisting of output dequantization),
    and a main function (that calls the prefix, middle, and suffix functions in succession).
    If set to 'fully_integral' and there are unquantized operators in the result,
    an exception is raised.
    The default value is 'disabled'.

Returns
-------
config: QConfig
    The quantization configuration
File:      /media/pc/data/lxw/ai/tvm/python/tvm/relay/quantize/quantize.py
Type:      function