自定义 VTA 算子(python)

自定义 VTA 算子(python)#

from testing import viz_expr # 可视化 relay
from d2py.utils.file import mkdir
root_dir = ".temp"
mkdir(f"{root_dir}/logs")

自定义算子 为例展开 VTA 算子定义。

import numpy as np
import onnx
import tvm
from tvm import relay
import torch
from torch.nn import functional as F
from torch import nn
from torch.onnx import OperatorExportTypes, utils

class M(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(3, 8, 8, 1, 0, bias=False, groups=1)

    def forward(self, x):
        x = self.conv(x)
        b, c, h, w = x.shape
        x = x.view((b, h, w, c))
        x = F.softmax(x, dim=3)
        x = x.view((b, -1))
        return x

model = M()
model.eval()

shape = 1, 3, 8, 8
input_name = "data"
xx = torch.rand(*shape, dtype=torch.float32, requires_grad=False)
# model = torch.jit.trace(model, xx)
# 导出模型
output_name = "test"
utils.export(
    model,               # torch 模型
    xx,                         # 模型输入或者对于多个输入,使用元组
    f"{root_dir}/{output_name}.onnx",               # 模型保存的位置(可以是文件或类似文件的对象)
    export_params=True,        # 将训练后的参数权重存储在模型文件内
    opset_version=17,          # 导出模型的 ONNX 版本
    do_constant_folding=True,  # 是否执行常量折叠以进行优化
    input_names = [input_name],    # 模型的输入名称
    output_names = ['output'], # 模型的输出名称
    keep_initializers_as_inputs=True,
    # export_modules_as_functions=True,
    verbose=True,
    operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH,
    # dynamic_axes={'data' : {0 : 'batch_size'},    # 可变长度的轴
    #               'output' : {0 : 'batch_size'}}
)

onnx_model = onnx.load(f"{root_dir}/{output_name}.onnx")
mod, params = relay.frontend.from_onnx(onnx_model, {input_name: shape}, freeze_params=True)
mod = relay.transform.InferType()(mod)
mod.show()
Exported graph: graph(%data : Float(1, 3, 8, 8, strides=[192, 64, 8, 1], requires_grad=0, device=cpu),
      %conv.weight : Float(8, 3, 8, 8, strides=[192, 64, 8, 1], requires_grad=1, device=cpu)):
  %/conv/Conv_output_0 : Float(1, 8, 1, 1, strides=[8, 1, 1, 1], requires_grad=0, device=cpu) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[8, 8], pads=[0, 0, 0, 0], strides=[1, 1], onnx_name="/conv/Conv"](%data, %conv.weight), scope: __main__.M::/torch.nn.modules.conv.Conv2d::conv # /media/pc/data/tmp/cache/conda/envs/py312x/lib/python3.12/site-packages/torch/nn/modules/conv.py:456:0
  %/Constant_output_0 : Long(4, strides=[1], requires_grad=0, device=cpu) = onnx::Constant[value= 1  1  1  8 [ CPULongType{4} ], onnx_name="/Constant"](), scope: __main__.M:: # /tmp/ipykernel_2716226/852096207.py:14:0
  %/Reshape_output_0 : Float(1, 1, 1, 8, strides=[8, 8, 8, 1], requires_grad=1, device=cpu) = onnx::Reshape[allowzero=0, onnx_name="/Reshape"](%/conv/Conv_output_0, %/Constant_output_0), scope: __main__.M:: # /tmp/ipykernel_2716226/852096207.py:14:0
  %/Softmax_output_0 : Float(1, 1, 1, 8, strides=[8, 8, 8, 1], requires_grad=1, device=cpu) = onnx::Softmax[axis=3, onnx_name="/Softmax"](%/Reshape_output_0), scope: __main__.M:: # /media/pc/data/tmp/cache/conda/envs/py312x/lib/python3.12/site-packages/torch/nn/functional.py:1885:0
  %/Constant_1_output_0 : Long(2, strides=[1], requires_grad=0, device=cpu) = onnx::Constant[value= 1 -1 [ CPULongType{2} ], onnx_name="/Constant_1"](), scope: __main__.M:: # /tmp/ipykernel_2716226/852096207.py:16:0
  %output : Float(1, 8, strides=[8, 1], requires_grad=1, device=cpu) = onnx::Reshape[allowzero=0, onnx_name="/Reshape_1"](%/Softmax_output_0, %/Constant_1_output_0), scope: __main__.M:: # /tmp/ipykernel_2716226/852096207.py:16:0
  return (%output)
def @main(%data: Tensor[(1, 3, 8, 8), float32] /* ty=Tensor[(1, 3, 8, 8), float32] span=/conv/Conv.data:0:0 */) -> Tensor[(1, 8), float32] {
  %0 = nn.conv2d(%data, meta[relay.Constant][0] /* ty=Tensor[(8, 3, 8, 8), float32] span=/conv/Conv.conv.weight:0:0 */, padding=[0, 0, 0, 0], channels=8, kernel_size=[8, 8]) /* ty=Tensor[(1, 8, 1, 1), float32] span=/conv/Conv:0:0 */;
  %1 = reshape(%0, newshape=[1, 1, 1, 8]) /* ty=Tensor[(1, 1, 1, 8), float32] span=/Reshape:0:0 */;
  %2 = nn.softmax(%1, axis=3) /* ty=Tensor[(1, 1, 1, 8), float32] span=/Softmax:0:0 */;
  reshape(%2, newshape=[1, -1]) /* ty=Tensor[(1, 8), float32] span=/Reshape_1:0:0 */
}

简化:

from tvm.relay.dataflow_pattern import rewrite
from tvm_book.special.rewriter import Reshape4dSoftmaxReshape2dRewrite

expr = mod["main"]
expr = rewrite(Reshape4dSoftmaxReshape2dRewrite(), expr)
run_mod = tvm.IRModule.from_expr(expr)
run_mod = relay.transform.InferType()(run_mod)
run_mod.show()
def @main(%data: Tensor[(1, 3, 8, 8), float32] /* ty=Tensor[(1, 3, 8, 8), float32] span=/conv/Conv.data:0:0 */) -> Tensor[(1, 8), float32] {
  %0 = nn.conv2d(%data, meta[relay.Constant][0] /* ty=Tensor[(8, 3, 8, 8), float32] span=/conv/Conv.conv.weight:0:0 */, padding=[0, 0, 0, 0], channels=8, kernel_size=[8, 8]) /* ty=Tensor[(1, 8, 1, 1), float32] span=/conv/Conv:0:0 */;
  special.softmax_reshape(%0, __dict__={"axis"=1, "newshape"=[1, 8]}) /* ty=Tensor[(1, 8), float32] */
}

量化:

from dataclasses import dataclass

@dataclass
class Dataset:
    input_name: str
    shape: tuple

    def __iter__(self):
        for _ in range(2):
            yield {self.input_name: np.random.normal(0, 1, size=self.shape).astype("float32")}

dataset = Dataset(input_name, shape)
with relay.quantize.qconfig(
    calibrate_mode="kl_divergence",
    weight_scale="max",
    skip_conv_layers=[],
    skip_dense_layer=False,):
    qmod = relay.quantize.quantize(run_mod, params, dataset)
WARNING:autotvm:One or more operators have not been tuned. Please tune your model for better performance. Use DEBUG logging level to see more details.
qmod.show()
def @main(%data: Tensor[(1, 3, 8, 8), float32] /* ty=Tensor[(1, 3, 8, 8), float32] span=/conv/Conv.data:0:0 */) -> Tensor[(1, 8), float32] {
  %0 = multiply(%data, 52.3141f /* ty=float32 */) /* ty=Tensor[(1, 3, 8, 8), float32] */;
  %1 = round(%0) /* ty=Tensor[(1, 3, 8, 8), float32] */;
  %2 = clip(%1, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 3, 8, 8), float32] */;
  %3 = cast(%2, dtype="int8") /* ty=Tensor[(1, 3, 8, 8), int8] */;
  %4 = nn.conv2d(%3, meta[relay.Constant][0] /* ty=Tensor[(8, 3, 8, 8), int8] */, padding=[0, 0, 0, 0], channels=8, kernel_size=[8, 8], out_dtype="int32") /* ty=Tensor[(1, 8, 1, 1), int32] */;
  %5 = cast(%4, dtype="int64") /* ty=Tensor[(1, 8, 1, 1), int64] */;
  %6 = fixed_point_multiply(%5, multiplier=1511924864, shift=-9) /* ty=Tensor[(1, 8, 1, 1), int64] */;
  %7 = clip(%6, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 8, 1, 1), int64] */;
  %8 = cast(%7, dtype="int32") /* ty=Tensor[(1, 8, 1, 1), int32] */;
  %9 = cast(%8, dtype="int8") /* ty=Tensor[(1, 8, 1, 1), int8] */;
  %10 = annotation.stop_fusion(%9) /* ty=Tensor[(1, 8, 1, 1), int8] */;
  %11 = cast(%10, dtype="float32") /* ty=Tensor[(1, 8, 1, 1), float32] */;
  %12 = multiply(%11, 0.00783597f /* ty=float32 */) /* ty=Tensor[(1, 8, 1, 1), float32] */;
  special.softmax_reshape(%12, __dict__={"axis"=1, "newshape"=[1, 8]}) /* ty=Tensor[(1, 8), float32] */
}

重写预处理算子#

from tvm.relay.dataflow_pattern import rewrite
from tvm_book.special.rewriter._preprocess import PreprocessRewrite

expr = qmod["main"]
expr = rewrite(PreprocessRewrite(a_min=-127.0, a_max=127.0, dtype="int8"), expr)
run_qmod = tvm.IRModule.from_expr(expr)
run_qmod = relay.transform.InferType()(run_qmod)
run_qmod.show()
---------------------------------------------------------------------------
ModuleNotFoundError                       Traceback (most recent call last)
Cell In[7], line 2
      1 from tvm.relay.dataflow_pattern import rewrite
----> 2 from tvm_book.special.rewriter._preprocess import PreprocessRewrite
      4 expr = qmod["main"]
      5 expr = rewrite(PreprocessRewrite(a_min=-127.0, a_max=127.0, dtype="int8"), expr)

File /media/pc/data/lxw/ai/tvm-book/src/tvm_book/special/rewriter/_preprocess.py:8
      1 from tvm import relay
      2 from tvm.relay.dataflow_pattern import (
      3     wildcard, is_op,
      4     # FunctionPattern,
      5     DFPatternCallback,
      6     # rewrite
      7 )
----> 8 from .op import special_preprocess
     10 class PreprocessRewrite(DFPatternCallback):
     11     def __init__(self, a_min=-127.0, a_max=127.0, dtype="int8"):

ModuleNotFoundError: No module named 'tvm_book.special.rewriter.op'
data = np.random.normal(0, 1, size=shape).astype("float32")
torch_out = model(torch.from_numpy(data)).detach().numpy()


target = 'llvm'
dev = tvm.device(target, 0)

with tvm.transform.PassContext(opt_level=3):
    lib = relay.build(qmod, target, params=params)
func = lib[lib.libmod_name]
module = tvm.contrib.graph_executor.GraphModule(func(dev))
module.run(**{input_name: data})
output1 = module.get_output(0).numpy()

with tvm.transform.PassContext(opt_level=3):
    lib = relay.build(run_qmod, target, params=params)
func = lib[lib.libmod_name]
module = tvm.contrib.graph_executor.GraphModule(func(dev))
module.run(**{input_name: data})
output2 = module.get_output(0).numpy()
np.concatenate([torch_out, output1, output2]).T

graph pack 预处理算子#