onnx flatten

目录

onnx flatten#

%cd ../../..
import set_env
from d2py.utils.file import mkdir
temp_dir = ".temp"
mkdir(temp_dir)
/media/pc/data/lxw/ai/tvm-book/doc/tutorials/frontend
import torch
from torch import nn

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 16, 1, 1, 0, bias=False, groups=1)
        self.dense = nn.Linear(16384, 1000)
    def forward(self, x):
        x = self.conv1(x)
        x = x.flatten(1) # 或者 torch.flatten(x, 1)
        x = self.dense(x)
        return x

shape = 1, 3, 32, 32
x = torch.rand(*shape)

torch_model = Model()
# 导出模型
output_name = "flatten"
torch.onnx.export(
    torch_model,               # torch 模型
    x,                         # 模型输入或者对于多个输入,使用元组
    f"{temp_dir}/{output_name}.onnx",               # 模型保存的位置(可以是文件或类似文件的对象)
    export_params=True,        # 将训练后的参数权重存储在模型文件内
    opset_version=17,          # 导出模型的 ONNX 版本
    verbose=True,
    do_constant_folding=True,  # 是否执行常量折叠以进行优化
    input_names = ['data'],    # 模型的输入名称
    output_names = ['output'], # 模型的输出名称
    # dynamic_axes={'data' : {0 : 'batch_size'},    # 可变长度的轴
    #               'output' : {0 : 'batch_size'}}
)
Exported graph: graph(%data : Float(1, 3, 32, 32, strides=[3072, 1024, 32, 1], requires_grad=0, device=cpu),
      %conv1.weight : Float(16, 3, 1, 1, strides=[3, 1, 1, 1], requires_grad=1, device=cpu),
      %dense.weight : Float(1000, 16384, strides=[16384, 1], requires_grad=1, device=cpu),
      %dense.bias : Float(1000, strides=[1], requires_grad=1, device=cpu)):
  %/conv1/Conv_output_0 : Float(1, 16, 32, 32, strides=[16384, 1024, 32, 1], requires_grad=0, device=cpu) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[1, 1], pads=[0, 0, 0, 0], strides=[1, 1], onnx_name="/conv1/Conv"](%data, %conv1.weight), scope: __main__.Model::/torch.nn.modules.conv.Conv2d::conv1 # /media/pc/data/tmp/cache/conda/envs/xin/lib/python3.12/site-packages/torch/nn/modules/conv.py:456:0
  %/Flatten_output_0 : Float(1, 16384, strides=[16384, 1], requires_grad=1, device=cpu) = onnx::Flatten[axis=1, onnx_name="/Flatten"](%/conv1/Conv_output_0), scope: __main__.Model:: # /tmp/ipykernel_2988201/3855814829.py:11:0
  %output : Float(1, 1000, strides=[1000, 1], requires_grad=1, device=cpu) = onnx::Gemm[alpha=1., beta=1., transB=1, onnx_name="/dense/Gemm"](%/Flatten_output_0, %dense.weight, %dense.bias), scope: __main__.Model::/torch.nn.modules.linear.Linear::dense # /media/pc/data/tmp/cache/conda/envs/xin/lib/python3.12/site-packages/torch/nn/modules/linear.py:116:0
  return (%output)

模型结构#

import onnx
import tvm
from tvm import relay
onnx_model = onnx.load(f"{temp_dir}/{output_name}.onnx")
mod, params = relay.frontend.from_onnx(onnx_model, {"data": shape}, freeze_params=True)
# with tvm.transform.PassContext(opt_level=3):
#     mod = relay.quantize.prerequisite_optimize(mod, params)
mod.show()
def @main(%data: Tensor[(1, 3, 32, 32), float32] /* ty=Tensor[(1, 3, 32, 32), float32] span=/conv1/Conv.data:0:0 */) -> Tensor[(1, 1000), float32] {
  %0 = nn.conv2d(%data, meta[relay.Constant][0] /* ty=Tensor[(16, 3, 1, 1), float32] span=/conv1/Conv.conv1.weight:0:0 */, padding=[0, 0, 0, 0], channels=16, kernel_size=[1, 1]) /* ty=Tensor[(1, 16, 32, 32), float32] span=/conv1/Conv:0:0 */;
  %1 = nn.batch_flatten(%0) /* ty=Tensor[(1, 16384), float32] span=/Flatten:0:0 */;
  %2 = nn.dense(%1, meta[relay.Constant][1] /* ty=Tensor[(1000, 16384), float32] span=/dense/Gemm.dense.weight:0:0 */, units=1000) /* ty=Tensor[(1, 1000), float32] span=/dense/Gemm:0:0 */;
  add(%2, meta[relay.Constant][2] /* ty=Tensor[(1000), float32] span=/dense/Gemm.dense.bias:0:0 */) /* ty=Tensor[(1, 1000), float32] span=/dense/Gemm:0:0 */
}
with tvm.transform.PassContext(opt_level=3):
    with relay.quantize.qconfig(
        skip_conv_layers=[],
        # calibrate_mode="kl_divergence", 
        weight_scale="max",
        # round_for_shift=True,
        # rounding="TONEAREST", # "UPWARD" or "TONEAREST"
        # calibrate_skip_layers=[],
        skip_dense_layer=False,
    ):
        qmod = relay.quantize.quantize(mod, params)
qmod.show()
def @main(%data: Tensor[(1, 3, 32, 32), float32] /* ty=Tensor[(1, 3, 32, 32), float32] span=/conv1/Conv.data:0:0 */) -> Tensor[(1, 1000), float32] {
  %0 = multiply(%data, 16f /* ty=float32 */) /* ty=Tensor[(1, 3, 32, 32), float32] */;
  %1 = round(%0) /* ty=Tensor[(1, 3, 32, 32), float32] */;
  %2 = clip(%1, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 3, 32, 32), float32] */;
  %3 = cast(%2, dtype="int8") /* ty=Tensor[(1, 3, 32, 32), int8] */;
  %4 = nn.conv2d(%3, meta[relay.Constant][0] /* ty=Tensor[(16, 3, 1, 1), int8] */, padding=[0, 0, 0, 0], channels=16, kernel_size=[1, 1], out_dtype="int32") /* ty=Tensor[(1, 16, 32, 32), int32] */;
  %5 = cast(%4, dtype="int64") /* ty=Tensor[(1, 16, 32, 32), int64] */;
  %6 = fixed_point_multiply(%5, multiplier=1203700736, shift=-7) /* ty=Tensor[(1, 16, 32, 32), int64] */;
  %7 = clip(%6, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 16, 32, 32), int64] */;
  %8 = cast(%7, dtype="int32") /* ty=Tensor[(1, 16, 32, 32), int32] */;
  %9 = cast(%8, dtype="int8") /* ty=Tensor[(1, 16, 32, 32), int8] */;
  %10 = annotation.stop_fusion(%9) /* ty=Tensor[(1, 16, 32, 32), int8] */;
  %11 = nn.batch_flatten(%10) /* ty=Tensor[(1, 16384), int8] */;
  %12 = nn.dense(%11, meta[relay.Constant][1] /* ty=Tensor[(1000, 16384), int8] */, units=1000, out_dtype="int32") /* ty=Tensor[(1, 1000), int32] */;
  %13 = add(%12, meta[relay.Constant][2] /* ty=Tensor[(1000), int32] */) /* ty=Tensor[(1, 1000), int32] */;
  %14 = cast(%13, dtype="float32") /* ty=Tensor[(1, 1000), float32] */;
  multiply(%14, 3.8147e-06f /* ty=float32 */) /* ty=Tensor[(1, 1000), float32] */
}
from tvm.relay.testing import run_infer_type
func = run_infer_type(qmod["main"])
func.body.checked_type.shape
[1, 1000]
from tvm.relay import Call
from tvm.relay.op import op as _op
class FlattenMutator(tvm.relay.ExprMutator):
    def __init__(self):
        super().__init__()
        self.batch_flatten = _op.get("nn.batch_flatten")
        self.bitpack_end = _op.get("annotation.bitpack_end")
    def visit_call(self, call):
        new_fn = self.visit(call.op)
        new_args = [self.visit(arg) for arg in call.args]
        if new_fn == self.batch_flatten:
            print(new_fn)
        call = Call(new_fn, new_args, call.attrs, call.type_args, call.span)
        return call
func = qmod["main"]
transform = FlattenMutator()
transform.visit(func)
# print(func)
1
Op(multiply)
Op(round)
Op(clip)
Op(cast)
Op(nn.conv2d)
Op(cast)
Op(fixed_point_multiply)
Op(clip)
Op(cast)
Op(cast)
Op(annotation.stop_fusion)
Op(nn.batch_flatten)
Op(nn.dense)
Op(add)
Op(cast)
Op(multiply)
1
from tvm.relay.dataflow_pattern import (
    is_constant, is_op, is_tuple, wildcard, 
    is_tuple_get_item
)
def make_flatten_dense_pattern():
    r = is_op("nn.batch_flatten")(wildcard())
    r = is_op("nn.dense")(x, wildcard())
    # 激活函数
    return make_activate(r)
compiler_name = "ccompiler"
pattern_table = [
    (f"{compiler_name}.flatten_dense", make_flatten_dense_pattern()),
]
merge_passes = tvm.transform.Sequential([
    relay.transform.InferType(),
    relay.transform.MergeComposite(pattern_table),
    relay.transform.InferType(),
    FuseTransform(),
    relay.transform.InferType(),
])
with tvm.transform.PassContext(opt_level=3):
    mod = merge_passes(mod)