YOLO11 Relax 优化

YOLO11 Relax 优化#

import torch
torch.cuda.empty_cache()
from PIL import Image
import numpy as np
from ultralytics import YOLO

input_path = "images/vehicle-jaguar-f-type-car-red-cars-wallpaper.jpg"
im = Image.open(input_path) #.resize((384, 640))
yolo = YOLO("yolo11n.pt")
results = yolo(np.array(im), conf=0.25)
Image.fromarray(results[0].plot()).resize((320, 208))
0: 416x640 1 car, 15.4ms
Speed: 2.7ms preprocess, 15.4ms inference, 2.2ms postprocess per image at shape (1, 3, 416, 640)
../../../_images/e3c7dc420f8d77e9edf7d4a0bb6c9f74ee1ee0d208270de3d105aac04fbee6a3.png

预处理:

from PIL import Image
import numpy as np
import torch
from ultralytics.data.augment import LetterBox

imgsz = 640, 640
strides = yolo.model.stride
mean = (0,)
std = (255,)

letterbox = LetterBox(new_shape=imgsz, auto=False, scaleFill=False, scaleup=True, stride=32)
origin_image = np.asanyarray(Image.open(input_path))
letterbox_image = letterbox(image=origin_image)
xs = np.stack([letterbox_image - mean])
print(f"数据内存的连续性:{xs.flags["C_CONTIGUOUS"]}")
xs = xs.transpose((0, 3, 1, 2))  # BHWC to BCHW, (n, 3, h, w)
print(f"数据内存的连续性(transpose):{xs.flags["C_CONTIGUOUS"]}")
xs = np.ascontiguousarray(xs)  # contiguous
print(f"数据内存的连续性:{xs.flags["C_CONTIGUOUS"]}")
xs = (xs / std).astype("float32") # 归一化值域范围为 0.0 - 1.0
Image.fromarray(
    np.concatenate([letterbox_image, (xs[0]*std).astype("uint8").transpose((1, 2, 0))], axis=1)
).resize((640, 320,))
数据内存的连续性:True
数据内存的连续性(transpose):False
数据内存的连续性:True
../../../_images/885109deaa3c97aa4aed7edf2e4e0b5a506f4742ba6d9f13ee9c88d8d690fa40.png

后处理:

from ultralytics.utils import ops
from ultralytics.engine.results import Results

def postprocess(preds, img, orig_imgs, names, input_path, conf_thres=0.25, iou_thres=0.45,):
    """Post-processes predictions and returns a list of Results objects."""
    preds = ops.non_max_suppression(
        preds,
        conf_thres=conf_thres,
        iou_thres=iou_thres,
        # agnostic=self.args.agnostic_nms,
        # max_det=self.args.max_det,
        # classes=80,
    )

    results = []
    for i, pred in enumerate(preds):
        orig_img = orig_imgs[i]
        pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
        img_path = input_path
        results.append(Results(orig_img, path=img_path, names=names, boxes=pred))
    return results

ONNX 推理#

import onnxruntime
import onnx
onnx_model = onnx.load('yolo11n.onnx')
# 通过 ONNX 运行模型以获取预期结果
ort_session = onnxruntime.InferenceSession(
    onnx_model.SerializeToString(), providers=["CPUExecutionProvider"]
)
inputs = {"images": xs}
ort_output = ort_session.run([], inputs)

测试 YOLO ONNX Relax 前端#

import tvm
from tvm import relax
from tvm.relax.frontend.onnx import from_onnx
tvm_model = from_onnx(onnx_model, keep_params_in_input=False)
mod_actual = relax.transform.AnnotateTIROpPattern()(tvm_model)
mod_actual = relax.transform.FuseOps()(mod_actual)
mod_actual.show()
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func(private=True)
    def split(A: T.Buffer((T.int64(1), T.int64(32), T.int64(160), T.int64(160)), "float32"), T_split: T.Buffer((T.int64(1), T.int64(16), T.int64(160), T.int64(160)), "float32"), T_split_1: T.Buffer((T.int64(1), T.int64(16), T.int64(160), T.int64(160)), "float32")):
        T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(16), T.int64(160), T.int64(160)):
            with T.block("T_split"):
                v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                T.reads(A[v_ax0, v_ax1, v_ax2, v_ax3])
                T.writes(T_split[v_ax0, v_ax1, v_ax2, v_ax3])
                T_split[v_ax0, v_ax1, v_ax2, v_ax3] = A[v_ax0, v_ax1, v_ax2, v_ax3]
        for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(16), T.int64(160), T.int64(160)):
            with T.block("T_split_1"):
                v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                T.reads(A[v_ax0, v_ax1 + T.int64(16), v_ax2, v_ax3])
                T.writes(T_split_1[v_ax0, v_ax1, v_ax2, v_ax3])
                T_split_1[v_ax0, v_ax1, v_ax2, v_ax3] = A[v_ax0, v_ax1 + T.int64(16), v_ax2, v_ax3]

    @T.prim_func(private=True)
    def split1(A: T.Buffer((T.int64(1), T.int64(64), T.int64(80), T.int64(80)), "float32"), T_split: T.Buffer((T.int64(1), T.int64(32), T.int64(80), T.int64(80)), "float32"), T_split_1: T.Buffer((T.int64(1), T.int64(32), T.int64(80), T.int64(80)), "float32")):
        T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(80), T.int64(80)):
            with T.block("T_split"):
                v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                T.reads(A[v_ax0, v_ax1, v_ax2, v_ax3])
                T.writes(T_split[v_ax0, v_ax1, v_ax2, v_ax3])
                T_split[v_ax0, v_ax1, v_ax2, v_ax3] = A[v_ax0, v_ax1, v_ax2, v_ax3]
        for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(80), T.int64(80)):
            with T.block("T_split_1"):
                v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                T.reads(A[v_ax0, v_ax1 + T.int64(32), v_ax2, v_ax3])
                T.writes(T_split_1[v_ax0, v_ax1, v_ax2, v_ax3])
                T_split_1[v_ax0, v_ax1, v_ax2, v_ax3] = A[v_ax0, v_ax1 + T.int64(32), v_ax2, v_ax3]

    @T.prim_func(private=True)
    def split2(A: T.Buffer((T.int64(1), T.int64(128), T.int64(40), T.int64(40)), "float32"), T_split: T.Buffer((T.int64(1), T.int64(64), T.int64(40), T.int64(40)), "float32"), T_split_1: T.Buffer((T.int64(1), T.int64(64), T.int64(40), T.int64(40)), "float32")):
        T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(64), T.int64(40), T.int64(40)):
            with T.block("T_split"):
                v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                T.reads(A[v_ax0, v_ax1, v_ax2, v_ax3])
                T.writes(T_split[v_ax0, v_ax1, v_ax2, v_ax3])
                T_split[v_ax0, v_ax1, v_ax2, v_ax3] = A[v_ax0, v_ax1, v_ax2, v_ax3]
        for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(64), T.int64(40), T.int64(40)):
            with T.block("T_split_1"):
                v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                T.reads(A[v_ax0, v_ax1 + T.int64(64), v_ax2, v_ax3])
                T.writes(T_split_1[v_ax0, v_ax1, v_ax2, v_ax3])
                T_split_1[v_ax0, v_ax1, v_ax2, v_ax3] = A[v_ax0, v_ax1 + T.int64(64), v_ax2, v_ax3]

    @T.prim_func(private=True)
    def split3(A: T.Buffer((T.int64(1), T.int64(256), T.int64(20), T.int64(20)), "float32"), T_split: T.Buffer((T.int64(1), T.int64(128), T.int64(20), T.int64(20)), "float32"), T_split_1: T.Buffer((T.int64(1), T.int64(128), T.int64(20), T.int64(20)), "float32")):
        T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(128), T.int64(20), T.int64(20)):
            with T.block("T_split"):
                v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                T.reads(A[v_ax0, v_ax1, v_ax2, v_ax3])
                T.writes(T_split[v_ax0, v_ax1, v_ax2, v_ax3])
                T_split[v_ax0, v_ax1, v_ax2, v_ax3] = A[v_ax0, v_ax1, v_ax2, v_ax3]
        for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(128), T.int64(20), T.int64(20)):
            with T.block("T_split_1"):
                v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                T.reads(A[v_ax0, v_ax1 + T.int64(128), v_ax2, v_ax3])
                T.writes(T_split_1[v_ax0, v_ax1, v_ax2, v_ax3])
                T_split_1[v_ax0, v_ax1, v_ax2, v_ax3] = A[v_ax0, v_ax1 + T.int64(128), v_ax2, v_ax3]

    @T.prim_func(private=True)
    def split4(A: T.Buffer((T.int64(1), T.int64(2), T.int64(128), T.int64(400)), "float32"), T_split: T.Buffer((T.int64(1), T.int64(2), T.int64(32), T.int64(400)), "float32"), T_split_1: T.Buffer((T.int64(1), T.int64(2), T.int64(32), T.int64(400)), "float32"), T_split_2: T.Buffer((T.int64(1), T.int64(2), T.int64(64), T.int64(400)), "float32")):
        T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(2), T.int64(32), T.int64(400)):
            with T.block("T_split"):
                v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                T.reads(A[v_ax0, v_ax1, v_ax2, v_ax3])
                T.writes(T_split[v_ax0, v_ax1, v_ax2, v_ax3])
                T_split[v_ax0, v_ax1, v_ax2, v_ax3] = A[v_ax0, v_ax1, v_ax2, v_ax3]
        for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(2), T.int64(32), T.int64(400)):
            with T.block("T_split_1"):
                v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                T.reads(A[v_ax0, v_ax1, v_ax2 + T.int64(32), v_ax3])
                T.writes(T_split_1[v_ax0, v_ax1, v_ax2, v_ax3])
                T_split_1[v_ax0, v_ax1, v_ax2, v_ax3] = A[v_ax0, v_ax1, v_ax2 + T.int64(32), v_ax3]
        for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(2), T.int64(64), T.int64(400)):
            with T.block("T_split_2"):
                v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                T.reads(A[v_ax0, v_ax1, v_ax2 + T.int64(64), v_ax3])
                T.writes(T_split_2[v_ax0, v_ax1, v_ax2, v_ax3])
                T_split_2[v_ax0, v_ax1, v_ax2, v_ax3] = A[v_ax0, v_ax1, v_ax2 + T.int64(64), v_ax3]

    @T.prim_func(private=True)
    def split5(A: T.Buffer((T.int64(1), T.int64(144), T.int64(8400)), "float32"), T_split: T.Buffer((T.int64(1), T.int64(64), T.int64(8400)), "float32"), T_split_1: T.Buffer((T.int64(1), T.int64(80), T.int64(8400)), "float32")):
        T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(64), T.int64(8400)):
            with T.block("T_split"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(A[v_ax0, v_ax1, v_ax2])
                T.writes(T_split[v_ax0, v_ax1, v_ax2])
                T_split[v_ax0, v_ax1, v_ax2] = A[v_ax0, v_ax1, v_ax2]
        for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(80), T.int64(8400)):
            with T.block("T_split_1"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(A[v_ax0, v_ax1 + T.int64(64), v_ax2])
                T.writes(T_split_1[v_ax0, v_ax1, v_ax2])
                T_split_1[v_ax0, v_ax1, v_ax2] = A[v_ax0, v_ax1 + T.int64(64), v_ax2]

    @R.function
    def main(images: R.Tensor((1, 3, 640, 640), dtype="float32")) -> R.Tensor((1, 84, 8400), dtype="float32"):
        R.func_attr({"num_input": 1})
        cls = Module
        with R.dataflow():
            lv: R.Tensor((1, 16, 320, 320), dtype="float32") = R.nn.conv2d(images, metadata["relax.expr.Constant"][0], strides=[2, 2], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            lv1: R.Tensor((1, 16, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][1], R.shape([1, 16, 1, 1]))
            lv2: R.Tensor((1, 16, 320, 320), dtype="float32") = R.add(lv, lv1)
            lv3: R.Tensor((1, 16, 320, 320), dtype="float32") = R.sigmoid(lv2)
            lv4: R.Tensor((1, 16, 320, 320), dtype="float32") = R.multiply(lv2, lv3)
            lv5: R.Tensor((1, 32, 160, 160), dtype="float32") = R.nn.conv2d(lv4, metadata["relax.expr.Constant"][2], strides=[2, 2], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            lv6: R.Tensor((1, 32, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][3], R.shape([1, 32, 1, 1]))
            lv7: R.Tensor((1, 32, 160, 160), dtype="float32") = R.add(lv5, lv6)
            lv8: R.Tensor((1, 32, 160, 160), dtype="float32") = R.sigmoid(lv7)
            lv9: R.Tensor((1, 32, 160, 160), dtype="float32") = R.multiply(lv7, lv8)
            lv10: R.Tensor((1, 32, 160, 160), dtype="float32") = R.nn.conv2d(lv9, metadata["relax.expr.Constant"][4], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            lv11: R.Tensor((1, 32, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][5], R.shape([1, 32, 1, 1]))
            lv12: R.Tensor((1, 32, 160, 160), dtype="float32") = R.add(lv10, lv11)
            lv13: R.Tensor((1, 32, 160, 160), dtype="float32") = R.sigmoid(lv12)
            lv14: R.Tensor((1, 32, 160, 160), dtype="float32") = R.multiply(lv12, lv13)
            lv15 = R.call_tir(cls.split, (lv14,), out_sinfo=[R.Tensor((1, 16, 160, 160), dtype="float32"), R.Tensor((1, 16, 160, 160), dtype="float32")])
            lv16: R.Tensor((1, 16, 160, 160), dtype="float32") = lv15[0]
            lv17: R.Tensor((1, 16, 160, 160), dtype="float32") = lv15[1]
            lv18: R.Tensor((1, 8, 160, 160), dtype="float32") = R.nn.conv2d(lv17, metadata["relax.expr.Constant"][6], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            lv19: R.Tensor((1, 8, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][7], R.shape([1, 8, 1, 1]))
            lv20: R.Tensor((1, 8, 160, 160), dtype="float32") = R.add(lv18, lv19)
            lv21: R.Tensor((1, 8, 160, 160), dtype="float32") = R.sigmoid(lv20)
            lv22: R.Tensor((1, 8, 160, 160), dtype="float32") = R.multiply(lv20, lv21)
            lv23: R.Tensor((1, 16, 160, 160), dtype="float32") = R.nn.conv2d(lv22, metadata["relax.expr.Constant"][8], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            lv24: R.Tensor((1, 16, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][9], R.shape([1, 16, 1, 1]))
            lv25: R.Tensor((1, 16, 160, 160), dtype="float32") = R.add(lv23, lv24)
            lv26: R.Tensor((1, 16, 160, 160), dtype="float32") = R.sigmoid(lv25)
            lv27: R.Tensor((1, 16, 160, 160), dtype="float32") = R.multiply(lv25, lv26)
            lv28: R.Tensor((1, 16, 160, 160), dtype="float32") = R.add(lv17, lv27)
            lv29: R.Tensor((1, 48, 160, 160), dtype="float32") = R.concat((lv16, lv17, lv28), axis=1)
            lv30: R.Tensor((1, 64, 160, 160), dtype="float32") = R.nn.conv2d(lv29, metadata["relax.expr.Constant"][10], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            lv31: R.Tensor((1, 64, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][11], R.shape([1, 64, 1, 1]))
            lv32: R.Tensor((1, 64, 160, 160), dtype="float32") = R.add(lv30, lv31)
            lv33: R.Tensor((1, 64, 160, 160), dtype="float32") = R.sigmoid(lv32)
            lv34: R.Tensor((1, 64, 160, 160), dtype="float32") = R.multiply(lv32, lv33)
            lv35: R.Tensor((1, 64, 80, 80), dtype="float32") = R.nn.conv2d(lv34, metadata["relax.expr.Constant"][12], strides=[2, 2], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            lv36: R.Tensor((1, 64, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][13], R.shape([1, 64, 1, 1]))
            lv37: R.Tensor((1, 64, 80, 80), dtype="float32") = R.add(lv35, lv36)
            lv38: R.Tensor((1, 64, 80, 80), dtype="float32") = R.sigmoid(lv37)
            lv39: R.Tensor((1, 64, 80, 80), dtype="float32") = R.multiply(lv37, lv38)
            lv40: R.Tensor((1, 64, 80, 80), dtype="float32") = R.nn.conv2d(lv39, metadata["relax.expr.Constant"][14], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            lv41: R.Tensor((1, 64, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][15], R.shape([1, 64, 1, 1]))
            lv42: R.Tensor((1, 64, 80, 80), dtype="float32") = R.add(lv40, lv41)
            lv43: R.Tensor((1, 64, 80, 80), dtype="float32") = R.sigmoid(lv42)
            lv44: R.Tensor((1, 64, 80, 80), dtype="float32") = R.multiply(lv42, lv43)
            lv45 = R.call_tir(cls.split1, (lv44,), out_sinfo=[R.Tensor((1, 32, 80, 80), dtype="float32"), R.Tensor((1, 32, 80, 80), dtype="float32")])
            lv46: R.Tensor((1, 32, 80, 80), dtype="float32") = lv45[0]
            lv47: R.Tensor((1, 32, 80, 80), dtype="float32") = lv45[1]
            lv48: R.Tensor((1, 16, 80, 80), dtype="float32") = R.nn.conv2d(lv47, metadata["relax.expr.Constant"][16], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            lv49: R.Tensor((1, 16, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][17], R.shape([1, 16, 1, 1]))
            lv50: R.Tensor((1, 16, 80, 80), dtype="float32") = R.add(lv48, lv49)
            lv51: R.Tensor((1, 16, 80, 80), dtype="float32") = R.sigmoid(lv50)
            lv52: R.Tensor((1, 16, 80, 80), dtype="float32") = R.multiply(lv50, lv51)
            lv53: R.Tensor((1, 32, 80, 80), dtype="float32") = R.nn.conv2d(lv52, metadata["relax.expr.Constant"][18], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            lv54: R.Tensor((1, 32, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][19], R.shape([1, 32, 1, 1]))
            lv55: R.Tensor((1, 32, 80, 80), dtype="float32") = R.add(lv53, lv54)
            lv56: R.Tensor((1, 32, 80, 80), dtype="float32") = R.sigmoid(lv55)
            lv57: R.Tensor((1, 32, 80, 80), dtype="float32") = R.multiply(lv55, lv56)
            lv58: R.Tensor((1, 32, 80, 80), dtype="float32") = R.add(lv47, lv57)
            lv59: R.Tensor((1, 96, 80, 80), dtype="float32") = R.concat((lv46, lv47, lv58), axis=1)
            lv60: R.Tensor((1, 128, 80, 80), dtype="float32") = R.nn.conv2d(lv59, metadata["relax.expr.Constant"][20], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            lv61: R.Tensor((1, 128, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][21], R.shape([1, 128, 1, 1]))
            lv62: R.Tensor((1, 128, 80, 80), dtype="float32") = R.add(lv60, lv61)
            lv63: R.Tensor((1, 128, 80, 80), dtype="float32") = R.sigmoid(lv62)
            lv64: R.Tensor((1, 128, 80, 80), dtype="float32") = R.multiply(lv62, lv63)
            lv65: R.Tensor((1, 128, 40, 40), dtype="float32") = R.nn.conv2d(lv64, metadata["relax.expr.Constant"][22], strides=[2, 2], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            lv66: R.Tensor((1, 128, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][23], R.shape([1, 128, 1, 1]))
            lv67: R.Tensor((1, 128, 40, 40), dtype="float32") = R.add(lv65, lv66)
            lv68: R.Tensor((1, 128, 40, 40), dtype="float32") = R.sigmoid(lv67)
            lv69: R.Tensor((1, 128, 40, 40), dtype="float32") = R.multiply(lv67, lv68)
            lv70: R.Tensor((1, 128, 40, 40), dtype="float32") = R.nn.conv2d(lv69, metadata["relax.expr.Constant"][24], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            lv71: R.Tensor((1, 128, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][25], R.shape([1, 128, 1, 1]))
            lv72: R.Tensor((1, 128, 40, 40), dtype="float32") = R.add(lv70, lv71)
            lv73: R.Tensor((1, 128, 40, 40), dtype="float32") = R.sigmoid(lv72)
            lv74: R.Tensor((1, 128, 40, 40), dtype="float32") = R.multiply(lv72, lv73)
            lv75 = R.call_tir(cls.split2, (lv74,), out_sinfo=[R.Tensor((1, 64, 40, 40), dtype="float32"), R.Tensor((1, 64, 40, 40), dtype="float32")])
            lv76: R.Tensor((1, 64, 40, 40), dtype="float32") = lv75[0]
            lv77: R.Tensor((1, 64, 40, 40), dtype="float32") = lv75[1]
            lv78: R.Tensor((1, 32, 40, 40), dtype="float32") = R.nn.conv2d(lv77, metadata["relax.expr.Constant"][26], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            lv79: R.Tensor((1, 32, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][27], R.shape([1, 32, 1, 1]))
            lv80: R.Tensor((1, 32, 40, 40), dtype="float32") = R.nn.conv2d(lv77, metadata["relax.expr.Constant"][28], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            lv81: R.Tensor((1, 32, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][29], R.shape([1, 32, 1, 1]))
            lv82: R.Tensor((1, 32, 40, 40), dtype="float32") = R.add(lv78, lv79)
            lv83: R.Tensor((1, 32, 40, 40), dtype="float32") = R.add(lv80, lv81)
            lv84: R.Tensor((1, 32, 40, 40), dtype="float32") = R.sigmoid(lv82)
            lv85: R.Tensor((1, 32, 40, 40), dtype="float32") = R.sigmoid(lv83)
            lv86: R.Tensor((1, 32, 40, 40), dtype="float32") = R.multiply(lv82, lv84)
            lv87: R.Tensor((1, 32, 40, 40), dtype="float32") = R.nn.conv2d(lv86, metadata["relax.expr.Constant"][30], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            lv88: R.Tensor((1, 32, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][31], R.shape([1, 32, 1, 1]))
            lv89: R.Tensor((1, 32, 40, 40), dtype="float32") = R.add(lv87, lv88)
            lv90: R.Tensor((1, 32, 40, 40), dtype="float32") = R.sigmoid(lv89)
            lv91: R.Tensor((1, 32, 40, 40), dtype="float32") = R.multiply(lv89, lv90)
            lv92: R.Tensor((1, 32, 40, 40), dtype="float32") = R.nn.conv2d(lv91, metadata["relax.expr.Constant"][32], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            lv93: R.Tensor((1, 32, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][33], R.shape([1, 32, 1, 1]))
            lv94: R.Tensor((1, 32, 40, 40), dtype="float32") = R.add(lv92, lv93)
            lv95: R.Tensor((1, 32, 40, 40), dtype="float32") = R.sigmoid(lv94)
            lv96: R.Tensor((1, 32, 40, 40), dtype="float32") = R.multiply(lv94, lv95)
            lv97: R.Tensor((1, 32, 40, 40), dtype="float32") = R.add(lv86, lv96)
            lv98: R.Tensor((1, 32, 40, 40), dtype="float32") = R.nn.conv2d(lv97, metadata["relax.expr.Constant"][34], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            lv99: R.Tensor((1, 32, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][35], R.shape([1, 32, 1, 1]))
            lv100: R.Tensor((1, 32, 40, 40), dtype="float32") = R.add(lv98, lv99)
            lv101: R.Tensor((1, 32, 40, 40), dtype="float32") = R.sigmoid(lv100)
            lv102: R.Tensor((1, 32, 40, 40), dtype="float32") = R.multiply(lv100, lv101)
            lv103: R.Tensor((1, 32, 40, 40), dtype="float32") = R.nn.conv2d(lv102, metadata["relax.expr.Constant"][36], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            lv104: R.Tensor((1, 32, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][37], R.shape([1, 32, 1, 1]))
            lv105: R.Tensor((1, 32, 40, 40), dtype="float32") = R.add(lv103, lv104)
            lv106: R.Tensor((1, 32, 40, 40), dtype="float32") = R.sigmoid(lv105)
            lv107: R.Tensor((1, 32, 40, 40), dtype="float32") = R.multiply(lv105, lv106)
            lv108: R.Tensor((1, 32, 40, 40), dtype="float32") = R.add(lv97, lv107)
            lv109: R.Tensor((1, 32, 40, 40), dtype="float32") = R.multiply(lv83, lv85)
            lv110: R.Tensor((1, 64, 40, 40), dtype="float32") = R.concat((lv108, lv109), axis=1)
            lv111: R.Tensor((1, 64, 40, 40), dtype="float32") = R.nn.conv2d(lv110, metadata["relax.expr.Constant"][38], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            lv112: R.Tensor((1, 64, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][39], R.shape([1, 64, 1, 1]))
            lv113: R.Tensor((1, 64, 40, 40), dtype="float32") = R.add(lv111, lv112)
            lv114: R.Tensor((1, 64, 40, 40), dtype="float32") = R.sigmoid(lv113)
            lv115: R.Tensor((1, 64, 40, 40), dtype="float32") = R.multiply(lv113, lv114)
            lv116: R.Tensor((1, 192, 40, 40), dtype="float32") = R.concat((lv76, lv77, lv115), axis=1)
            lv117: R.Tensor((1, 128, 40, 40), dtype="float32") = R.nn.conv2d(lv116, metadata["relax.expr.Constant"][40], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            lv118: R.Tensor((1, 128, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][41], R.shape([1, 128, 1, 1]))
            lv119: R.Tensor((1, 128, 40, 40), dtype="float32") = R.add(lv117, lv118)
            lv120: R.Tensor((1, 128, 40, 40), dtype="float32") = R.sigmoid(lv119)
            lv121: R.Tensor((1, 128, 40, 40), dtype="float32") = R.multiply(lv119, lv120)
            lv122: R.Tensor((1, 256, 20, 20), dtype="float32") = R.nn.conv2d(lv121, metadata["relax.expr.Constant"][42], strides=[2, 2], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            lv123: R.Tensor((1, 256, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][43], R.shape([1, 256, 1, 1]))
            lv124: R.Tensor((1, 256, 20, 20), dtype="float32") = R.add(lv122, lv123)
            lv125: R.Tensor((1, 256, 20, 20), dtype="float32") = R.sigmoid(lv124)
            lv126: R.Tensor((1, 256, 20, 20), dtype="float32") = R.multiply(lv124, lv125)
            lv127: R.Tensor((1, 256, 20, 20), dtype="float32") = R.nn.conv2d(lv126, metadata["relax.expr.Constant"][44], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            lv128: R.Tensor((1, 256, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][45], R.shape([1, 256, 1, 1]))
            lv129: R.Tensor((1, 256, 20, 20), dtype="float32") = R.add(lv127, lv128)
            lv130: R.Tensor((1, 256, 20, 20), dtype="float32") = R.sigmoid(lv129)
            lv131: R.Tensor((1, 256, 20, 20), dtype="float32") = R.multiply(lv129, lv130)
            lv132 = R.call_tir(cls.split3, (lv131,), out_sinfo=[R.Tensor((1, 128, 20, 20), dtype="float32"), R.Tensor((1, 128, 20, 20), dtype="float32")])
            lv133: R.Tensor((1, 128, 20, 20), dtype="float32") = lv132[0]
            lv134: R.Tensor((1, 128, 20, 20), dtype="float32") = lv132[1]
            lv135: R.Tensor((1, 64, 20, 20), dtype="float32") = R.nn.conv2d(lv134, metadata["relax.expr.Constant"][46], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            lv136: R.Tensor((1, 64, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][47], R.shape([1, 64, 1, 1]))
            lv137: R.Tensor((1, 64, 20, 20), dtype="float32") = R.nn.conv2d(lv134, metadata["relax.expr.Constant"][48], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            lv138: R.Tensor((1, 64, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][49], R.shape([1, 64, 1, 1]))
            lv139: R.Tensor((1, 64, 20, 20), dtype="float32") = R.add(lv135, lv136)
            lv140: R.Tensor((1, 64, 20, 20), dtype="float32") = R.add(lv137, lv138)
            lv141: R.Tensor((1, 64, 20, 20), dtype="float32") = R.sigmoid(lv139)
            lv142: R.Tensor((1, 64, 20, 20), dtype="float32") = R.sigmoid(lv140)
            lv143: R.Tensor((1, 64, 20, 20), dtype="float32") = R.multiply(lv139, lv141)
            lv144: R.Tensor((1, 64, 20, 20), dtype="float32") = R.nn.conv2d(lv143, metadata["relax.expr.Constant"][50], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            lv145: R.Tensor((1, 64, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][51], R.shape([1, 64, 1, 1]))
            lv146: R.Tensor((1, 64, 20, 20), dtype="float32") = R.add(lv144, lv145)
            lv147: R.Tensor((1, 64, 20, 20), dtype="float32") = R.sigmoid(lv146)
            lv148: R.Tensor((1, 64, 20, 20), dtype="float32") = R.multiply(lv146, lv147)
            lv149: R.Tensor((1, 64, 20, 20), dtype="float32") = R.nn.conv2d(lv148, metadata["relax.expr.Constant"][52], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            lv150: R.Tensor((1, 64, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][53], R.shape([1, 64, 1, 1]))
            lv151: R.Tensor((1, 64, 20, 20), dtype="float32") = R.add(lv149, lv150)
            lv152: R.Tensor((1, 64, 20, 20), dtype="float32") = R.sigmoid(lv151)
            lv153: R.Tensor((1, 64, 20, 20), dtype="float32") = R.multiply(lv151, lv152)
            lv154: R.Tensor((1, 64, 20, 20), dtype="float32") = R.add(lv143, lv153)
            lv155: R.Tensor((1, 64, 20, 20), dtype="float32") = R.nn.conv2d(lv154, metadata["relax.expr.Constant"][54], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            lv156: R.Tensor((1, 64, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][55], R.shape([1, 64, 1, 1]))
            lv157: R.Tensor((1, 64, 20, 20), dtype="float32") = R.add(lv155, lv156)
            lv158: R.Tensor((1, 64, 20, 20), dtype="float32") = R.sigmoid(lv157)
            lv159: R.Tensor((1, 64, 20, 20), dtype="float32") = R.multiply(lv157, lv158)
            lv160: R.Tensor((1, 64, 20, 20), dtype="float32") = R.nn.conv2d(lv159, metadata["relax.expr.Constant"][56], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            lv161: R.Tensor((1, 64, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][57], R.shape([1, 64, 1, 1]))
            lv162: R.Tensor((1, 64, 20, 20), dtype="float32") = R.add(lv160, lv161)
            lv163: R.Tensor((1, 64, 20, 20), dtype="float32") = R.sigmoid(lv162)
            lv164: R.Tensor((1, 64, 20, 20), dtype="float32") = R.multiply(lv162, lv163)
            lv165: R.Tensor((1, 64, 20, 20), dtype="float32") = R.add(lv154, lv164)
            lv166: R.Tensor((1, 64, 20, 20), dtype="float32") = R.multiply(lv140, lv142)
            lv167: R.Tensor((1, 128, 20, 20), dtype="float32") = R.concat((lv165, lv166), axis=1)
            lv168: R.Tensor((1, 128, 20, 20), dtype="float32") = R.nn.conv2d(lv167, metadata["relax.expr.Constant"][58], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            lv169: R.Tensor((1, 128, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][59], R.shape([1, 128, 1, 1]))
            lv170: R.Tensor((1, 128, 20, 20), dtype="float32") = R.add(lv168, lv169)
            lv171: R.Tensor((1, 128, 20, 20), dtype="float32") = R.sigmoid(lv170)
            lv172: R.Tensor((1, 128, 20, 20), dtype="float32") = R.multiply(lv170, lv171)
            lv173: R.Tensor((1, 384, 20, 20), dtype="float32") = R.concat((lv133, lv134, lv172), axis=1)
            lv174: R.Tensor((1, 256, 20, 20), dtype="float32") = R.nn.conv2d(lv173, metadata["relax.expr.Constant"][60], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            lv175: R.Tensor((1, 256, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][61], R.shape([1, 256, 1, 1]))
            lv176: R.Tensor((1, 256, 20, 20), dtype="float32") = R.add(lv174, lv175)
            lv177: R.Tensor((1, 256, 20, 20), dtype="float32") = R.sigmoid(lv176)
            lv178: R.Tensor((1, 256, 20, 20), dtype="float32") = R.multiply(lv176, lv177)
            lv179: R.Tensor((1, 128, 20, 20), dtype="float32") = R.nn.conv2d(lv178, metadata["relax.expr.Constant"][62], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            lv180: R.Tensor((1, 128, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][63], R.shape([1, 128, 1, 1]))
            lv181: R.Tensor((1, 128, 20, 20), dtype="float32") = R.add(lv179, lv180)
            lv182: R.Tensor((1, 128, 20, 20), dtype="float32") = R.sigmoid(lv181)
            lv183: R.Tensor((1, 128, 20, 20), dtype="float32") = R.multiply(lv181, lv182)
            lv184: R.Tensor((1, 128, 20, 20), dtype="float32") = R.nn.max_pool2d(lv183, pool_size=[5, 5], strides=[1, 1], dilation=[1, 1], padding=[2, 2, 2, 2], ceil_mode=False, count_include_pad=False, layout="NCHW", out_layout="NCHW")
            lv185: R.Tensor((1, 128, 20, 20), dtype="float32") = R.nn.max_pool2d(lv184, pool_size=[5, 5], strides=[1, 1], dilation=[1, 1], padding=[2, 2, 2, 2], ceil_mode=False, count_include_pad=False, layout="NCHW", out_layout="NCHW")
            lv186: R.Tensor((1, 128, 20, 20), dtype="float32") = R.nn.max_pool2d(lv185, pool_size=[5, 5], strides=[1, 1], dilation=[1, 1], padding=[2, 2, 2, 2], ceil_mode=False, count_include_pad=False, layout="NCHW", out_layout="NCHW")
            lv187: R.Tensor((1, 512, 20, 20), dtype="float32") = R.concat((lv183, lv184, lv185, lv186), axis=1)
            lv188: R.Tensor((1, 256, 20, 20), dtype="float32") = R.nn.conv2d(lv187, metadata["relax.expr.Constant"][64], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            lv189: R.Tensor((1, 256, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][65], R.shape([1, 256, 1, 1]))
            lv190: R.Tensor((1, 256, 20, 20), dtype="float32") = R.add(lv188, lv189)
            lv191: R.Tensor((1, 256, 20, 20), dtype="float32") = R.sigmoid(lv190)
            lv192: R.Tensor((1, 256, 20, 20), dtype="float32") = R.multiply(lv190, lv191)
            lv193: R.Tensor((1, 256, 20, 20), dtype="float32") = R.nn.conv2d(lv192, metadata["relax.expr.Constant"][66], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            lv194: R.Tensor((1, 256, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][67], R.shape([1, 256, 1, 1]))
            lv195: R.Tensor((1, 256, 20, 20), dtype="float32") = R.add(lv193, lv194)
            lv196: R.Tensor((1, 256, 20, 20), dtype="float32") = R.sigmoid(lv195)
            lv197: R.Tensor((1, 256, 20, 20), dtype="float32") = R.multiply(lv195, lv196)
            lv198 = R.call_tir(cls.split3, (lv197,), out_sinfo=[R.Tensor((1, 128, 20, 20), dtype="float32"), R.Tensor((1, 128, 20, 20), dtype="float32")])
            lv199: R.Tensor((1, 128, 20, 20), dtype="float32") = lv198[0]
            lv200: R.Tensor((1, 128, 20, 20), dtype="float32") = lv198[1]
            lv201: R.Tensor((1, 256, 20, 20), dtype="float32") = R.nn.conv2d(lv200, metadata["relax.expr.Constant"][68], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            lv202: R.Tensor((1, 256, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][69], R.shape([1, 256, 1, 1]))
            lv203: R.Tensor((1, 256, 20, 20), dtype="float32") = R.add(lv201, lv202)
            lv204: R.Tensor((1, 2, 128, 400), dtype="float32") = R.reshape(lv203, R.shape([1, 2, 128, 400]))
            lv205 = R.call_tir(cls.split4, (lv204,), out_sinfo=[R.Tensor((1, 2, 32, 400), dtype="float32"), R.Tensor((1, 2, 32, 400), dtype="float32"), R.Tensor((1, 2, 64, 400), dtype="float32")])
            lv206: R.Tensor((1, 2, 32, 400), dtype="float32") = lv205[0]
            lv207: R.Tensor((1, 2, 32, 400), dtype="float32") = lv205[1]
            lv208: R.Tensor((1, 2, 64, 400), dtype="float32") = lv205[2]
            lv209: R.Tensor((1, 2, 400, 32), dtype="float32") = R.permute_dims(lv206, axes=[0, 1, 3, 2])
            lv210: R.Tensor((1, 128, 20, 20), dtype="float32") = R.reshape(lv208, R.shape([1, 128, 20, 20]))
            lv211: R.Tensor((1, 128, 20, 20), dtype="float32") = R.nn.conv2d(lv210, metadata["relax.expr.Constant"][70], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=128, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            lv212: R.Tensor((1, 128, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][71], R.shape([1, 128, 1, 1]))
            lv213: R.Tensor((1, 2, 400, 400), dtype="float32") = R.matmul(lv209, lv207, out_dtype="void")
            lv214: R.Tensor((1, 2, 400, 400), dtype="float32") = R.multiply(lv213, R.const(0.1767766922712326, "float32"))
            lv215: R.Tensor((1, 2, 400, 400), dtype="float32") = R.nn.softmax(lv214, axis=-1)
            lv216: R.Tensor((1, 2, 400, 400), dtype="float32") = R.permute_dims(lv215, axes=[0, 1, 3, 2])
            lv217: R.Tensor((1, 2, 64, 400), dtype="float32") = R.matmul(lv208, lv216, out_dtype="void")
            lv218: R.Tensor((1, 128, 20, 20), dtype="float32") = R.reshape(lv217, R.shape([1, 128, 20, 20]))
            lv219: R.Tensor((1, 128, 20, 20), dtype="float32") = R.add(lv211, lv212)
            lv220: R.Tensor((1, 128, 20, 20), dtype="float32") = R.add(lv218, lv219)
            lv221: R.Tensor((1, 128, 20, 20), dtype="float32") = R.nn.conv2d(lv220, metadata["relax.expr.Constant"][72], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            lv222: R.Tensor((1, 128, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][73], R.shape([1, 128, 1, 1]))
            lv223: R.Tensor((1, 128, 20, 20), dtype="float32") = R.add(lv221, lv222)
            lv224: R.Tensor((1, 128, 20, 20), dtype="float32") = R.add(lv200, lv223)
            lv225: R.Tensor((1, 256, 20, 20), dtype="float32") = R.nn.conv2d(lv224, metadata["relax.expr.Constant"][74], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            lv226: R.Tensor((1, 256, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][75], R.shape([1, 256, 1, 1]))
            lv227: R.Tensor((1, 256, 20, 20), dtype="float32") = R.add(lv225, lv226)
            lv228: R.Tensor((1, 256, 20, 20), dtype="float32") = R.sigmoid(lv227)
            lv229: R.Tensor((1, 256, 20, 20), dtype="float32") = R.multiply(lv227, lv228)
            lv230: R.Tensor((1, 128, 20, 20), dtype="float32") = R.nn.conv2d(lv229, metadata["relax.expr.Constant"][76], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            lv231: R.Tensor((1, 128, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][77], R.shape([1, 128, 1, 1]))
            lv232: R.Tensor((1, 128, 20, 20), dtype="float32") = R.add(lv230, lv231)
            lv233: R.Tensor((1, 128, 20, 20), dtype="float32") = R.add(lv224, lv232)
            lv234: R.Tensor((1, 256, 20, 20), dtype="float32") = R.concat((lv199, lv233), axis=1)
            lv235: R.Tensor((1, 256, 20, 20), dtype="float32") = R.nn.conv2d(lv234, metadata["relax.expr.Constant"][78], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            lv236: R.Tensor((1, 256, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][79], R.shape([1, 256, 1, 1]))
            lv237: R.Tensor((1, 256, 20, 20), dtype="float32") = R.add(lv235, lv236)
            lv238: R.Tensor((1, 256, 20, 20), dtype="float32") = R.sigmoid(lv237)
            lv239: R.Tensor((1, 256, 20, 20), dtype="float32") = R.multiply(lv237, lv238)
            lv240: R.Tensor((1, 256, 40, 40), dtype="float32") = R.image.resize2d(lv239, R.shape([40, 40]), roi=[T.float32(0.0), T.float32(0.0), T.float32(0.0), T.float32(0.0)], layout="NCHW", method="nearest_neighbor", coordinate_transformation_mode="asymmetric", rounding_method="floor", cubic_alpha=-0.75, cubic_exclude=0, extrapolation_value=0.0, out_dtype="void")
            lv241: R.Tensor((1, 384, 40, 40), dtype="float32") = R.concat((lv240, lv121), axis=1)
            lv242: R.Tensor((1, 128, 40, 40), dtype="float32") = R.nn.conv2d(lv241, metadata["relax.expr.Constant"][80], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            lv243: R.Tensor((1, 128, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][81], R.shape([1, 128, 1, 1]))
            lv244: R.Tensor((1, 128, 40, 40), dtype="float32") = R.add(lv242, lv243)
            lv245: R.Tensor((1, 128, 40, 40), dtype="float32") = R.sigmoid(lv244)
            lv246: R.Tensor((1, 128, 40, 40), dtype="float32") = R.multiply(lv244, lv245)
            lv247 = R.call_tir(cls.split2, (lv246,), out_sinfo=[R.Tensor((1, 64, 40, 40), dtype="float32"), R.Tensor((1, 64, 40, 40), dtype="float32")])
            lv248: R.Tensor((1, 64, 40, 40), dtype="float32") = lv247[0]
            lv249: R.Tensor((1, 64, 40, 40), dtype="float32") = lv247[1]
            lv250: R.Tensor((1, 32, 40, 40), dtype="float32") = R.nn.conv2d(lv249, metadata["relax.expr.Constant"][82], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            lv251: R.Tensor((1, 32, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][83], R.shape([1, 32, 1, 1]))
            lv252: R.Tensor((1, 32, 40, 40), dtype="float32") = R.add(lv250, lv251)
            lv253: R.Tensor((1, 32, 40, 40), dtype="float32") = R.sigmoid(lv252)
            lv254: R.Tensor((1, 32, 40, 40), dtype="float32") = R.multiply(lv252, lv253)
            lv255: R.Tensor((1, 64, 40, 40), dtype="float32") = R.nn.conv2d(lv254, metadata["relax.expr.Constant"][84], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            lv256: R.Tensor((1, 64, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][85], R.shape([1, 64, 1, 1]))
            lv257: R.Tensor((1, 64, 40, 40), dtype="float32") = R.add(lv255, lv256)
            lv258: R.Tensor((1, 64, 40, 40), dtype="float32") = R.sigmoid(lv257)
            lv259: R.Tensor((1, 64, 40, 40), dtype="float32") = R.multiply(lv257, lv258)
            lv260: R.Tensor((1, 64, 40, 40), dtype="float32") = R.add(lv249, lv259)
            lv261: R.Tensor((1, 192, 40, 40), dtype="float32") = R.concat((lv248, lv249, lv260), axis=1)
            lv262: R.Tensor((1, 128, 40, 40), dtype="float32") = R.nn.conv2d(lv261, metadata["relax.expr.Constant"][86], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            lv263: R.Tensor((1, 128, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][87], R.shape([1, 128, 1, 1]))
            lv264: R.Tensor((1, 128, 40, 40), dtype="float32") = R.add(lv262, lv263)
            lv265: R.Tensor((1, 128, 40, 40), dtype="float32") = R.sigmoid(lv264)
            lv266: R.Tensor((1, 128, 40, 40), dtype="float32") = R.multiply(lv264, lv265)
            lv267: R.Tensor((1, 128, 80, 80), dtype="float32") = R.image.resize2d(lv266, R.shape([80, 80]), roi=[T.float32(0.0), T.float32(0.0), T.float32(0.0), T.float32(0.0)], layout="NCHW", method="nearest_neighbor", coordinate_transformation_mode="asymmetric", rounding_method="floor", cubic_alpha=-0.75, cubic_exclude=0, extrapolation_value=0.0, out_dtype="void")
            lv268: R.Tensor((1, 256, 80, 80), dtype="float32") = R.concat((lv267, lv64), axis=1)
            lv269: R.Tensor((1, 64, 80, 80), dtype="float32") = R.nn.conv2d(lv268, metadata["relax.expr.Constant"][88], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            lv270: R.Tensor((1, 64, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][89], R.shape([1, 64, 1, 1]))
            lv271: R.Tensor((1, 64, 80, 80), dtype="float32") = R.add(lv269, lv270)
            lv272: R.Tensor((1, 64, 80, 80), dtype="float32") = R.sigmoid(lv271)
            lv273: R.Tensor((1, 64, 80, 80), dtype="float32") = R.multiply(lv271, lv272)
            lv274 = R.call_tir(cls.split1, (lv273,), out_sinfo=[R.Tensor((1, 32, 80, 80), dtype="float32"), R.Tensor((1, 32, 80, 80), dtype="float32")])
            lv275: R.Tensor((1, 32, 80, 80), dtype="float32") = lv274[0]
            lv276: R.Tensor((1, 32, 80, 80), dtype="float32") = lv274[1]
            lv277: R.Tensor((1, 16, 80, 80), dtype="float32") = R.nn.conv2d(lv276, metadata["relax.expr.Constant"][90], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            lv278: R.Tensor((1, 16, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][91], R.shape([1, 16, 1, 1]))
            lv279: R.Tensor((1, 16, 80, 80), dtype="float32") = R.add(lv277, lv278)
            lv280: R.Tensor((1, 16, 80, 80), dtype="float32") = R.sigmoid(lv279)
            lv281: R.Tensor((1, 16, 80, 80), dtype="float32") = R.multiply(lv279, lv280)
            lv282: R.Tensor((1, 32, 80, 80), dtype="float32") = R.nn.conv2d(lv281, metadata["relax.expr.Constant"][92], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            lv283: R.Tensor((1, 32, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][93], R.shape([1, 32, 1, 1]))
            lv284: R.Tensor((1, 32, 80, 80), dtype="float32") = R.add(lv282, lv283)
            lv285: R.Tensor((1, 32, 80, 80), dtype="float32") = R.sigmoid(lv284)
            lv286: R.Tensor((1, 32, 80, 80), dtype="float32") = R.multiply(lv284, lv285)
            lv287: R.Tensor((1, 32, 80, 80), dtype="float32") = R.add(lv276, lv286)
            lv288: R.Tensor((1, 96, 80, 80), dtype="float32") = R.concat((lv275, lv276, lv287), axis=1)
            lv289: R.Tensor((1, 64, 80, 80), dtype="float32") = R.nn.conv2d(lv288, metadata["relax.expr.Constant"][94], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            lv290: R.Tensor((1, 64, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][95], R.shape([1, 64, 1, 1]))
            lv291: R.Tensor((1, 64, 80, 80), dtype="float32") = R.add(lv289, lv290)
            lv292: R.Tensor((1, 64, 80, 80), dtype="float32") = R.sigmoid(lv291)
            lv293: R.Tensor((1, 64, 80, 80), dtype="float32") = R.multiply(lv291, lv292)
            lv294: R.Tensor((1, 64, 40, 40), dtype="float32") = R.nn.conv2d(lv293, metadata["relax.expr.Constant"][96], strides=[2, 2], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            lv295: R.Tensor((1, 64, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][97], R.shape([1, 64, 1, 1]))
            lv296: R.Tensor((1, 64, 80, 80), dtype="float32") = R.nn.conv2d(lv293, metadata["relax.expr.Constant"][98], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            lv297: R.Tensor((1, 64, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][99], R.shape([1, 64, 1, 1]))
            lv298: R.Tensor((1, 64, 80, 80), dtype="float32") = R.nn.conv2d(lv293, metadata["relax.expr.Constant"][100], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=64, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            lv299: R.Tensor((1, 64, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][101], R.shape([1, 64, 1, 1]))
            lv300: R.Tensor((1, 64, 40, 40), dtype="float32") = R.add(lv294, lv295)
            lv301: R.Tensor((1, 64, 80, 80), dtype="float32") = R.add(lv296, lv297)
            lv302: R.Tensor((1, 64, 80, 80), dtype="float32") = R.add(lv298, lv299)
            lv303: R.Tensor((1, 64, 40, 40), dtype="float32") = R.sigmoid(lv300)
            lv304: R.Tensor((1, 64, 80, 80), dtype="float32") = R.sigmoid(lv301)
            lv305: R.Tensor((1, 64, 80, 80), dtype="float32") = R.sigmoid(lv302)
            lv306: R.Tensor((1, 64, 40, 40), dtype="float32") = R.multiply(lv300, lv303)
            lv307: R.Tensor((1, 64, 80, 80), dtype="float32") = R.multiply(lv301, lv304)
            lv308: R.Tensor((1, 64, 80, 80), dtype="float32") = R.nn.conv2d(lv307, metadata["relax.expr.Constant"][102], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            lv309: R.Tensor((1, 64, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][103], R.shape([1, 64, 1, 1]))
            lv310: R.Tensor((1, 64, 80, 80), dtype="float32") = R.multiply(lv302, lv305)
            lv311: R.Tensor((1, 80, 80, 80), dtype="float32") = R.nn.conv2d(lv310, metadata["relax.expr.Constant"][104], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            lv312: R.Tensor((1, 80, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][105], R.shape([1, 80, 1, 1]))
            lv313: R.Tensor((1, 192, 40, 40), dtype="float32") = R.concat((lv306, lv266), axis=1)
            lv314: R.Tensor((1, 128, 40, 40), dtype="float32") = R.nn.conv2d(lv313, metadata["relax.expr.Constant"][106], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            lv315: R.Tensor((1, 128, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][107], R.shape([1, 128, 1, 1]))
            lv316: R.Tensor((1, 64, 80, 80), dtype="float32") = R.add(lv308, lv309)
            lv317: R.Tensor((1, 80, 80, 80), dtype="float32") = R.add(lv311, lv312)
            lv318: R.Tensor((1, 128, 40, 40), dtype="float32") = R.add(lv314, lv315)
            lv319: R.Tensor((1, 64, 80, 80), dtype="float32") = R.sigmoid(lv316)
            lv320: R.Tensor((1, 80, 80, 80), dtype="float32") = R.sigmoid(lv317)
            lv321: R.Tensor((1, 128, 40, 40), dtype="float32") = R.sigmoid(lv318)
            lv322: R.Tensor((1, 64, 80, 80), dtype="float32") = R.multiply(lv316, lv319)
            lv323: R.Tensor((1, 64, 80, 80), dtype="float32") = R.nn.conv2d(lv322, metadata["relax.expr.Constant"][108], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            lv324: R.Tensor((1, 64, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][109], R.shape([1, 64, 1, 1]))
            lv325: R.Tensor((1, 80, 80, 80), dtype="float32") = R.multiply(lv317, lv320)
            lv326: R.Tensor((1, 80, 80, 80), dtype="float32") = R.nn.conv2d(lv325, metadata["relax.expr.Constant"][110], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=80, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            lv327: R.Tensor((1, 80, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][111], R.shape([1, 80, 1, 1]))
            lv328: R.Tensor((1, 128, 40, 40), dtype="float32") = R.multiply(lv318, lv321)
            lv329 = R.call_tir(cls.split2, (lv328,), out_sinfo=[R.Tensor((1, 64, 40, 40), dtype="float32"), R.Tensor((1, 64, 40, 40), dtype="float32")])
            lv330: R.Tensor((1, 64, 40, 40), dtype="float32") = lv329[0]
            lv331: R.Tensor((1, 64, 40, 40), dtype="float32") = lv329[1]
            lv332: R.Tensor((1, 80, 80, 80), dtype="float32") = R.add(lv326, lv327)
            lv333: R.Tensor((1, 32, 40, 40), dtype="float32") = R.nn.conv2d(lv331, metadata["relax.expr.Constant"][112], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            lv334: R.Tensor((1, 32, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][113], R.shape([1, 32, 1, 1]))
            lv335: R.Tensor((1, 80, 80, 80), dtype="float32") = R.sigmoid(lv332)
            lv336: R.Tensor((1, 32, 40, 40), dtype="float32") = R.add(lv333, lv334)
            lv337: R.Tensor((1, 80, 80, 80), dtype="float32") = R.multiply(lv332, lv335)
            lv338: R.Tensor((1, 80, 80, 80), dtype="float32") = R.nn.conv2d(lv337, metadata["relax.expr.Constant"][114], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            lv339: R.Tensor((1, 80, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][115], R.shape([1, 80, 1, 1]))
            lv340: R.Tensor((1, 32, 40, 40), dtype="float32") = R.sigmoid(lv336)
            lv341: R.Tensor((1, 80, 80, 80), dtype="float32") = R.add(lv338, lv339)
            lv342: R.Tensor((1, 32, 40, 40), dtype="float32") = R.multiply(lv336, lv340)
            lv343: R.Tensor((1, 64, 40, 40), dtype="float32") = R.nn.conv2d(lv342, metadata["relax.expr.Constant"][116], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            lv344: R.Tensor((1, 64, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][117], R.shape([1, 64, 1, 1]))
            lv345: R.Tensor((1, 80, 80, 80), dtype="float32") = R.sigmoid(lv341)
            lv346: R.Tensor((1, 64, 40, 40), dtype="float32") = R.add(lv343, lv344)
            lv347: R.Tensor((1, 80, 80, 80), dtype="float32") = R.multiply(lv341, lv345)
            lv348: R.Tensor((1, 80, 80, 80), dtype="float32") = R.nn.conv2d(lv347, metadata["relax.expr.Constant"][118], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            lv349: R.Tensor((1, 80, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][119], R.shape([1, 80, 1, 1]))
            lv350: R.Tensor((1, 64, 40, 40), dtype="float32") = R.sigmoid(lv346)
            lv351: R.Tensor((1, 64, 80, 80), dtype="float32") = R.add(lv323, lv324)
            lv352: R.Tensor((1, 80, 80, 80), dtype="float32") = R.add(lv348, lv349)
            lv353: R.Tensor((1, 64, 40, 40), dtype="float32") = R.multiply(lv346, lv350)
            lv354: R.Tensor((1, 144, 80, 80), dtype="float32") = R.concat((lv351, lv352), axis=1)
            lv355: R.Tensor((1, 64, 40, 40), dtype="float32") = R.add(lv331, lv353)
            lv356: R.Tensor((1, 192, 40, 40), dtype="float32") = R.concat((lv330, lv331, lv355), axis=1)
            lv357: R.Tensor((1, 128, 40, 40), dtype="float32") = R.nn.conv2d(lv356, metadata["relax.expr.Constant"][120], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            lv358: R.Tensor((1, 128, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][121], R.shape([1, 128, 1, 1]))
            lv359: R.Tensor((1, 128, 40, 40), dtype="float32") = R.add(lv357, lv358)
            lv360: R.Tensor((1, 128, 40, 40), dtype="float32") = R.sigmoid(lv359)
            lv361: R.Tensor((1, 128, 40, 40), dtype="float32") = R.multiply(lv359, lv360)
            lv362: R.Tensor((1, 128, 20, 20), dtype="float32") = R.nn.conv2d(lv361, metadata["relax.expr.Constant"][122], strides=[2, 2], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            lv363: R.Tensor((1, 128, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][123], R.shape([1, 128, 1, 1]))
            lv364: R.Tensor((1, 64, 40, 40), dtype="float32") = R.nn.conv2d(lv361, metadata["relax.expr.Constant"][124], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            lv365: R.Tensor((1, 64, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][125], R.shape([1, 64, 1, 1]))
            lv366: R.Tensor((1, 128, 40, 40), dtype="float32") = R.nn.conv2d(lv361, metadata["relax.expr.Constant"][126], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=128, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            lv367: R.Tensor((1, 128, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][127], R.shape([1, 128, 1, 1]))
            lv368: R.Tensor((1, 128, 20, 20), dtype="float32") = R.add(lv362, lv363)
            lv369: R.Tensor((1, 64, 40, 40), dtype="float32") = R.add(lv364, lv365)
            lv370: R.Tensor((1, 128, 40, 40), dtype="float32") = R.add(lv366, lv367)
            lv371: R.Tensor((1, 128, 20, 20), dtype="float32") = R.sigmoid(lv368)
            lv372: R.Tensor((1, 64, 40, 40), dtype="float32") = R.sigmoid(lv369)
            lv373: R.Tensor((1, 128, 40, 40), dtype="float32") = R.sigmoid(lv370)
            lv374: R.Tensor((1, 128, 20, 20), dtype="float32") = R.multiply(lv368, lv371)
            lv375: R.Tensor((1, 64, 40, 40), dtype="float32") = R.multiply(lv369, lv372)
            lv376: R.Tensor((1, 64, 40, 40), dtype="float32") = R.nn.conv2d(lv375, metadata["relax.expr.Constant"][128], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            lv377: R.Tensor((1, 64, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][129], R.shape([1, 64, 1, 1]))
            lv378: R.Tensor((1, 128, 40, 40), dtype="float32") = R.multiply(lv370, lv373)
            lv379: R.Tensor((1, 80, 40, 40), dtype="float32") = R.nn.conv2d(lv378, metadata["relax.expr.Constant"][130], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            lv380: R.Tensor((1, 80, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][131], R.shape([1, 80, 1, 1]))
            lv381: R.Tensor((1, 384, 20, 20), dtype="float32") = R.concat((lv374, lv239), axis=1)
            lv382: R.Tensor((1, 256, 20, 20), dtype="float32") = R.nn.conv2d(lv381, metadata["relax.expr.Constant"][132], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            lv383: R.Tensor((1, 256, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][133], R.shape([1, 256, 1, 1]))
            lv384: R.Tensor((1, 64, 40, 40), dtype="float32") = R.add(lv376, lv377)
            lv385: R.Tensor((1, 80, 40, 40), dtype="float32") = R.add(lv379, lv380)
            lv386: R.Tensor((1, 256, 20, 20), dtype="float32") = R.add(lv382, lv383)
            lv387: R.Tensor((1, 64, 40, 40), dtype="float32") = R.sigmoid(lv384)
            lv388: R.Tensor((1, 80, 40, 40), dtype="float32") = R.sigmoid(lv385)
            lv389: R.Tensor((1, 256, 20, 20), dtype="float32") = R.sigmoid(lv386)
            lv390: R.Tensor((1, 64, 40, 40), dtype="float32") = R.multiply(lv384, lv387)
            lv391: R.Tensor((1, 64, 40, 40), dtype="float32") = R.nn.conv2d(lv390, metadata["relax.expr.Constant"][134], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            lv392: R.Tensor((1, 64, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][135], R.shape([1, 64, 1, 1]))
            lv393: R.Tensor((1, 80, 40, 40), dtype="float32") = R.multiply(lv385, lv388)
            lv394: R.Tensor((1, 80, 40, 40), dtype="float32") = R.nn.conv2d(lv393, metadata["relax.expr.Constant"][136], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=80, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            lv395: R.Tensor((1, 80, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][137], R.shape([1, 80, 1, 1]))
            lv396: R.Tensor((1, 256, 20, 20), dtype="float32") = R.multiply(lv386, lv389)
            lv397 = R.call_tir(cls.split3, (lv396,), out_sinfo=[R.Tensor((1, 128, 20, 20), dtype="float32"), R.Tensor((1, 128, 20, 20), dtype="float32")])
            lv398: R.Tensor((1, 128, 20, 20), dtype="float32") = lv397[0]
            lv399: R.Tensor((1, 128, 20, 20), dtype="float32") = lv397[1]
            lv400: R.Tensor((1, 80, 40, 40), dtype="float32") = R.add(lv394, lv395)
            lv401: R.Tensor((1, 64, 20, 20), dtype="float32") = R.nn.conv2d(lv399, metadata["relax.expr.Constant"][138], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            lv402: R.Tensor((1, 64, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][139], R.shape([1, 64, 1, 1]))
            lv403: R.Tensor((1, 64, 20, 20), dtype="float32") = R.nn.conv2d(lv399, metadata["relax.expr.Constant"][140], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            lv404: R.Tensor((1, 64, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][141], R.shape([1, 64, 1, 1]))
            lv405: R.Tensor((1, 80, 40, 40), dtype="float32") = R.sigmoid(lv400)
            lv406: R.Tensor((1, 64, 20, 20), dtype="float32") = R.add(lv401, lv402)
            lv407: R.Tensor((1, 64, 20, 20), dtype="float32") = R.add(lv403, lv404)
            lv408: R.Tensor((1, 80, 40, 40), dtype="float32") = R.multiply(lv400, lv405)
            lv409: R.Tensor((1, 80, 40, 40), dtype="float32") = R.nn.conv2d(lv408, metadata["relax.expr.Constant"][142], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            lv410: R.Tensor((1, 80, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][143], R.shape([1, 80, 1, 1]))
            lv411: R.Tensor((1, 64, 20, 20), dtype="float32") = R.sigmoid(lv406)
            lv412: R.Tensor((1, 64, 20, 20), dtype="float32") = R.sigmoid(lv407)
            lv413: R.Tensor((1, 80, 40, 40), dtype="float32") = R.add(lv409, lv410)
            lv414: R.Tensor((1, 64, 20, 20), dtype="float32") = R.multiply(lv406, lv411)
            lv415: R.Tensor((1, 64, 20, 20), dtype="float32") = R.nn.conv2d(lv414, metadata["relax.expr.Constant"][144], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            lv416: R.Tensor((1, 64, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][145], R.shape([1, 64, 1, 1]))
            lv417: R.Tensor((1, 80, 40, 40), dtype="float32") = R.sigmoid(lv413)
            lv418: R.Tensor((1, 64, 20, 20), dtype="float32") = R.add(lv415, lv416)
            lv419: R.Tensor((1, 80, 40, 40), dtype="float32") = R.multiply(lv413, lv417)
            lv420: R.Tensor((1, 80, 40, 40), dtype="float32") = R.nn.conv2d(lv419, metadata["relax.expr.Constant"][146], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            lv421: R.Tensor((1, 80, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][147], R.shape([1, 80, 1, 1]))
            lv422: R.Tensor((1, 64, 20, 20), dtype="float32") = R.sigmoid(lv418)
            lv423: R.Tensor((1, 64, 40, 40), dtype="float32") = R.add(lv391, lv392)
            lv424: R.Tensor((1, 80, 40, 40), dtype="float32") = R.add(lv420, lv421)
            lv425: R.Tensor((1, 64, 20, 20), dtype="float32") = R.multiply(lv418, lv422)
            lv426: R.Tensor((1, 64, 20, 20), dtype="float32") = R.nn.conv2d(lv425, metadata["relax.expr.Constant"][148], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            lv427: R.Tensor((1, 64, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][149], R.shape([1, 64, 1, 1]))
            lv428: R.Tensor((1, 144, 40, 40), dtype="float32") = R.concat((lv423, lv424), axis=1)
            lv429: R.Tensor((1, 64, 20, 20), dtype="float32") = R.add(lv426, lv427)
            lv430: R.Tensor((1, 64, 20, 20), dtype="float32") = R.sigmoid(lv429)
            lv431: R.Tensor((1, 64, 20, 20), dtype="float32") = R.multiply(lv429, lv430)
            lv432: R.Tensor((1, 64, 20, 20), dtype="float32") = R.add(lv414, lv431)
            lv433: R.Tensor((1, 64, 20, 20), dtype="float32") = R.nn.conv2d(lv432, metadata["relax.expr.Constant"][150], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            lv434: R.Tensor((1, 64, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][151], R.shape([1, 64, 1, 1]))
            lv435: R.Tensor((1, 64, 20, 20), dtype="float32") = R.add(lv433, lv434)
            lv436: R.Tensor((1, 64, 20, 20), dtype="float32") = R.sigmoid(lv435)
            lv437: R.Tensor((1, 64, 20, 20), dtype="float32") = R.multiply(lv435, lv436)
            lv438: R.Tensor((1, 64, 20, 20), dtype="float32") = R.nn.conv2d(lv437, metadata["relax.expr.Constant"][152], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            lv439: R.Tensor((1, 64, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][153], R.shape([1, 64, 1, 1]))
            lv440: R.Tensor((1, 64, 20, 20), dtype="float32") = R.add(lv438, lv439)
            lv441: R.Tensor((1, 64, 20, 20), dtype="float32") = R.sigmoid(lv440)
            lv442: R.Tensor((1, 64, 20, 20), dtype="float32") = R.multiply(lv440, lv441)
            lv443: R.Tensor((1, 64, 20, 20), dtype="float32") = R.add(lv432, lv442)
            lv444: R.Tensor((1, 64, 20, 20), dtype="float32") = R.multiply(lv407, lv412)
            lv445: R.Tensor((1, 128, 20, 20), dtype="float32") = R.concat((lv443, lv444), axis=1)
            lv446: R.Tensor((1, 128, 20, 20), dtype="float32") = R.nn.conv2d(lv445, metadata["relax.expr.Constant"][154], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            lv447: R.Tensor((1, 128, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][155], R.shape([1, 128, 1, 1]))
            lv448: R.Tensor((1, 128, 20, 20), dtype="float32") = R.add(lv446, lv447)
            lv449: R.Tensor((1, 128, 20, 20), dtype="float32") = R.sigmoid(lv448)
            lv450: R.Tensor((1, 128, 20, 20), dtype="float32") = R.multiply(lv448, lv449)
            lv451: R.Tensor((1, 384, 20, 20), dtype="float32") = R.concat((lv398, lv399, lv450), axis=1)
            lv452: R.Tensor((1, 256, 20, 20), dtype="float32") = R.nn.conv2d(lv451, metadata["relax.expr.Constant"][156], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            lv453: R.Tensor((1, 256, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][157], R.shape([1, 256, 1, 1]))
            lv454: R.Tensor((1, 256, 20, 20), dtype="float32") = R.add(lv452, lv453)
            lv455: R.Tensor((1, 256, 20, 20), dtype="float32") = R.sigmoid(lv454)
            lv456: R.Tensor((1, 256, 20, 20), dtype="float32") = R.multiply(lv454, lv455)
            lv457: R.Tensor((1, 64, 20, 20), dtype="float32") = R.nn.conv2d(lv456, metadata["relax.expr.Constant"][158], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            lv458: R.Tensor((1, 64, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][159], R.shape([1, 64, 1, 1]))
            lv459: R.Tensor((1, 256, 20, 20), dtype="float32") = R.nn.conv2d(lv456, metadata["relax.expr.Constant"][160], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=256, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            lv460: R.Tensor((1, 256, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][161], R.shape([1, 256, 1, 1]))
            lv461: R.Tensor((1, 64, 20, 20), dtype="float32") = R.add(lv457, lv458)
            lv462: R.Tensor((1, 256, 20, 20), dtype="float32") = R.add(lv459, lv460)
            lv463: R.Tensor((1, 64, 20, 20), dtype="float32") = R.sigmoid(lv461)
            lv464: R.Tensor((1, 256, 20, 20), dtype="float32") = R.sigmoid(lv462)
            lv465: R.Tensor((1, 64, 20, 20), dtype="float32") = R.multiply(lv461, lv463)
            lv466: R.Tensor((1, 64, 20, 20), dtype="float32") = R.nn.conv2d(lv465, metadata["relax.expr.Constant"][162], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            lv467: R.Tensor((1, 64, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][163], R.shape([1, 64, 1, 1]))
            lv468: R.Tensor((1, 256, 20, 20), dtype="float32") = R.multiply(lv462, lv464)
            lv469: R.Tensor((1, 80, 20, 20), dtype="float32") = R.nn.conv2d(lv468, metadata["relax.expr.Constant"][164], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            lv470: R.Tensor((1, 80, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][165], R.shape([1, 80, 1, 1]))
            lv471: R.Tensor((1, 64, 20, 20), dtype="float32") = R.add(lv466, lv467)
            lv472: R.Tensor((1, 80, 20, 20), dtype="float32") = R.add(lv469, lv470)
            lv473: R.Tensor((1, 64, 20, 20), dtype="float32") = R.sigmoid(lv471)
            lv474: R.Tensor((1, 80, 20, 20), dtype="float32") = R.sigmoid(lv472)
            lv475: R.Tensor((1, 64, 20, 20), dtype="float32") = R.multiply(lv471, lv473)
            lv476: R.Tensor((1, 64, 20, 20), dtype="float32") = R.nn.conv2d(lv475, metadata["relax.expr.Constant"][166], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            lv477: R.Tensor((1, 64, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][167], R.shape([1, 64, 1, 1]))
            lv478: R.Tensor((1, 80, 20, 20), dtype="float32") = R.multiply(lv472, lv474)
            lv479: R.Tensor((1, 80, 20, 20), dtype="float32") = R.nn.conv2d(lv478, metadata["relax.expr.Constant"][168], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=80, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            lv480: R.Tensor((1, 80, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][169], R.shape([1, 80, 1, 1]))
            lv481: R.Tensor((1, 80, 20, 20), dtype="float32") = R.add(lv479, lv480)
            lv482: R.Tensor((1, 80, 20, 20), dtype="float32") = R.sigmoid(lv481)
            lv483: R.Tensor((1, 80, 20, 20), dtype="float32") = R.multiply(lv481, lv482)
            lv484: R.Tensor((1, 80, 20, 20), dtype="float32") = R.nn.conv2d(lv483, metadata["relax.expr.Constant"][170], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            lv485: R.Tensor((1, 80, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][171], R.shape([1, 80, 1, 1]))
            lv486: R.Tensor((1, 80, 20, 20), dtype="float32") = R.add(lv484, lv485)
            lv487: R.Tensor((1, 80, 20, 20), dtype="float32") = R.sigmoid(lv486)
            lv488: R.Tensor((1, 80, 20, 20), dtype="float32") = R.multiply(lv486, lv487)
            lv489: R.Tensor((1, 80, 20, 20), dtype="float32") = R.nn.conv2d(lv488, metadata["relax.expr.Constant"][172], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            lv490: R.Tensor((1, 80, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][173], R.shape([1, 80, 1, 1]))
            lv491: R.Tensor((1, 64, 20, 20), dtype="float32") = R.add(lv476, lv477)
            lv492: R.Tensor((1, 80, 20, 20), dtype="float32") = R.add(lv489, lv490)
            lv493: R.Tensor((1, 144, 20, 20), dtype="float32") = R.concat((lv491, lv492), axis=1)
            lv494: R.Tensor((1, 144, 6400), dtype="float32") = R.reshape(lv354, R.shape([1, 144, 6400]))
            lv495: R.Tensor((1, 144, 1600), dtype="float32") = R.reshape(lv428, R.shape([1, 144, 1600]))
            lv496: R.Tensor((1, 144, 400), dtype="float32") = R.reshape(lv493, R.shape([1, 144, 400]))
            lv497: R.Tensor((1, 144, 8400), dtype="float32") = R.concat((lv494, lv495, lv496), axis=2)
            lv498 = R.call_tir(cls.split5, (lv497,), out_sinfo=[R.Tensor((1, 64, 8400), dtype="float32"), R.Tensor((1, 80, 8400), dtype="float32")])
            lv499: R.Tensor((1, 64, 8400), dtype="float32") = lv498[0]
            lv500: R.Tensor((1, 80, 8400), dtype="float32") = lv498[1]
            lv501: R.Tensor((1, 4, 16, 8400), dtype="float32") = R.reshape(lv499, R.shape([1, 4, 16, 8400]))
            lv502: R.Tensor((1, 16, 4, 8400), dtype="float32") = R.permute_dims(lv501, axes=[0, 2, 1, 3])
            lv503: R.Tensor((1, 16, 4, 8400), dtype="float32") = R.nn.softmax(lv502, axis=1)
            lv504: R.Tensor((1, 1, 4, 8400), dtype="float32") = R.nn.conv2d(lv503, metadata["relax.expr.Constant"][174], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
            lv505: R.Tensor((1, 4, 8400), dtype="float32") = R.reshape(lv504, R.shape([1, 4, 8400]))
            lv506: R.Tensor((1, 2, 8400), dtype="float32") = R.strided_slice(lv505, (R.prim_value(1),), (R.prim_value(0),), (R.prim_value(2),), (R.prim_value(1),), assume_inbound=False)
            lv507: R.Tensor((1, 2, 8400), dtype="float32") = R.strided_slice(lv505, (R.prim_value(1),), (R.prim_value(2),), (R.prim_value(4),), (R.prim_value(1),), assume_inbound=False)
            lv508: R.Tensor((1, 2, 8400), dtype="float32") = R.subtract(metadata["relax.expr.Constant"][175], lv506)
            lv509: R.Tensor((1, 2, 8400), dtype="float32") = R.add(metadata["relax.expr.Constant"][176], lv507)
            lv510: R.Tensor((1, 2, 8400), dtype="float32") = R.add(lv508, lv509)
            lv511: R.Tensor((1, 2, 8400), dtype="float32") = R.divide(lv510, R.const(2.0, "float32"))
            lv512: R.Tensor((1, 2, 8400), dtype="float32") = R.subtract(lv509, lv508)
            lv513: R.Tensor((1, 4, 8400), dtype="float32") = R.concat((lv511, lv512), axis=1)
            lv514: R.Tensor((1, 4, 8400), dtype="float32") = R.multiply(lv513, metadata["relax.expr.Constant"][177])
            lv515: R.Tensor((1, 80, 8400), dtype="float32") = R.sigmoid(lv500)
            gv: R.Tensor((1, 84, 8400), dtype="float32") = R.concat((lv514, lv515), axis=1)
            R.output(gv)
        return gv

# Metadata omitted. Use show_meta=True in script() method to show it.
relax_func = tvm_model["main"]
relax_func.show()
relax.op.split
relax.
tvm_model.show()
# 将算子转换为推理模式
tvm_model = relax.transform.DecomposeOpsForInference()(tvm_model)
tvm_model.show()
# 将任何 Relax 算子合法化为 TensorIR
tvm_model = relax.transform.LegalizeOps()(tvm_model)

# 将模型与参数分离
tvm_model, params = relax.frontend.detach_params(tvm_model)
# 将 Relax 图编译为虚拟机(VM)然后运行
with tvm.transform.PassContext(opt_level=3):
    ex = relax.build(tvm_model, target="llvm")
    vm = relax.VirtualMachine(ex, tvm.cpu())

准备输入:

input_list = [
    inputs[key.name_hint] for key in tvm_model["main"].params if key.name_hint in inputs
]
if params:
    input_list += params["main"]

运行模型并检查输出:

vm.set_input("main", *input_list)
vm.invoke_stateful("main")
tvm_output = vm.get_outputs("main")
# 如果只有一个输出,则将其包装为列表
if len(ort_output) == 1:
    # 对于 TVM 不检查输出数量  
    # 对于序列输出,TVM 的输出是元组(Tuple),  
    # 而 ONNX 的输出数量是一个,即列表形式。
    tvm_output = [tvm_output]
def _check_output(tvm_out: list, ort_out: list, rtol: float = 1e-7, atol: float = 1e-5,):
    if isinstance(tvm_out, tuple) and isinstance(ort_out, (tvm.runtime.ShapeTuple, list)):
        assert len(tvm_out) == len(ort_out), "Unequal number of outputs"
        for tvm_out_i, ort_out_i in zip(tvm_out, ort_out):
            _check_output(tvm_out_i, ort_out_i)
    elif isinstance(tvm_out, tvm.nd.NDArray) and isinstance(ort_out, np.ndarray):
        np.testing.assert_allclose(tvm_out.numpy(), ort_out, rtol=rtol, atol=atol)
    elif isinstance(tvm_out, tvm.runtime.ShapeTuple) and isinstance(ort_out, np.ndarray):
        shape_out = tvm.nd.array([int(i) for i in tvm_out])
        np.testing.assert_allclose(shape_out.numpy(), ort_out, rtol=rtol, atol=atol)
    elif isinstance(tvm_out, (int, float, bool)) and isinstance(ort_out, np.ndarray):
        np.testing.assert_allclose(np.array(tvm_out), ort_out, rtol=rtol, atol=atol)
    else:
        raise ValueError(f"Unsupported types: {type(tvm_out)}, {type(ort_out)}")
# Check that number of outputs match.
assert len(tvm_output) == len(ort_output), "Unequal number of outputs"
for tvm_out, ort_out in zip(tvm_output, ort_output):
    # TODO Allow configurable tolerance.
    if ort_out is not None:
        _check_output(tvm_out, ort_out, rtol=1e-4, atol=1e-5)