自定义 VTA 算子(python)#
from testing import viz_expr # 可视化 relay
from d2py.utils.file import mkdir
root_dir = ".temp"
mkdir(f"{root_dir}/logs")
以 自定义算子 为例展开 VTA 算子定义。
import numpy as np
import onnx
import tvm
from tvm import relay
import torch
from torch.nn import functional as F
from torch import nn
from torch.onnx import OperatorExportTypes, utils
class M(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 8, 8, 1, 0, bias=False, groups=1)
def forward(self, x):
x = self.conv(x)
b, c, h, w = x.shape
x = x.view((b, h, w, c))
x = F.softmax(x, dim=3)
x = x.view((b, -1))
return x
model = M()
model.eval()
shape = 1, 3, 8, 8
input_name = "data"
xx = torch.rand(*shape, dtype=torch.float32, requires_grad=False)
# model = torch.jit.trace(model, xx)
# 导出模型
output_name = "test"
utils.export(
model, # torch 模型
xx, # 模型输入或者对于多个输入,使用元组
f"{root_dir}/{output_name}.onnx", # 模型保存的位置(可以是文件或类似文件的对象)
export_params=True, # 将训练后的参数权重存储在模型文件内
opset_version=17, # 导出模型的 ONNX 版本
do_constant_folding=True, # 是否执行常量折叠以进行优化
input_names = [input_name], # 模型的输入名称
output_names = ['output'], # 模型的输出名称
keep_initializers_as_inputs=True,
# export_modules_as_functions=True,
verbose=True,
operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH,
# dynamic_axes={'data' : {0 : 'batch_size'}, # 可变长度的轴
# 'output' : {0 : 'batch_size'}}
)
onnx_model = onnx.load(f"{root_dir}/{output_name}.onnx")
mod, params = relay.frontend.from_onnx(onnx_model, {input_name: shape}, freeze_params=True)
mod = relay.transform.InferType()(mod)
mod.show()
Exported graph: graph(%data : Float(1, 3, 8, 8, strides=[192, 64, 8, 1], requires_grad=0, device=cpu),
%conv.weight : Float(8, 3, 8, 8, strides=[192, 64, 8, 1], requires_grad=1, device=cpu)):
%/conv/Conv_output_0 : Float(1, 8, 1, 1, strides=[8, 1, 1, 1], requires_grad=0, device=cpu) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[8, 8], pads=[0, 0, 0, 0], strides=[1, 1], onnx_name="/conv/Conv"](%data, %conv.weight), scope: __main__.M::/torch.nn.modules.conv.Conv2d::conv # /media/pc/data/tmp/cache/conda/envs/py312x/lib/python3.12/site-packages/torch/nn/modules/conv.py:456:0
%/Constant_output_0 : Long(4, strides=[1], requires_grad=0, device=cpu) = onnx::Constant[value= 1 1 1 8 [ CPULongType{4} ], onnx_name="/Constant"](), scope: __main__.M:: # /tmp/ipykernel_2716226/852096207.py:14:0
%/Reshape_output_0 : Float(1, 1, 1, 8, strides=[8, 8, 8, 1], requires_grad=1, device=cpu) = onnx::Reshape[allowzero=0, onnx_name="/Reshape"](%/conv/Conv_output_0, %/Constant_output_0), scope: __main__.M:: # /tmp/ipykernel_2716226/852096207.py:14:0
%/Softmax_output_0 : Float(1, 1, 1, 8, strides=[8, 8, 8, 1], requires_grad=1, device=cpu) = onnx::Softmax[axis=3, onnx_name="/Softmax"](%/Reshape_output_0), scope: __main__.M:: # /media/pc/data/tmp/cache/conda/envs/py312x/lib/python3.12/site-packages/torch/nn/functional.py:1885:0
%/Constant_1_output_0 : Long(2, strides=[1], requires_grad=0, device=cpu) = onnx::Constant[value= 1 -1 [ CPULongType{2} ], onnx_name="/Constant_1"](), scope: __main__.M:: # /tmp/ipykernel_2716226/852096207.py:16:0
%output : Float(1, 8, strides=[8, 1], requires_grad=1, device=cpu) = onnx::Reshape[allowzero=0, onnx_name="/Reshape_1"](%/Softmax_output_0, %/Constant_1_output_0), scope: __main__.M:: # /tmp/ipykernel_2716226/852096207.py:16:0
return (%output)
def @main(%data: Tensor[(1, 3, 8, 8), float32] /* ty=Tensor[(1, 3, 8, 8), float32] span=/conv/Conv.data:0:0 */) -> Tensor[(1, 8), float32] {
%0 = nn.conv2d(%data, meta[relay.Constant][0] /* ty=Tensor[(8, 3, 8, 8), float32] span=/conv/Conv.conv.weight:0:0 */, padding=[0, 0, 0, 0], channels=8, kernel_size=[8, 8]) /* ty=Tensor[(1, 8, 1, 1), float32] span=/conv/Conv:0:0 */;
%1 = reshape(%0, newshape=[1, 1, 1, 8]) /* ty=Tensor[(1, 1, 1, 8), float32] span=/Reshape:0:0 */;
%2 = nn.softmax(%1, axis=3) /* ty=Tensor[(1, 1, 1, 8), float32] span=/Softmax:0:0 */;
reshape(%2, newshape=[1, -1]) /* ty=Tensor[(1, 8), float32] span=/Reshape_1:0:0 */
}
简化:
from tvm.relay.dataflow_pattern import rewrite
from tvm_book.special.rewriter import Reshape4dSoftmaxReshape2dRewrite
expr = mod["main"]
expr = rewrite(Reshape4dSoftmaxReshape2dRewrite(), expr)
run_mod = tvm.IRModule.from_expr(expr)
run_mod = relay.transform.InferType()(run_mod)
run_mod.show()
def @main(%data: Tensor[(1, 3, 8, 8), float32] /* ty=Tensor[(1, 3, 8, 8), float32] span=/conv/Conv.data:0:0 */) -> Tensor[(1, 8), float32] {
%0 = nn.conv2d(%data, meta[relay.Constant][0] /* ty=Tensor[(8, 3, 8, 8), float32] span=/conv/Conv.conv.weight:0:0 */, padding=[0, 0, 0, 0], channels=8, kernel_size=[8, 8]) /* ty=Tensor[(1, 8, 1, 1), float32] span=/conv/Conv:0:0 */;
special.softmax_reshape(%0, __dict__={"axis"=1, "newshape"=[1, 8]}) /* ty=Tensor[(1, 8), float32] */
}
量化:
from dataclasses import dataclass
@dataclass
class Dataset:
input_name: str
shape: tuple
def __iter__(self):
for _ in range(2):
yield {self.input_name: np.random.normal(0, 1, size=self.shape).astype("float32")}
dataset = Dataset(input_name, shape)
with relay.quantize.qconfig(
calibrate_mode="kl_divergence",
weight_scale="max",
skip_conv_layers=[],
skip_dense_layer=False,):
qmod = relay.quantize.quantize(run_mod, params, dataset)
WARNING:autotvm:One or more operators have not been tuned. Please tune your model for better performance. Use DEBUG logging level to see more details.
qmod.show()
def @main(%data: Tensor[(1, 3, 8, 8), float32] /* ty=Tensor[(1, 3, 8, 8), float32] span=/conv/Conv.data:0:0 */) -> Tensor[(1, 8), float32] {
%0 = multiply(%data, 52.3141f /* ty=float32 */) /* ty=Tensor[(1, 3, 8, 8), float32] */;
%1 = round(%0) /* ty=Tensor[(1, 3, 8, 8), float32] */;
%2 = clip(%1, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 3, 8, 8), float32] */;
%3 = cast(%2, dtype="int8") /* ty=Tensor[(1, 3, 8, 8), int8] */;
%4 = nn.conv2d(%3, meta[relay.Constant][0] /* ty=Tensor[(8, 3, 8, 8), int8] */, padding=[0, 0, 0, 0], channels=8, kernel_size=[8, 8], out_dtype="int32") /* ty=Tensor[(1, 8, 1, 1), int32] */;
%5 = cast(%4, dtype="int64") /* ty=Tensor[(1, 8, 1, 1), int64] */;
%6 = fixed_point_multiply(%5, multiplier=1511924864, shift=-9) /* ty=Tensor[(1, 8, 1, 1), int64] */;
%7 = clip(%6, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 8, 1, 1), int64] */;
%8 = cast(%7, dtype="int32") /* ty=Tensor[(1, 8, 1, 1), int32] */;
%9 = cast(%8, dtype="int8") /* ty=Tensor[(1, 8, 1, 1), int8] */;
%10 = annotation.stop_fusion(%9) /* ty=Tensor[(1, 8, 1, 1), int8] */;
%11 = cast(%10, dtype="float32") /* ty=Tensor[(1, 8, 1, 1), float32] */;
%12 = multiply(%11, 0.00783597f /* ty=float32 */) /* ty=Tensor[(1, 8, 1, 1), float32] */;
special.softmax_reshape(%12, __dict__={"axis"=1, "newshape"=[1, 8]}) /* ty=Tensor[(1, 8), float32] */
}
重写预处理算子#
from tvm.relay.dataflow_pattern import rewrite
from tvm_book.special.rewriter._preprocess import PreprocessRewrite
expr = qmod["main"]
expr = rewrite(PreprocessRewrite(a_min=-127.0, a_max=127.0, dtype="int8"), expr)
run_qmod = tvm.IRModule.from_expr(expr)
run_qmod = relay.transform.InferType()(run_qmod)
run_qmod.show()
---------------------------------------------------------------------------
ModuleNotFoundError Traceback (most recent call last)
Cell In[7], line 2
1 from tvm.relay.dataflow_pattern import rewrite
----> 2 from tvm_book.special.rewriter._preprocess import PreprocessRewrite
4 expr = qmod["main"]
5 expr = rewrite(PreprocessRewrite(a_min=-127.0, a_max=127.0, dtype="int8"), expr)
File /media/pc/data/lxw/ai/tvm-book/src/tvm_book/special/rewriter/_preprocess.py:8
1 from tvm import relay
2 from tvm.relay.dataflow_pattern import (
3 wildcard, is_op,
4 # FunctionPattern,
5 DFPatternCallback,
6 # rewrite
7 )
----> 8 from .op import special_preprocess
10 class PreprocessRewrite(DFPatternCallback):
11 def __init__(self, a_min=-127.0, a_max=127.0, dtype="int8"):
ModuleNotFoundError: No module named 'tvm_book.special.rewriter.op'
data = np.random.normal(0, 1, size=shape).astype("float32")
torch_out = model(torch.from_numpy(data)).detach().numpy()
target = 'llvm'
dev = tvm.device(target, 0)
with tvm.transform.PassContext(opt_level=3):
lib = relay.build(qmod, target, params=params)
func = lib[lib.libmod_name]
module = tvm.contrib.graph_executor.GraphModule(func(dev))
module.run(**{input_name: data})
output1 = module.get_output(0).numpy()
with tvm.transform.PassContext(opt_level=3):
lib = relay.build(run_qmod, target, params=params)
func = lib[lib.libmod_name]
module = tvm.contrib.graph_executor.GraphModule(func(dev))
module.run(**{input_name: data})
output2 = module.get_output(0).numpy()
np.concatenate([torch_out, output1, output2]).T