onnx conv+reshape#
%cd ../../..
import set_env
from d2py.utils.file import mkdir
temp_dir = ".temp"
mkdir(temp_dir)
/media/pc/data/lxw/ai/tvm-book/doc/tutorials/frontend
import torch
from torch import nn
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 16, 1, 1, 0, bias=False, groups=1)
def forward(self, x):
x = self.conv1(x)
# 下面两个等同
# x = x.view(x.size(0), -1)
x = torch.reshape(x, (x.size(0), -1))
return x
shape = 1, 3, 32, 32
x = torch.rand(*shape)
torch_model = Model()
# 导出模型
output_name = "conv-reshape"
torch.onnx.export(
torch_model, # torch 模型
x, # 模型输入或者对于多个输入,使用元组
f"{temp_dir}/{output_name}.onnx", # 模型保存的位置(可以是文件或类似文件的对象)
export_params=True, # 将训练后的参数权重存储在模型文件内
opset_version=17, # 导出模型的 ONNX 版本
do_constant_folding=True, # 是否执行常量折叠以进行优化
verbose=True,
input_names = ['data'], # 模型的输入名称
output_names = ['output'], # 模型的输出名称
# export_modules_as_functions=True,
dynamic_axes={'data' : {0 : 'batch_size'}, # 可变长度的轴
'output' : {0 : 'batch_size'}}
)
Exported graph: graph(%data : Float(*, 3, 32, 32, strides=[3072, 1024, 32, 1], requires_grad=0, device=cpu),
%conv1.weight : Float(16, 3, 1, 1, strides=[3, 1, 1, 1], requires_grad=1, device=cpu)):
%/conv1/Conv_output_0 : Float(*, 16, 32, 32, strides=[16384, 1024, 32, 1], requires_grad=0, device=cpu) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[1, 1], pads=[0, 0, 0, 0], strides=[1, 1], onnx_name="/conv1/Conv"](%data, %conv1.weight), scope: __main__.Model::/torch.nn.modules.conv.Conv2d::conv1 # /media/pc/data/tmp/cache/conda/envs/xin/lib/python3.12/site-packages/torch/nn/modules/conv.py:456:0
%/Shape_output_0 : Long(4, strides=[1], device=cpu) = onnx::Shape[onnx_name="/Shape"](%/conv1/Conv_output_0), scope: __main__.Model:: # /tmp/ipykernel_1214832/3273040045.py:12:0
%/Constant_output_0 : Long(device=cpu) = onnx::Constant[value={0}, onnx_name="/Constant"](), scope: __main__.Model:: # /tmp/ipykernel_1214832/3273040045.py:12:0
%/Gather_output_0 : Long(device=cpu) = onnx::Gather[axis=0, onnx_name="/Gather"](%/Shape_output_0, %/Constant_output_0), scope: __main__.Model:: # /tmp/ipykernel_1214832/3273040045.py:12:0
%onnx::Unsqueeze_7 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={0}]()
%/Unsqueeze_output_0 : Long(1, strides=[1], device=cpu) = onnx::Unsqueeze[onnx_name="/Unsqueeze"](%/Gather_output_0, %onnx::Unsqueeze_7), scope: __main__.Model::
%/Constant_1_output_0 : Long(1, strides=[1], requires_grad=0, device=cpu) = onnx::Constant[value={-1}, onnx_name="/Constant_1"](), scope: __main__.Model::
%/Concat_output_0 : Long(2, strides=[1], device=cpu) = onnx::Concat[axis=0, onnx_name="/Concat"](%/Unsqueeze_output_0, %/Constant_1_output_0), scope: __main__.Model:: # /tmp/ipykernel_1214832/3273040045.py:12:0
%output : Float(*, *, strides=[16384, 1], requires_grad=1, device=cpu) = onnx::Reshape[allowzero=0, onnx_name="/Reshape"](%/conv1/Conv_output_0, %/Concat_output_0), scope: __main__.Model:: # /tmp/ipykernel_1214832/3273040045.py:12:0
return (%output)
# from onnx import load_model, save_model
# from onnx.shape_inference import infer_shapes
# onnx_model = load_model(output_onnx_name)
# onnx_model = infer_shapes(onnx_model)
# save_model(onnx_model, "infered_test_net.onnx")
import onnx
import tvm
from tvm import relay
onnx_model = onnx.load(f"{temp_dir}/{output_name}.onnx")
mod, params = relay.frontend.from_onnx(onnx_model, {"data": shape}, freeze_params=True)
# with tvm.transform.PassContext(opt_level=3):
# mod = relay.quantize.prerequisite_optimize(mod, params)
mod.show()
def @main(%data: Tensor[(1, 3, 32, 32), float32] /* ty=Tensor[(1, 3, 32, 32), float32] span=/conv1/Conv.data:0:0 */) -> Tensor[(1, 16384), float32] {
%0 = nn.conv2d(%data, meta[relay.Constant][0] /* ty=Tensor[(16, 3, 1, 1), float32] span=/conv1/Conv.conv1.weight:0:0 */, padding=[0, 0, 0, 0], channels=16, kernel_size=[1, 1]) /* ty=Tensor[(1, 16, 32, 32), float32] span=/conv1/Conv:0:0 */;
reshape(%0, newshape=[1, -1]) /* ty=Tensor[(1, 16384), float32] span=/Reshape:0:0 */
}
with tvm.transform.PassContext(opt_level=3):
with relay.quantize.qconfig(
skip_conv_layers=[],
# calibrate_mode="kl_divergence",
weight_scale="max",
round_for_shift=True,
# rounding="TONEAREST", # "UPWARD" or "TONEAREST"
# calibrate_skip_layers=[],
skip_dense_layer=False,
):
qmod = relay.quantize.quantize(mod, params)
qmod.show()
def @main(%data: Tensor[(1, 3, 32, 32), float32] /* ty=Tensor[(1, 3, 32, 32), float32] span=/conv1/Conv.data:0:0 */) -> Tensor[(1, 16384), float32] {
%0 = multiply(%data, 16f /* ty=float32 */) /* ty=Tensor[(1, 3, 32, 32), float32] */;
%1 = round(%0) /* ty=Tensor[(1, 3, 32, 32), float32] */;
%2 = clip(%1, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 3, 32, 32), float32] */;
%3 = cast(%2, dtype="int8") /* ty=Tensor[(1, 3, 32, 32), int8] */;
%4 = nn.conv2d(%3, meta[relay.Constant][0] /* ty=Tensor[(16, 3, 1, 1), int8] */, padding=[0, 0, 0, 0], channels=16, kernel_size=[1, 1], out_dtype="int32") /* ty=Tensor[(1, 16, 32, 32), int32] */;
%5 = cast(%4, dtype="int64") /* ty=Tensor[(1, 16, 32, 32), int64] */;
%6 = fixed_point_multiply(%5, multiplier=1230307840, shift=-7) /* ty=Tensor[(1, 16, 32, 32), int64] */;
%7 = clip(%6, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 16, 32, 32), int64] */;
%8 = cast(%7, dtype="int32") /* ty=Tensor[(1, 16, 32, 32), int32] */;
%9 = cast(%8, dtype="int8") /* ty=Tensor[(1, 16, 32, 32), int8] */;
%10 = annotation.stop_fusion(%9) /* ty=Tensor[(1, 16, 32, 32), int8] */;
%11 = reshape(%10, newshape=[1, -1]) /* ty=Tensor[(1, 16384), int8] */;
%12 = cast(%11, dtype="float32") /* ty=Tensor[(1, 16384), float32] */;
multiply(%12, 0.0625f /* ty=float32 */) /* ty=Tensor[(1, 16384), float32] */
}
# from onnxscript import optimizer
# model = optimizer.optimize(onnx_model, onnx_shape_inference=False)
# onnx.save(
# model,
# f"{temp_dir}/{output_name}_opt.onnx",
# save_as_external_data=True,
# all_tensors_to_one_file=True,
# convert_attribute=True
# )