三种版本的 DFL

三种版本的 DFL#

import set_env
from d2py.utils.file import mkdir
root_dir = ".temp"
mkdir(f"{root_dir}/logs")
import torch
from torch.nn import functional as F
from torch import nn
from torch.onnx import OperatorExportTypes, utils


class DFLV1(nn.Module):
    """
    Integral module of Distribution Focal Loss (DFL).
    Proposed in Generalized Focal Loss https://ieeexplore.ieee.org/document/9792391
    """

    def __init__(self, c1=16):
        """Initialize a convolutional layer with a given number of input channels."""
        super().__init__()
        self.conv = nn.Conv2d(c1, 1, 1, bias=False).requires_grad_(False)
        x = torch.arange(c1, dtype=torch.float)
        self.conv.weight.data[:] = nn.Parameter(x.view(1, c1, 1, 1))
        self.c1 = c1

    def forward(self, x):
        """Applies a transformer layer on input tensor 'x' and returns a tensor."""
        b, c, a = x.shape  # batch, channels, anchors
        x = x.view(b, 4, self.c1, a)
        x = x.transpose(3, 1).transpose(2, 3)
        x = x.softmax(3)
        x = x.transpose(3, 1)
        x = self.conv(x)
        return x.view(b, 4, a)
    
class DFLV2(nn.Module):
    """
    Integral module of Distribution Focal Loss (DFL).
    Proposed in Generalized Focal Loss https://ieeexplore.ieee.org/document/9792391
    """

    def __init__(self, c1=16):
        """Initialize a convolutional layer with a given number of input channels."""
        super().__init__()
        self.conv = nn.Conv2d(c1, 1, 1, bias=False).requires_grad_(False)
        x = torch.arange(c1, dtype=torch.float)
        self.conv.weight.data[:] = nn.Parameter(x.view(1, c1, 1, 1))
        self.c1 = c1

    def forward(self, x):
        """Applies a transformer layer on input tensor 'x' and returns a tensor."""
        b, c, a = x.shape  # batch, channels, anchors
        return self.conv(x.view(b, 4, self.c1, a).transpose(2, 1).softmax(1)).view(b, 4, a)

class DFLV3(nn.Module):
    """
    Integral module of Distribution Focal Loss (DFL).
    Proposed in Generalized Focal Loss https://ieeexplore.ieee.org/document/9792391
    """

    def __init__(self, c1=16):
        """Initialize a convolutional layer with a given number of input channels."""
        super().__init__()
        self.conv = nn.Conv2d(c1, 1, 1, bias=False).requires_grad_(False)
        x = torch.arange(c1, dtype=torch.float)
        self.conv.weight.data[:] = nn.Parameter(x.view(1, c1, 1, 1))
        self.c1 = c1

    def forward(self, x):
        """Applies a transformer layer on input tensor 'x' and returns a tensor."""
        b, c, a = x.shape  # batch, channels, anchors
        return self.conv(x.view(b, self.c1, 4, a).softmax(1)).view(b, 4, a)

class DFL(nn.Module):
    def __init__(self, c1=16):
        super().__init__()
        self.conv0 = nn.Conv2d(3, 16, 1, bias=False)
        self.conv1 = nn.Conv2d(16, 64, 1, bias=False)
        self.conv2 = nn.Conv2d(16, 64, 1, bias=False)
        self.conv3 = nn.Conv2d(16, 100, 1, bias=False)
        self.v1 = DFLV1(c1)
        self.v2 = DFLV2(c1)
        self.v3 = DFLV3(c1)
    
    def forward(self, x):
        x = self.conv0(x)
        x1 = self.conv1(x).view(1, 64, -1)
        x2 = self.conv2(x).view(1, 64, -1)
        x3 = self.conv3(x).view(1, 64, -1)
        x1 = self.v1(x1)
        x2 = self.v2(x2)
        x3 = self.v3(x3)
        return x1, x2, x3
model = DFL().eval()

shape = 1, 3, 48, 80
xx = torch.rand(*shape, dtype=torch.float32, requires_grad=False)
# model = torch.jit.trace(model, xx)
# 导出模型
input_name = "data"
output_name = "dfl"
utils.export(
    model,               # torch 模型
    xx,                         # 模型输入或者对于多个输入,使用元组
    f"{root_dir}/{output_name}.onnx",               # 模型保存的位置(可以是文件或类似文件的对象)
    export_params=True,        # 将训练后的参数权重存储在模型文件内
    opset_version=17,          # 导出模型的 ONNX 版本
    do_constant_folding=True,  # 是否执行常量折叠以进行优化
    input_names = [input_name],    # 模型的输入名称
    output_names = ['output'], # 模型的输出名称
    keep_initializers_as_inputs=True,
    # export_modules_as_functions=True,
    verbose=True,
    operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH,
    # dynamic_axes={'data' : {0 : 'batch_size'},    # 可变长度的轴
    #               'output' : {0 : 'batch_size'}}
)
Exported graph: graph(%data : Float(1, 3, 48, 80, strides=[11520, 3840, 80, 1], requires_grad=0, device=cpu),
      %conv0.weight : Float(16, 3, 1, 1, strides=[3, 1, 1, 1], requires_grad=1, device=cpu),
      %conv1.weight : Float(64, 16, 1, 1, strides=[16, 1, 1, 1], requires_grad=1, device=cpu),
      %conv2.weight : Float(64, 16, 1, 1, strides=[16, 1, 1, 1], requires_grad=1, device=cpu),
      %conv3.weight : Float(100, 16, 1, 1, strides=[16, 1, 1, 1], requires_grad=1, device=cpu),
      %v1.conv.weight : Float(1, 16, 1, 1, strides=[16, 1, 1, 1], requires_grad=0, device=cpu),
      %v2.conv.weight : Float(1, 16, 1, 1, strides=[16, 1, 1, 1], requires_grad=0, device=cpu),
      %v3.conv.weight : Float(1, 16, 1, 1, strides=[16, 1, 1, 1], requires_grad=0, device=cpu)):
  %/conv0/Conv_output_0 : Float(1, 16, 48, 80, strides=[61440, 3840, 80, 1], requires_grad=0, device=cpu) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[1, 1], pads=[0, 0, 0, 0], strides=[1, 1], onnx_name="/conv0/Conv"](%data, %conv0.weight), scope: __main__.DFL::/torch.nn.modules.conv.Conv2d::conv0 # /media/pc/data/tmp/cache/conda/envs/py312x/lib/python3.12/site-packages/torch/nn/modules/conv.py:456:0
  %/conv1/Conv_output_0 : Float(1, 64, 48, 80, strides=[245760, 3840, 80, 1], requires_grad=0, device=cpu) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[1, 1], pads=[0, 0, 0, 0], strides=[1, 1], onnx_name="/conv1/Conv"](%/conv0/Conv_output_0, %conv1.weight), scope: __main__.DFL::/torch.nn.modules.conv.Conv2d::conv1 # /media/pc/data/tmp/cache/conda/envs/py312x/lib/python3.12/site-packages/torch/nn/modules/conv.py:456:0
  %/Constant_output_0 : Long(3, strides=[1], device=cpu) = onnx::Constant[value=  1  64  -1 [ CPULongType{3} ], onnx_name="/Constant"](), scope: __main__.DFL:: # /tmp/ipykernel_3296793/673152664.py:82:0
  %/Reshape_output_0 : Float(1, 64, 3840, strides=[245760, 3840, 1], requires_grad=1, device=cpu) = onnx::Reshape[allowzero=0, onnx_name="/Reshape"](%/conv1/Conv_output_0, %/Constant_output_0), scope: __main__.DFL:: # /tmp/ipykernel_3296793/673152664.py:82:0
  %/conv2/Conv_output_0 : Float(1, 64, 48, 80, strides=[245760, 3840, 80, 1], requires_grad=0, device=cpu) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[1, 1], pads=[0, 0, 0, 0], strides=[1, 1], onnx_name="/conv2/Conv"](%/conv0/Conv_output_0, %conv2.weight), scope: __main__.DFL::/torch.nn.modules.conv.Conv2d::conv2 # /media/pc/data/tmp/cache/conda/envs/py312x/lib/python3.12/site-packages/torch/nn/modules/conv.py:456:0
  %/Constant_1_output_0 : Long(3, strides=[1], device=cpu) = onnx::Constant[value=  1  64  -1 [ CPULongType{3} ], onnx_name="/Constant_1"](), scope: __main__.DFL:: # /tmp/ipykernel_3296793/673152664.py:83:0
  %/Reshape_1_output_0 : Float(1, 64, 3840, strides=[245760, 3840, 1], requires_grad=1, device=cpu) = onnx::Reshape[allowzero=0, onnx_name="/Reshape_1"](%/conv2/Conv_output_0, %/Constant_1_output_0), scope: __main__.DFL:: # /tmp/ipykernel_3296793/673152664.py:83:0
  %/conv3/Conv_output_0 : Float(1, 100, 48, 80, strides=[384000, 3840, 80, 1], requires_grad=0, device=cpu) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[1, 1], pads=[0, 0, 0, 0], strides=[1, 1], onnx_name="/conv3/Conv"](%/conv0/Conv_output_0, %conv3.weight), scope: __main__.DFL::/torch.nn.modules.conv.Conv2d::conv3 # /media/pc/data/tmp/cache/conda/envs/py312x/lib/python3.12/site-packages/torch/nn/modules/conv.py:456:0
  %/Constant_2_output_0 : Long(3, strides=[1], device=cpu) = onnx::Constant[value=  1  64  -1 [ CPULongType{3} ], onnx_name="/Constant_2"](), scope: __main__.DFL:: # /tmp/ipykernel_3296793/673152664.py:84:0
  %/Reshape_2_output_0 : Float(1, 64, 6000, strides=[384000, 6000, 1], requires_grad=1, device=cpu) = onnx::Reshape[allowzero=0, onnx_name="/Reshape_2"](%/conv3/Conv_output_0, %/Constant_2_output_0), scope: __main__.DFL:: # /tmp/ipykernel_3296793/673152664.py:84:0
  %/v1/Constant_output_0 : Long(4, strides=[1], requires_grad=0, device=cpu) = onnx::Constant[value=    1     4    16  3840 [ CPULongType{4} ], onnx_name="/v1/Constant"](), scope: __main__.DFL::/__main__.DFLV1::v1 # /tmp/ipykernel_3296793/673152664.py:24:0
  %/v1/Reshape_output_0 : Float(1, 4, 16, 3840, strides=[245760, 61440, 3840, 1], requires_grad=1, device=cpu) = onnx::Reshape[allowzero=0, onnx_name="/v1/Reshape"](%/Reshape_output_0, %/v1/Constant_output_0), scope: __main__.DFL::/__main__.DFLV1::v1 # /tmp/ipykernel_3296793/673152664.py:24:0
  %/v1/Transpose_output_0 : Float(1, 3840, 4, 16, strides=[245760, 1, 61440, 3840], requires_grad=1, device=cpu) = onnx::Transpose[perm=[0, 3, 1, 2], onnx_name="/v1/Transpose"](%/v1/Reshape_output_0), scope: __main__.DFL::/__main__.DFLV1::v1 # /tmp/ipykernel_3296793/673152664.py:25:0
  %/v1/Softmax_output_0 : Float(1, 3840, 4, 16, strides=[245760, 64, 16, 1], requires_grad=1, device=cpu) = onnx::Softmax[axis=3, onnx_name="/v1/Softmax"](%/v1/Transpose_output_0), scope: __main__.DFL::/__main__.DFLV1::v1 # /tmp/ipykernel_3296793/673152664.py:26:0
  %/v1/Transpose_1_output_0 : Float(1, 16, 4, 3840, strides=[245760, 1, 16, 64], requires_grad=1, device=cpu) = onnx::Transpose[perm=[0, 3, 2, 1], onnx_name="/v1/Transpose_1"](%/v1/Softmax_output_0), scope: __main__.DFL::/__main__.DFLV1::v1 # /tmp/ipykernel_3296793/673152664.py:27:0
  %/v1/conv/Conv_output_0 : Float(1, 1, 4, 3840, strides=[15360, 15360, 3840, 1], requires_grad=0, device=cpu) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[1, 1], pads=[0, 0, 0, 0], strides=[1, 1], onnx_name="/v1/conv/Conv"](%/v1/Transpose_1_output_0, %v1.conv.weight), scope: __main__.DFL::/__main__.DFLV1::v1/torch.nn.modules.conv.Conv2d::conv # /media/pc/data/tmp/cache/conda/envs/py312x/lib/python3.12/site-packages/torch/nn/modules/conv.py:456:0
  %/v1/Constant_1_output_0 : Long(3, strides=[1], requires_grad=0, device=cpu) = onnx::Constant[value=    1     4  3840 [ CPULongType{3} ], onnx_name="/v1/Constant_1"](), scope: __main__.DFL::/__main__.DFLV1::v1 # /tmp/ipykernel_3296793/673152664.py:29:0
  %output : Float(1, 4, 3840, strides=[15360, 3840, 1], requires_grad=1, device=cpu) = onnx::Reshape[allowzero=0, onnx_name="/v1/Reshape_1"](%/v1/conv/Conv_output_0, %/v1/Constant_1_output_0), scope: __main__.DFL::/__main__.DFLV1::v1 # /tmp/ipykernel_3296793/673152664.py:29:0
  %/v2/Constant_output_0 : Long(4, strides=[1], requires_grad=0, device=cpu) = onnx::Constant[value=    1     4    16  3840 [ CPULongType{4} ], onnx_name="/v2/Constant"](), scope: __main__.DFL::/__main__.DFLV2::v2 # /tmp/ipykernel_3296793/673152664.py:48:0
  %/v2/Reshape_output_0 : Float(1, 4, 16, 3840, strides=[245760, 61440, 3840, 1], requires_grad=1, device=cpu) = onnx::Reshape[allowzero=0, onnx_name="/v2/Reshape"](%/Reshape_1_output_0, %/v2/Constant_output_0), scope: __main__.DFL::/__main__.DFLV2::v2 # /tmp/ipykernel_3296793/673152664.py:48:0
  %/v2/Transpose_output_0 : Float(1, 16, 4, 3840, strides=[245760, 3840, 61440, 1], requires_grad=1, device=cpu) = onnx::Transpose[perm=[0, 2, 1, 3], onnx_name="/v2/Transpose"](%/v2/Reshape_output_0), scope: __main__.DFL::/__main__.DFLV2::v2 # /tmp/ipykernel_3296793/673152664.py:48:0
  %/v2/Softmax_output_0 : Float(1, 16, 4, 3840, strides=[245760, 15360, 3840, 1], requires_grad=1, device=cpu) = onnx::Softmax[axis=1, onnx_name="/v2/Softmax"](%/v2/Transpose_output_0), scope: __main__.DFL::/__main__.DFLV2::v2 # /tmp/ipykernel_3296793/673152664.py:48:0
  %/v2/conv/Conv_output_0 : Float(1, 1, 4, 3840, strides=[15360, 15360, 3840, 1], requires_grad=0, device=cpu) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[1, 1], pads=[0, 0, 0, 0], strides=[1, 1], onnx_name="/v2/conv/Conv"](%/v2/Softmax_output_0, %v2.conv.weight), scope: __main__.DFL::/__main__.DFLV2::v2/torch.nn.modules.conv.Conv2d::conv # /media/pc/data/tmp/cache/conda/envs/py312x/lib/python3.12/site-packages/torch/nn/modules/conv.py:456:0
  %/v2/Constant_1_output_0 : Long(3, strides=[1], requires_grad=0, device=cpu) = onnx::Constant[value=    1     4  3840 [ CPULongType{3} ], onnx_name="/v2/Constant_1"](), scope: __main__.DFL::/__main__.DFLV2::v2 # /tmp/ipykernel_3296793/673152664.py:48:0
  %66 : Float(1, 4, 3840, strides=[15360, 3840, 1], requires_grad=1, device=cpu) = onnx::Reshape[allowzero=0, onnx_name="/v2/Reshape_1"](%/v2/conv/Conv_output_0, %/v2/Constant_1_output_0), scope: __main__.DFL::/__main__.DFLV2::v2 # /tmp/ipykernel_3296793/673152664.py:48:0
  %/v3/Constant_output_0 : Long(4, strides=[1], requires_grad=0, device=cpu) = onnx::Constant[value=    1    16     4  6000 [ CPULongType{4} ], onnx_name="/v3/Constant"](), scope: __main__.DFL::/__main__.DFLV3::v3 # /tmp/ipykernel_3296793/673152664.py:67:0
  %/v3/Reshape_output_0 : Float(1, 16, 4, 6000, strides=[384000, 24000, 6000, 1], requires_grad=1, device=cpu) = onnx::Reshape[allowzero=0, onnx_name="/v3/Reshape"](%/Reshape_2_output_0, %/v3/Constant_output_0), scope: __main__.DFL::/__main__.DFLV3::v3 # /tmp/ipykernel_3296793/673152664.py:67:0
  %/v3/Softmax_output_0 : Float(1, 16, 4, 6000, strides=[384000, 24000, 6000, 1], requires_grad=1, device=cpu) = onnx::Softmax[axis=1, onnx_name="/v3/Softmax"](%/v3/Reshape_output_0), scope: __main__.DFL::/__main__.DFLV3::v3 # /tmp/ipykernel_3296793/673152664.py:67:0
  %/v3/conv/Conv_output_0 : Float(1, 1, 4, 6000, strides=[24000, 24000, 6000, 1], requires_grad=0, device=cpu) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[1, 1], pads=[0, 0, 0, 0], strides=[1, 1], onnx_name="/v3/conv/Conv"](%/v3/Softmax_output_0, %v3.conv.weight), scope: __main__.DFL::/__main__.DFLV3::v3/torch.nn.modules.conv.Conv2d::conv # /media/pc/data/tmp/cache/conda/envs/py312x/lib/python3.12/site-packages/torch/nn/modules/conv.py:456:0
  %/v3/Constant_1_output_0 : Long(3, strides=[1], requires_grad=0, device=cpu) = onnx::Constant[value=    1     4  6000 [ CPULongType{3} ], onnx_name="/v3/Constant_1"](), scope: __main__.DFL::/__main__.DFLV3::v3 # /tmp/ipykernel_3296793/673152664.py:67:0
  %88 : Float(1, 4, 6000, strides=[24000, 6000, 1], requires_grad=1, device=cpu) = onnx::Reshape[allowzero=0, onnx_name="/v3/Reshape_1"](%/v3/conv/Conv_output_0, %/v3/Constant_1_output_0), scope: __main__.DFL::/__main__.DFLV3::v3 # /tmp/ipykernel_3296793/673152664.py:67:0
  return (%output, %66, %88)
from copy import deepcopy
import onnx
import tvm
from tvm import relay
onnx_model = onnx.load(f"{root_dir}/{output_name}.onnx")
mod, params = relay.frontend.from_onnx(onnx_model, {"data": shape}, freeze_params=True)
mod = relay.transform.InferType()(mod)
origin_mod = deepcopy(mod)
# with tvm.transform.PassContext(opt_level=3):
#     mod = relay.quantize.prerequisite_optimize(mod, params)
mod.show()
def @main(%data: Tensor[(1, 3, 48, 80), float32] /* ty=Tensor[(1, 3, 48, 80), float32] span=/conv0/Conv.data:0:0 */) -> (Tensor[(1, 4, 3840), float32], Tensor[(1, 4, 3840), float32], Tensor[(1, 4, 6000), float32]) {
  %0 = nn.conv2d(%data, meta[relay.Constant][0] /* ty=Tensor[(16, 3, 1, 1), float32] span=/conv0/Conv.conv0.weight:0:0 */, padding=[0, 0, 0, 0], channels=16, kernel_size=[1, 1]) /* ty=Tensor[(1, 16, 48, 80), float32] span=/conv0/Conv:0:0 */;
  %1 = nn.conv2d(%0, meta[relay.Constant][1] /* ty=Tensor[(64, 16, 1, 1), float32] span=/conv1/Conv.conv1.weight:0:0 */, padding=[0, 0, 0, 0], channels=64, kernel_size=[1, 1]) /* ty=Tensor[(1, 64, 48, 80), float32] span=/conv1/Conv:0:0 */;
  %2 = reshape(%1, newshape=[1, 64, -1]) /* ty=Tensor[(1, 64, 3840), float32] span=/Reshape:0:0 */;
  %3 = reshape(%2, newshape=[1, 4, 16, 3840]) /* ty=Tensor[(1, 4, 16, 3840), float32] span=/v1/Reshape:0:0 */;
  %4 = transpose(%3, axes=[0, 3, 1, 2]) /* ty=Tensor[(1, 3840, 4, 16), float32] span=/v1/Transpose:0:0 */;
  %5 = nn.softmax(%4, axis=3) /* ty=Tensor[(1, 3840, 4, 16), float32] span=/v1/Softmax:0:0 */;
  %6 = transpose(%5, axes=[0, 3, 2, 1]) /* ty=Tensor[(1, 16, 4, 3840), float32] span=/v1/Transpose_1:0:0 */;
  %7 = nn.conv2d(%6, meta[relay.Constant][2] /* ty=Tensor[(1, 16, 1, 1), float32] span=/v1/conv/Conv.v1.conv.weight:0:0 */, padding=[0, 0, 0, 0], channels=1, kernel_size=[1, 1]) /* ty=Tensor[(1, 1, 4, 3840), float32] span=/v1/conv/Conv:0:0 */;
  %8 = nn.conv2d(%0, meta[relay.Constant][3] /* ty=Tensor[(64, 16, 1, 1), float32] span=/conv2/Conv.conv2.weight:0:0 */, padding=[0, 0, 0, 0], channels=64, kernel_size=[1, 1]) /* ty=Tensor[(1, 64, 48, 80), float32] span=/conv2/Conv:0:0 */;
  %9 = reshape(%8, newshape=[1, 64, -1]) /* ty=Tensor[(1, 64, 3840), float32] span=/Reshape_1:0:0 */;
  %10 = reshape(%9, newshape=[1, 4, 16, 3840]) /* ty=Tensor[(1, 4, 16, 3840), float32] span=/v2/Reshape:0:0 */;
  %11 = transpose(%10, axes=[0, 2, 1, 3]) /* ty=Tensor[(1, 16, 4, 3840), float32] span=/v2/Transpose:0:0 */;
  %12 = nn.softmax(%11, axis=1) /* ty=Tensor[(1, 16, 4, 3840), float32] span=/v2/Softmax:0:0 */;
  %13 = nn.conv2d(%12, meta[relay.Constant][4] /* ty=Tensor[(1, 16, 1, 1), float32] span=/v2/conv/Conv.v2.conv.weight:0:0 */, padding=[0, 0, 0, 0], channels=1, kernel_size=[1, 1]) /* ty=Tensor[(1, 1, 4, 3840), float32] span=/v2/conv/Conv:0:0 */;
  %14 = nn.conv2d(%0, meta[relay.Constant][5] /* ty=Tensor[(100, 16, 1, 1), float32] span=/conv3/Conv.conv3.weight:0:0 */, padding=[0, 0, 0, 0], channels=100, kernel_size=[1, 1]) /* ty=Tensor[(1, 100, 48, 80), float32] span=/conv3/Conv:0:0 */;
  %15 = reshape(%14, newshape=[1, 64, -1]) /* ty=Tensor[(1, 64, 6000), float32] span=/Reshape_2:0:0 */;
  %16 = reshape(%15, newshape=[1, 16, 4, 6000]) /* ty=Tensor[(1, 16, 4, 6000), float32] span=/v3/Reshape:0:0 */;
  %17 = nn.softmax(%16, axis=1) /* ty=Tensor[(1, 16, 4, 6000), float32] span=/v3/Softmax:0:0 */;
  %18 = nn.conv2d(%17, meta[relay.Constant][6] /* ty=Tensor[(1, 16, 1, 1), float32] span=/v3/conv/Conv.v3.conv.weight:0:0 */, padding=[0, 0, 0, 0], channels=1, kernel_size=[1, 1]) /* ty=Tensor[(1, 1, 4, 6000), float32] span=/v3/conv/Conv:0:0 */;
  %19 = reshape(%7, newshape=[1, 4, 3840]) /* ty=Tensor[(1, 4, 3840), float32] span=/v1/Reshape_1:0:0 */;
  %20 = reshape(%13, newshape=[1, 4, 3840]) /* ty=Tensor[(1, 4, 3840), float32] span=/v2/Reshape_1:0:0 */;
  %21 = reshape(%18, newshape=[1, 4, 6000]) /* ty=Tensor[(1, 4, 6000), float32] span=/v3/Reshape_1:0:0 */;
  (%19, %20, %21) /* ty=(Tensor[(1, 4, 3840), float32], Tensor[(1, 4, 3840), float32], Tensor[(1, 4, 6000), float32]) */
}

注册算子 vta_special.yolo_dfl#

import numpy as np
from tvm.relay.testing import run_infer_type
from tvm.relay.dataflow_pattern import (
    wildcard, is_op,
    is_constant,
    DFPatternCallback,
    rewrite
)
import tvm
from tvm.ir.attrs import DictAttrs
from tvm.relay import transform as _transform
from tvm import relay, te, topi
from tvm.relay.op import op as _op
from tvm.target import generic_func

@generic_func
def schedule_special_op(attrs, outs, target):
    with target:
        outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
        output = outs[0]
        sch = te.create_schedule(output.op)   
        return sch
def custom_yolo_dfl_rel(arg_types, attrs):
    assert len(arg_types) == 1, "type relation arg number mismatch!"
    if attrs:
        assert isinstance(attrs, DictAttrs)
    in_shape = attrs.in_shape
    bbox_size = 4
    assert in_shape[1]%bbox_size == 0
    out_shape = (in_shape[0], bbox_size, in_shape[2])
    return relay.TensorType(out_shape, "float32")

op_name = "vta_special.yolo_dfl"
_op.register(op_name, r"code(cal yolo_dfl.)code")
_op.get(op_name).set_num_inputs(1)
_op.get(op_name).add_argument("data", "Tensor", "The input data tensor.")
_op.get(op_name).set_attrs_type_key("DictAttrs")
_op.get(op_name).add_type_rel(op_name, custom_yolo_dfl_rel)
_op.get(op_name).set_support_level(1)
_op.register_pattern(op_name, _op.OpPattern.COMM_REDUCE)
_op.register_stateful(op_name, False) # 无状态算子
def yolo_dfl(x, channel, in_shape, version="v3", x_scale=-1, x_split=-1):
    attrs = tvm.ir.make_node(
        "DictAttrs",
        channel=channel, in_shape=in_shape, version=version,
        x_scale=x_scale, x_split=x_split,
    )
    return relay.Call(_op.get(op_name), [x], attrs=attrs, type_args=None, span=None)

@_op.register_compute(op_name)
def output_yolo_dfl_compute(attrs, inputs, out_type):
    """yolo_dfl Relay 计算"""
    assert len(inputs) == 1, "输入参数数量不为 1"
    x = inputs[0]
    b, c, a = attrs.in_shape # batch, channels, anchors 
    assert c % 4 == 0
    if x.dtype == "int8":
        x = topi.cast(x, "float32")
        x = topi.multiply(x, attrs.x_scale)
    w = topi.arange(0, attrs.channel, dtype="float32")
    w = topi.reshape(w, (1, attrs.channel, 1, 1))
    if attrs.version == "v3":
        x = topi.reshape(x, (b, c//4, 4, a))
        x = topi.nn.softmax(x, axis=1)
        x = topi.nn.conv2d(x, w, padding=[0, 0, 0, 0], strides=(1, 1), dilation=(1, 1))
        x = topi.reshape(x, (b, 4, a))
    elif attrs.version == "v2":
        x = topi.reshape(x, (b, 4, c//4, a))
        x = topi.transpose(x, [0, 2, 1, 3])
        x = topi.nn.softmax(x, axis=1)
        x = topi.nn.conv2d(x, w, padding=[0, 0, 0, 0], strides=(1, 1), dilation=(1, 1))
        x = topi.reshape(x, (b, 4, a))
    elif attrs.version == "v1":
        x = topi.reshape(x, (b, 4, c//4, a))
        x = topi.transpose(x, [0, 3, 1, 2])
        x = topi.nn.softmax(x, axis=3)
        x = topi.transpose(x, [0, 3, 2, 1])
        x = topi.nn.conv2d(x, w, padding=[0, 0, 0, 0], strides=(1, 1), dilation=(1, 1))
        x = topi.reshape(x, (b, 4, a))
    else:
        raise TypeError(f"暂未支持 {attrs.version}")
    return [x]

_op.register_schedule(op_name, schedule_special_op) # 定义调度
GenericFunc(0x9c088c0)
class DFLV1Rewrite(DFPatternCallback):
    def __init__(self):
        super().__init__()
        self.x = wildcard()
        self.reshape = is_op("reshape")(self.x)
        self.transpose = is_op("transpose")(self.reshape).has_attr({"axes": [0, 3, 1, 2]})
        self.softmax = is_op("nn.softmax")(self.transpose).has_attr({"axis": 3})
        self.transpose2 = is_op("transpose")(self.softmax).has_attr({"axes": [0, 3, 2, 1]})
        self.conv_weight = is_constant()
        self.conv = is_op("nn.conv2d")(self.transpose2, self.conv_weight)
        self.reshape2 = is_op("reshape")(self.conv)
        self.pattern = self.reshape2
        

    def callback(self, pre, post, node_map):
        x = node_map[self.x][0]
        conv_weight = node_map[self.conv_weight][0]
        # conv = node_map[self.conv][0]
        b, c, a = _transform.InferTypeLocal(x).shape # batch, channels, anchors 
        conv_weight_shape = _transform.InferTypeLocal(conv_weight).shape
        # print(f"conv_weight_shape[2]: {type(conv_weight_shape[2])}")
        # print(dict(conv.attrs))
        assert conv_weight_shape[0] == conv_weight_shape[2] == conv_weight_shape[3] == 1
        # x = yolo_dfl(x, int(conv_weight_shape[1]), (b, c, a))
        x = yolo_dfl(x, conv_weight_shape[1], (b, c, a), version="v1")
        return x
    
class DFLV2Rewrite(DFPatternCallback):
    def __init__(self):
        super().__init__()
        self.x = wildcard()
        self.reshape = is_op("reshape")(self.x)
        self.transpose = is_op("transpose")(self.reshape).has_attr({"axes": [0, 2, 1, 3]})
        self.softmax = is_op("nn.softmax")(self.transpose).has_attr({"axis": 1})
        self.conv_weight = is_constant()
        self.conv = is_op("nn.conv2d")(self.softmax, self.conv_weight)
        self.reshape2 = is_op("reshape")(self.conv)
        self.pattern = self.reshape2
        

    def callback(self, pre, post, node_map):
        x = node_map[self.x][0]
        conv_weight = node_map[self.conv_weight][0]
        b, c, a = _transform.InferTypeLocal(x).shape # batch, channels, anchors 
        conv_weight_shape = _transform.InferTypeLocal(conv_weight).shape
        assert conv_weight_shape[0] == conv_weight_shape[2] == conv_weight_shape[3] == 1
        x = yolo_dfl(x, conv_weight_shape[1], (b, c, a), version="v2")
        return x
    
    
class DFLV3Rewrite(DFPatternCallback):
    def __init__(self):
        super().__init__()
        self.x = wildcard()
        self.reshape = is_op("reshape")(self.x)
        self.softmax = is_op("nn.softmax")(self.reshape).has_attr({"axis": 1})
        self.conv_weight = is_constant()
        self.conv = is_op("nn.conv2d")(self.softmax, self.conv_weight)
        self.reshape2 = is_op("reshape")(self.conv)
        self.pattern = self.reshape2
        

    def callback(self, pre, post, node_map):
        x = node_map[self.x][0]
        conv_weight = node_map[self.conv_weight][0]
        b, c, a = _transform.InferTypeLocal(x).shape # batch, channels, anchors 
        conv_weight_shape = _transform.InferTypeLocal(conv_weight).shape
        assert conv_weight_shape[0] == conv_weight_shape[2] == conv_weight_shape[3] == 1
        x = yolo_dfl(x, conv_weight_shape[1], (b, c, a), version="v3")
        return x
mod["main"] = rewrite(DFLV1Rewrite(), mod["main"])
mod["main"] = rewrite(DFLV2Rewrite(), mod["main"])
mod["main"] = rewrite(DFLV3Rewrite(), mod["main"])
mod.show()
def @main(%data: Tensor[(1, 3, 48, 80), float32] /* ty=Tensor[(1, 3, 48, 80), float32] span=/conv0/Conv.data:0:0 */) -> (Tensor[(1, 4, 3840), float32], Tensor[(1, 4, 3840), float32], Tensor[(1, 4, 6000), float32]) {
  %0 = nn.conv2d(%data, meta[relay.Constant][0] /* ty=Tensor[(16, 3, 1, 1), float32] span=/conv0/Conv.conv0.weight:0:0 */, padding=[0, 0, 0, 0], channels=16, kernel_size=[1, 1]) /* ty=Tensor[(1, 16, 48, 80), float32] span=/conv0/Conv:0:0 */;
  %1 = nn.conv2d(%0, meta[relay.Constant][1] /* ty=Tensor[(64, 16, 1, 1), float32] span=/conv1/Conv.conv1.weight:0:0 */, padding=[0, 0, 0, 0], channels=64, kernel_size=[1, 1]) /* ty=Tensor[(1, 64, 48, 80), float32] span=/conv1/Conv:0:0 */;
  %2 = reshape(%1, newshape=[1, 64, -1]) /* ty=Tensor[(1, 64, 3840), float32] span=/Reshape:0:0 */;
  %3 = nn.conv2d(%0, meta[relay.Constant][2] /* ty=Tensor[(64, 16, 1, 1), float32] span=/conv2/Conv.conv2.weight:0:0 */, padding=[0, 0, 0, 0], channels=64, kernel_size=[1, 1]) /* ty=Tensor[(1, 64, 48, 80), float32] span=/conv2/Conv:0:0 */;
  %4 = reshape(%3, newshape=[1, 64, -1]) /* ty=Tensor[(1, 64, 3840), float32] span=/Reshape_1:0:0 */;
  %5 = nn.conv2d(%0, meta[relay.Constant][3] /* ty=Tensor[(100, 16, 1, 1), float32] span=/conv3/Conv.conv3.weight:0:0 */, padding=[0, 0, 0, 0], channels=100, kernel_size=[1, 1]) /* ty=Tensor[(1, 100, 48, 80), float32] span=/conv3/Conv:0:0 */;
  %6 = reshape(%5, newshape=[1, 64, -1]) /* ty=Tensor[(1, 64, 6000), float32] span=/Reshape_2:0:0 */;
  %7 = vta_special.yolo_dfl(%2, __dict__={"x_scale"=-1, "x_split"=-1, "in_shape"=[1, 64, 3840], "version"="v1", "channel"=16});
  %8 = vta_special.yolo_dfl(%4, __dict__={"x_scale"=-1, "x_split"=-1, "in_shape"=[1, 64, 3840], "version"="v2", "channel"=16});
  %9 = vta_special.yolo_dfl(%6, __dict__={"x_scale"=-1, "x_split"=-1, "in_shape"=[1, 64, 6000], "version"="v3", "channel"=16});
  (%7, %8, %9) /* ty=(Tensor[(1, 4, 3840), float32], Tensor[(1, 4, 3840), float32], Tensor[(1, 4, 6000), float32]) */
}
origin_mod.show()
def @main(%data: Tensor[(1, 3, 48, 80), float32] /* ty=Tensor[(1, 3, 48, 80), float32] span=/conv0/Conv.data:0:0 */) -> (Tensor[(1, 4, 3840), float32], Tensor[(1, 4, 3840), float32], Tensor[(1, 4, 6000), float32]) {
  %0 = nn.conv2d(%data, meta[relay.Constant][0] /* ty=Tensor[(16, 3, 1, 1), float32] span=/conv0/Conv.conv0.weight:0:0 */, padding=[0, 0, 0, 0], channels=16, kernel_size=[1, 1]) /* ty=Tensor[(1, 16, 48, 80), float32] span=/conv0/Conv:0:0 */;
  %1 = nn.conv2d(%0, meta[relay.Constant][1] /* ty=Tensor[(64, 16, 1, 1), float32] span=/conv1/Conv.conv1.weight:0:0 */, padding=[0, 0, 0, 0], channels=64, kernel_size=[1, 1]) /* ty=Tensor[(1, 64, 48, 80), float32] span=/conv1/Conv:0:0 */;
  %2 = reshape(%1, newshape=[1, 64, -1]) /* ty=Tensor[(1, 64, 3840), float32] span=/Reshape:0:0 */;
  %3 = reshape(%2, newshape=[1, 4, 16, 3840]) /* ty=Tensor[(1, 4, 16, 3840), float32] span=/v1/Reshape:0:0 */;
  %4 = transpose(%3, axes=[0, 3, 1, 2]) /* ty=Tensor[(1, 3840, 4, 16), float32] span=/v1/Transpose:0:0 */;
  %5 = nn.softmax(%4, axis=3) /* ty=Tensor[(1, 3840, 4, 16), float32] span=/v1/Softmax:0:0 */;
  %6 = transpose(%5, axes=[0, 3, 2, 1]) /* ty=Tensor[(1, 16, 4, 3840), float32] span=/v1/Transpose_1:0:0 */;
  %7 = nn.conv2d(%6, meta[relay.Constant][2] /* ty=Tensor[(1, 16, 1, 1), float32] span=/v1/conv/Conv.v1.conv.weight:0:0 */, padding=[0, 0, 0, 0], channels=1, kernel_size=[1, 1]) /* ty=Tensor[(1, 1, 4, 3840), float32] span=/v1/conv/Conv:0:0 */;
  %8 = nn.conv2d(%0, meta[relay.Constant][3] /* ty=Tensor[(64, 16, 1, 1), float32] span=/conv2/Conv.conv2.weight:0:0 */, padding=[0, 0, 0, 0], channels=64, kernel_size=[1, 1]) /* ty=Tensor[(1, 64, 48, 80), float32] span=/conv2/Conv:0:0 */;
  %9 = reshape(%8, newshape=[1, 64, -1]) /* ty=Tensor[(1, 64, 3840), float32] span=/Reshape_1:0:0 */;
  %10 = reshape(%9, newshape=[1, 4, 16, 3840]) /* ty=Tensor[(1, 4, 16, 3840), float32] span=/v2/Reshape:0:0 */;
  %11 = transpose(%10, axes=[0, 2, 1, 3]) /* ty=Tensor[(1, 16, 4, 3840), float32] span=/v2/Transpose:0:0 */;
  %12 = nn.softmax(%11, axis=1) /* ty=Tensor[(1, 16, 4, 3840), float32] span=/v2/Softmax:0:0 */;
  %13 = nn.conv2d(%12, meta[relay.Constant][4] /* ty=Tensor[(1, 16, 1, 1), float32] span=/v2/conv/Conv.v2.conv.weight:0:0 */, padding=[0, 0, 0, 0], channels=1, kernel_size=[1, 1]) /* ty=Tensor[(1, 1, 4, 3840), float32] span=/v2/conv/Conv:0:0 */;
  %14 = nn.conv2d(%0, meta[relay.Constant][5] /* ty=Tensor[(100, 16, 1, 1), float32] span=/conv3/Conv.conv3.weight:0:0 */, padding=[0, 0, 0, 0], channels=100, kernel_size=[1, 1]) /* ty=Tensor[(1, 100, 48, 80), float32] span=/conv3/Conv:0:0 */;
  %15 = reshape(%14, newshape=[1, 64, -1]) /* ty=Tensor[(1, 64, 6000), float32] span=/Reshape_2:0:0 */;
  %16 = reshape(%15, newshape=[1, 16, 4, 6000]) /* ty=Tensor[(1, 16, 4, 6000), float32] span=/v3/Reshape:0:0 */;
  %17 = nn.softmax(%16, axis=1) /* ty=Tensor[(1, 16, 4, 6000), float32] span=/v3/Softmax:0:0 */;
  %18 = nn.conv2d(%17, meta[relay.Constant][6] /* ty=Tensor[(1, 16, 1, 1), float32] span=/v3/conv/Conv.v3.conv.weight:0:0 */, padding=[0, 0, 0, 0], channels=1, kernel_size=[1, 1]) /* ty=Tensor[(1, 1, 4, 6000), float32] span=/v3/conv/Conv:0:0 */;
  %19 = reshape(%7, newshape=[1, 4, 3840]) /* ty=Tensor[(1, 4, 3840), float32] span=/v1/Reshape_1:0:0 */;
  %20 = reshape(%13, newshape=[1, 4, 3840]) /* ty=Tensor[(1, 4, 3840), float32] span=/v2/Reshape_1:0:0 */;
  %21 = reshape(%18, newshape=[1, 4, 6000]) /* ty=Tensor[(1, 4, 6000), float32] span=/v3/Reshape_1:0:0 */;
  (%19, %20, %21) /* ty=(Tensor[(1, 4, 3840), float32], Tensor[(1, 4, 3840), float32], Tensor[(1, 4, 6000), float32]) */
}
data = np.random.normal(0, 1, size=shape).astype("float32")
with torch.no_grad():
    torch_outputs = [o.numpy() for o in model(torch.from_numpy(data))]

target = 'llvm'
dev = tvm.device(target, 0)

# 原始模型
with tvm.transform.PassContext(opt_level=3):
    lib = relay.build(origin_mod, target, params=params)
func = lib[lib.libmod_name]
module = tvm.contrib.graph_executor.GraphModule(func(dev))
module.run(**{input_name: data})
outputs1 = [module.get_output(k).numpy() for k in range(3)]

# 重写后的模型
with tvm.transform.PassContext(opt_level=3):
    lib = relay.build(mod, target, params=params)
func = lib[lib.libmod_name]
module = tvm.contrib.graph_executor.GraphModule(func(dev))
module.run(**{input_name: data})
outputs2 = [module.get_output(k).numpy() for k in range(3)]

[np.testing.assert_allclose(torch_outputs[k], outputs1[k], rtol=1e-07, atol=1e-5) for k in range(3)]
[np.testing.assert_allclose(torch_outputs[k], outputs2[k], rtol=1e-07, atol=1e-5) for k in range(3)]
[np.testing.assert_equal(outputs1[k], outputs2[k]) for k in range(3)];
One or more operators have not been tuned. Please tune your model for better performance. Use DEBUG logging level to see more details.
mod.show()
def @main(%data: Tensor[(1, 3, 48, 80), float32] /* ty=Tensor[(1, 3, 48, 80), float32] span=/conv0/Conv.data:0:0 */) -> (Tensor[(1, 4, 3840), float32], Tensor[(1, 4, 3840), float32], Tensor[(1, 4, 6000), float32]) {
  %0 = nn.conv2d(%data, meta[relay.Constant][0] /* ty=Tensor[(16, 3, 1, 1), float32] span=/conv0/Conv.conv0.weight:0:0 */, padding=[0, 0, 0, 0], channels=16, kernel_size=[1, 1]) /* ty=Tensor[(1, 16, 48, 80), float32] span=/conv0/Conv:0:0 */;
  %1 = nn.conv2d(%0, meta[relay.Constant][1] /* ty=Tensor[(64, 16, 1, 1), float32] span=/conv1/Conv.conv1.weight:0:0 */, padding=[0, 0, 0, 0], channels=64, kernel_size=[1, 1]) /* ty=Tensor[(1, 64, 48, 80), float32] span=/conv1/Conv:0:0 */;
  %2 = reshape(%1, newshape=[1, 64, -1]) /* ty=Tensor[(1, 64, 3840), float32] span=/Reshape:0:0 */;
  %3 = nn.conv2d(%0, meta[relay.Constant][2] /* ty=Tensor[(64, 16, 1, 1), float32] span=/conv2/Conv.conv2.weight:0:0 */, padding=[0, 0, 0, 0], channels=64, kernel_size=[1, 1]) /* ty=Tensor[(1, 64, 48, 80), float32] span=/conv2/Conv:0:0 */;
  %4 = reshape(%3, newshape=[1, 64, -1]) /* ty=Tensor[(1, 64, 3840), float32] span=/Reshape_1:0:0 */;
  %5 = nn.conv2d(%0, meta[relay.Constant][3] /* ty=Tensor[(100, 16, 1, 1), float32] span=/conv3/Conv.conv3.weight:0:0 */, padding=[0, 0, 0, 0], channels=100, kernel_size=[1, 1]) /* ty=Tensor[(1, 100, 48, 80), float32] span=/conv3/Conv:0:0 */;
  %6 = reshape(%5, newshape=[1, 64, -1]) /* ty=Tensor[(1, 64, 6000), float32] span=/Reshape_2:0:0 */;
  %7 = vta_special.yolo_dfl(%2, __dict__={"x_scale"=-1, "x_split"=-1, "in_shape"=[1, 64, 3840], "version"="v1", "channel"=16});
  %8 = vta_special.yolo_dfl(%4, __dict__={"x_scale"=-1, "x_split"=-1, "in_shape"=[1, 64, 3840], "version"="v2", "channel"=16});
  %9 = vta_special.yolo_dfl(%6, __dict__={"x_scale"=-1, "x_split"=-1, "in_shape"=[1, 64, 6000], "version"="v3", "channel"=16});
  (%7, %8, %9) /* ty=(Tensor[(1, 4, 3840), float32], Tensor[(1, 4, 3840), float32], Tensor[(1, 4, 6000), float32]) */
}
from PIL import Image
import numpy as np

image = np.random.normal(0, 1, size=(48, 80, 3)).astype("uint8")
mean = (128,)
std = (256,)
data = (image - mean)/std
data = data.transpose((2, 0, 1))
data = np.expand_dims(data, 0).astype("float32")
images = np.expand_dims(image, 0)
images.tofile(f"{root_dir}/input.bin")
Image.fromarray(image).resize((112, 112))
../../../_images/2e397cc4fbe81bea6b6315861a39e379ec0b8b0bdad1faeca382d258342c15a7.png
from dataclasses import dataclass

@dataclass
class Dataset:
    input_name: str
    shape: tuple

    def __iter__(self):
        for _ in range(2):
            yield {self.input_name: data}
        # for _ in range(50):
        #     yield {self.input_name: np.random.normal(0, 1, size=self.shape).astype("float32")}

dataset = Dataset(input_name, shape)
with relay.quantize.qconfig(
    calibrate_mode="kl_divergence",
    weight_scale="max",
    # vta_adjust_scale = True,  # 跑 vta 时必须开启
    # vta_relu6_flag = True,
    # weight_per_channel_quantization=True,     # 逐通道量化
    # prelu_fuse=True,
    skip_conv_layers=[],
    skip_dense_layer=False,):
    # qmod, record_graphs = relay.quantize.quantize_debug(mod, params, dataset)
    qmod = relay.quantize.quantize(mod, params, dataset)
# QNN 量化的模型
origin_qmod = deepcopy(qmod)
with tvm.transform.PassContext(opt_level=3):
    lib = relay.build(origin_qmod, target, params=params)
func = lib[lib.libmod_name]
module = tvm.contrib.graph_executor.GraphModule(func(dev))
module.run(**{input_name: data})
origin_qoutputs = [module.get_output(k).numpy() for k in range(3)]
origin_qmod.show()
def @main(%data: Tensor[(1, 3, 48, 80), float32] /* ty=Tensor[(1, 3, 48, 80), float32] span=/conv0/Conv.data:0:0 */) -> (Tensor[(1, 4, 3840), float32], Tensor[(1, 4, 3840), float32], Tensor[(1, 4, 6000), float32]) {
  %0 = multiply(%data, 257.869f /* ty=float32 */) /* ty=Tensor[(1, 3, 48, 80), float32] */;
  %1 = round(%0) /* ty=Tensor[(1, 3, 48, 80), float32] */;
  %2 = clip(%1, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 3, 48, 80), float32] */;
  %3 = cast(%2, dtype="int8") /* ty=Tensor[(1, 3, 48, 80), int8] */;
  %4 = nn.conv2d(%3, meta[relay.Constant][0] /* ty=Tensor[(16, 3, 1, 1), int8] */, padding=[0, 0, 0, 0], channels=16, kernel_size=[1, 1], out_dtype="int32") /* ty=Tensor[(1, 16, 48, 80), int32] */;
  %5 = cast(%4, dtype="int64") /* ty=Tensor[(1, 16, 48, 80), int64] */;
  %6 = fixed_point_multiply(%5, multiplier=1712233856, shift=-8) /* ty=Tensor[(1, 16, 48, 80), int64] */;
  %7 = clip(%6, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 16, 48, 80), int64] */;
  %8 = cast(%7, dtype="int32") /* ty=Tensor[(1, 16, 48, 80), int32] */;
  %9 = cast(%8, dtype="int8") /* ty=Tensor[(1, 16, 48, 80), int8] */;
  %10 = annotation.stop_fusion(%9) /* ty=Tensor[(1, 16, 48, 80), int8] */;
  %11 = nn.conv2d(%10, meta[relay.Constant][1] /* ty=Tensor[(64, 16, 1, 1), int8] */, padding=[0, 0, 0, 0], channels=64, kernel_size=[1, 1], out_dtype="int32") /* ty=Tensor[(1, 64, 48, 80), int32] */;
  %12 = cast(%11, dtype="int64") /* ty=Tensor[(1, 64, 48, 80), int64] */;
  %13 = fixed_point_multiply(%12, multiplier=1878688768, shift=-8) /* ty=Tensor[(1, 64, 48, 80), int64] */;
  %14 = clip(%13, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 64, 48, 80), int64] */;
  %15 = cast(%14, dtype="int32") /* ty=Tensor[(1, 64, 48, 80), int32] */;
  %16 = cast(%15, dtype="int8") /* ty=Tensor[(1, 64, 48, 80), int8] */;
  %17 = annotation.stop_fusion(%16) /* ty=Tensor[(1, 64, 48, 80), int8] */;
  %18 = reshape(%17, newshape=[1, 64, -1]) /* ty=Tensor[(1, 64, 3840), int8] */;
  %19 = cast(%18, dtype="float32") /* ty=Tensor[(1, 64, 3840), float32] */;
  %20 = multiply(%19, 0.00317437f /* ty=float32 */) /* ty=Tensor[(1, 64, 3840), float32] */;
  %21 = cast(%8, dtype="int8") /* ty=Tensor[(1, 16, 48, 80), int8] */;
  %22 = annotation.stop_fusion(%21) /* ty=Tensor[(1, 16, 48, 80), int8] */;
  %23 = nn.conv2d(%22, meta[relay.Constant][2] /* ty=Tensor[(64, 16, 1, 1), int8] */, padding=[0, 0, 0, 0], channels=64, kernel_size=[1, 1], out_dtype="int32") /* ty=Tensor[(1, 64, 48, 80), int32] */;
  %24 = cast(%23, dtype="int64") /* ty=Tensor[(1, 64, 48, 80), int64] */;
  %25 = fixed_point_multiply(%24, multiplier=1496169600, shift=-8) /* ty=Tensor[(1, 64, 48, 80), int64] */;
  %26 = clip(%25, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 64, 48, 80), int64] */;
  %27 = cast(%26, dtype="int32") /* ty=Tensor[(1, 64, 48, 80), int32] */;
  %28 = cast(%27, dtype="int8") /* ty=Tensor[(1, 64, 48, 80), int8] */;
  %29 = annotation.stop_fusion(%28) /* ty=Tensor[(1, 64, 48, 80), int8] */;
  %30 = reshape(%29, newshape=[1, 64, -1]) /* ty=Tensor[(1, 64, 3840), int8] */;
  %31 = cast(%30, dtype="float32") /* ty=Tensor[(1, 64, 3840), float32] */;
  %32 = multiply(%31, 0.00398762f /* ty=float32 */) /* ty=Tensor[(1, 64, 3840), float32] */;
  %33 = cast(%8, dtype="int8") /* ty=Tensor[(1, 16, 48, 80), int8] */;
  %34 = annotation.stop_fusion(%33) /* ty=Tensor[(1, 16, 48, 80), int8] */;
  %35 = nn.conv2d(%34, meta[relay.Constant][3] /* ty=Tensor[(100, 16, 1, 1), int8] */, padding=[0, 0, 0, 0], channels=100, kernel_size=[1, 1], out_dtype="int32") /* ty=Tensor[(1, 100, 48, 80), int32] */;
  %36 = cast(%35, dtype="int64") /* ty=Tensor[(1, 100, 48, 80), int64] */;
  %37 = fixed_point_multiply(%36, multiplier=1540201216, shift=-8) /* ty=Tensor[(1, 100, 48, 80), int64] */;
  %38 = clip(%37, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 100, 48, 80), int64] */;
  %39 = cast(%38, dtype="int32") /* ty=Tensor[(1, 100, 48, 80), int32] */;
  %40 = cast(%39, dtype="int8") /* ty=Tensor[(1, 100, 48, 80), int8] */;
  %41 = annotation.stop_fusion(%40) /* ty=Tensor[(1, 100, 48, 80), int8] */;
  %42 = reshape(%41, newshape=[1, 64, -1]) /* ty=Tensor[(1, 64, 6000), int8] */;
  %43 = cast(%42, dtype="float32") /* ty=Tensor[(1, 64, 6000), float32] */;
  %44 = multiply(%43, 0.00387525f /* ty=float32 */) /* ty=Tensor[(1, 64, 6000), float32] */;
  %45 = vta_special.yolo_dfl(%20, __dict__={"x_scale"=-1, "in_shape"=[1, 64, 3840], "version"="v1", "x_split"=-1, "channel"=16}) /* ty=Tensor[(1, 4, 3840), float32] */;
  %46 = vta_special.yolo_dfl(%32, __dict__={"x_scale"=-1, "in_shape"=[1, 64, 3840], "version"="v2", "x_split"=-1, "channel"=16}) /* ty=Tensor[(1, 4, 3840), float32] */;
  %47 = vta_special.yolo_dfl(%44, __dict__={"x_scale"=-1, "in_shape"=[1, 64, 6000], "version"="v3", "x_split"=-1, "channel"=16}) /* ty=Tensor[(1, 4, 6000), float32] */;
  (%45, %46, %47) /* ty=(Tensor[(1, 4, 3840), float32], Tensor[(1, 4, 3840), float32], Tensor[(1, 4, 6000), float32]) */
}
class YoloDFLPredictRewrite(DFPatternCallback):
    """融合 (reshape)+cast+multiply+vta_special.vta_yolo_dfl_predict 以更新 x_scale
    """
    def __init__(self):
        super().__init__()
        self.x = wildcard()
        self.reshape = is_op("reshape")(self.x)
        self.cast = is_op("cast")(self.reshape|self.x).has_attr({"dtype": "float32"})
        self.data_scale = is_constant()
        self.multiply = is_op("multiply")(self.cast, self.data_scale)
        self.dfl_predict_call = is_op("vta_special.yolo_dfl")(self.multiply)
        self.pattern = self.dfl_predict_call

    def callback(self, pre, post, node_map):
        x = node_map[self.x][0]
        data_scale = node_map[self.data_scale][0]
        dfl_predict_call = node_map[self.dfl_predict_call][0]
        return yolo_dfl(
            x, dfl_predict_call.attrs.channel, 
            dfl_predict_call.attrs.in_shape, 
            dfl_predict_call.attrs.version,
            float(data_scale.data.numpy())
        )
qmod["main"] = rewrite(YoloDFLPredictRewrite(), qmod["main"])
qmod.show()
def @main(%data: Tensor[(1, 3, 48, 80), float32] /* ty=Tensor[(1, 3, 48, 80), float32] span=/conv0/Conv.data:0:0 */) -> (Tensor[(1, 4, 3840), float32], Tensor[(1, 4, 3840), float32], Tensor[(1, 4, 6000), float32]) {
  %0 = multiply(%data, 257.869f /* ty=float32 */) /* ty=Tensor[(1, 3, 48, 80), float32] */;
  %1 = round(%0) /* ty=Tensor[(1, 3, 48, 80), float32] */;
  %2 = clip(%1, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 3, 48, 80), float32] */;
  %3 = cast(%2, dtype="int8") /* ty=Tensor[(1, 3, 48, 80), int8] */;
  %4 = nn.conv2d(%3, meta[relay.Constant][0] /* ty=Tensor[(16, 3, 1, 1), int8] */, padding=[0, 0, 0, 0], channels=16, kernel_size=[1, 1], out_dtype="int32") /* ty=Tensor[(1, 16, 48, 80), int32] */;
  %5 = cast(%4, dtype="int64") /* ty=Tensor[(1, 16, 48, 80), int64] */;
  %6 = fixed_point_multiply(%5, multiplier=1712233856, shift=-8) /* ty=Tensor[(1, 16, 48, 80), int64] */;
  %7 = clip(%6, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 16, 48, 80), int64] */;
  %8 = cast(%7, dtype="int32") /* ty=Tensor[(1, 16, 48, 80), int32] */;
  %9 = cast(%8, dtype="int8") /* ty=Tensor[(1, 16, 48, 80), int8] */;
  %10 = annotation.stop_fusion(%9) /* ty=Tensor[(1, 16, 48, 80), int8] */;
  %11 = nn.conv2d(%10, meta[relay.Constant][1] /* ty=Tensor[(64, 16, 1, 1), int8] */, padding=[0, 0, 0, 0], channels=64, kernel_size=[1, 1], out_dtype="int32") /* ty=Tensor[(1, 64, 48, 80), int32] */;
  %12 = cast(%11, dtype="int64") /* ty=Tensor[(1, 64, 48, 80), int64] */;
  %13 = fixed_point_multiply(%12, multiplier=1878688768, shift=-8) /* ty=Tensor[(1, 64, 48, 80), int64] */;
  %14 = clip(%13, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 64, 48, 80), int64] */;
  %15 = cast(%14, dtype="int32") /* ty=Tensor[(1, 64, 48, 80), int32] */;
  %16 = cast(%15, dtype="int8") /* ty=Tensor[(1, 64, 48, 80), int8] */;
  %17 = annotation.stop_fusion(%16) /* ty=Tensor[(1, 64, 48, 80), int8] */;
  %18 = cast(%8, dtype="int8") /* ty=Tensor[(1, 16, 48, 80), int8] */;
  %19 = annotation.stop_fusion(%18) /* ty=Tensor[(1, 16, 48, 80), int8] */;
  %20 = nn.conv2d(%19, meta[relay.Constant][2] /* ty=Tensor[(64, 16, 1, 1), int8] */, padding=[0, 0, 0, 0], channels=64, kernel_size=[1, 1], out_dtype="int32") /* ty=Tensor[(1, 64, 48, 80), int32] */;
  %21 = cast(%20, dtype="int64") /* ty=Tensor[(1, 64, 48, 80), int64] */;
  %22 = fixed_point_multiply(%21, multiplier=1496169600, shift=-8) /* ty=Tensor[(1, 64, 48, 80), int64] */;
  %23 = clip(%22, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 64, 48, 80), int64] */;
  %24 = cast(%23, dtype="int32") /* ty=Tensor[(1, 64, 48, 80), int32] */;
  %25 = cast(%24, dtype="int8") /* ty=Tensor[(1, 64, 48, 80), int8] */;
  %26 = annotation.stop_fusion(%25) /* ty=Tensor[(1, 64, 48, 80), int8] */;
  %27 = cast(%8, dtype="int8") /* ty=Tensor[(1, 16, 48, 80), int8] */;
  %28 = annotation.stop_fusion(%27) /* ty=Tensor[(1, 16, 48, 80), int8] */;
  %29 = nn.conv2d(%28, meta[relay.Constant][3] /* ty=Tensor[(100, 16, 1, 1), int8] */, padding=[0, 0, 0, 0], channels=100, kernel_size=[1, 1], out_dtype="int32") /* ty=Tensor[(1, 100, 48, 80), int32] */;
  %30 = cast(%29, dtype="int64") /* ty=Tensor[(1, 100, 48, 80), int64] */;
  %31 = fixed_point_multiply(%30, multiplier=1540201216, shift=-8) /* ty=Tensor[(1, 100, 48, 80), int64] */;
  %32 = clip(%31, a_min=-127f, a_max=127f) /* ty=Tensor[(1, 100, 48, 80), int64] */;
  %33 = cast(%32, dtype="int32") /* ty=Tensor[(1, 100, 48, 80), int32] */;
  %34 = cast(%33, dtype="int8") /* ty=Tensor[(1, 100, 48, 80), int8] */;
  %35 = annotation.stop_fusion(%34) /* ty=Tensor[(1, 100, 48, 80), int8] */;
  %36 = vta_special.yolo_dfl(%17, __dict__={"x_scale"=0.00317437f, "x_split"=-1, "in_shape"=[1, 64, 3840], "version"="v1", "channel"=16});
  %37 = vta_special.yolo_dfl(%26, __dict__={"x_scale"=0.00398762f, "x_split"=-1, "in_shape"=[1, 64, 3840], "version"="v2", "channel"=16});
  %38 = vta_special.yolo_dfl(%35, __dict__={"x_scale"=0.00387525f, "x_split"=-1, "in_shape"=[1, 64, 6000], "version"="v3", "channel"=16});
  (%36, %37, %38) /* ty=(Tensor[(1, 4, 3840), float32], Tensor[(1, 4, 3840), float32], Tensor[(1, 4, 6000), float32]) */
}
with tvm.transform.PassContext(opt_level=3):
    lib = relay.build(qmod, target, params=params)
func = lib[lib.libmod_name]
module = tvm.contrib.graph_executor.GraphModule(func(dev))
module.run(**{input_name: data})
qoutputs = [module.get_output(k).numpy() for k in range(3)]
[np.testing.assert_almost_equal(origin_qoutputs[k], qoutputs[k]) for k in range(3)];