YOLO11 Relax 优化#
import torch
torch.cuda.empty_cache()
from PIL import Image
import numpy as np
from ultralytics import YOLO
input_path = "images/vehicle-jaguar-f-type-car-red-cars-wallpaper.jpg"
im = Image.open(input_path) #.resize((384, 640))
yolo = YOLO("yolo11n.pt")
results = yolo(np.array(im), conf=0.25)
Image.fromarray(results[0].plot()).resize((320, 208))
0: 416x640 1 car, 15.4ms
Speed: 2.7ms preprocess, 15.4ms inference, 2.2ms postprocess per image at shape (1, 3, 416, 640)
预处理:
from PIL import Image
import numpy as np
import torch
from ultralytics.data.augment import LetterBox
imgsz = 640, 640
strides = yolo.model.stride
mean = (0,)
std = (255,)
letterbox = LetterBox(new_shape=imgsz, auto=False, scaleFill=False, scaleup=True, stride=32)
origin_image = np.asanyarray(Image.open(input_path))
letterbox_image = letterbox(image=origin_image)
xs = np.stack([letterbox_image - mean])
print(f"数据内存的连续性:{xs.flags["C_CONTIGUOUS"]}")
xs = xs.transpose((0, 3, 1, 2)) # BHWC to BCHW, (n, 3, h, w)
print(f"数据内存的连续性(transpose):{xs.flags["C_CONTIGUOUS"]}")
xs = np.ascontiguousarray(xs) # contiguous
print(f"数据内存的连续性:{xs.flags["C_CONTIGUOUS"]}")
xs = (xs / std).astype("float32") # 归一化值域范围为 0.0 - 1.0
Image.fromarray(
np.concatenate([letterbox_image, (xs[0]*std).astype("uint8").transpose((1, 2, 0))], axis=1)
).resize((640, 320,))
数据内存的连续性:True
数据内存的连续性(transpose):False
数据内存的连续性:True
后处理:
from ultralytics.utils import ops
from ultralytics.engine.results import Results
def postprocess(preds, img, orig_imgs, names, input_path, conf_thres=0.25, iou_thres=0.45,):
"""Post-processes predictions and returns a list of Results objects."""
preds = ops.non_max_suppression(
preds,
conf_thres=conf_thres,
iou_thres=iou_thres,
# agnostic=self.args.agnostic_nms,
# max_det=self.args.max_det,
# classes=80,
)
results = []
for i, pred in enumerate(preds):
orig_img = orig_imgs[i]
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
img_path = input_path
results.append(Results(orig_img, path=img_path, names=names, boxes=pred))
return results
ONNX 推理#
import onnxruntime
import onnx
onnx_model = onnx.load('yolo11n.onnx')
# 通过 ONNX 运行模型以获取预期结果
ort_session = onnxruntime.InferenceSession(
onnx_model.SerializeToString(), providers=["CPUExecutionProvider"]
)
inputs = {"images": xs}
ort_output = ort_session.run([], inputs)
测试 YOLO ONNX Relax 前端#
import tvm
from tvm import relax
from tvm.relax.frontend.onnx import from_onnx
tvm_model = from_onnx(onnx_model, keep_params_in_input=False)
mod_actual = relax.transform.AnnotateTIROpPattern()(tvm_model)
mod_actual = relax.transform.FuseOps()(mod_actual)
mod_actual.show()
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R
@I.ir_module
class Module:
@T.prim_func(private=True)
def split(A: T.Buffer((T.int64(1), T.int64(32), T.int64(160), T.int64(160)), "float32"), T_split: T.Buffer((T.int64(1), T.int64(16), T.int64(160), T.int64(160)), "float32"), T_split_1: T.Buffer((T.int64(1), T.int64(16), T.int64(160), T.int64(160)), "float32")):
T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(16), T.int64(160), T.int64(160)):
with T.block("T_split"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(A[v_ax0, v_ax1, v_ax2, v_ax3])
T.writes(T_split[v_ax0, v_ax1, v_ax2, v_ax3])
T_split[v_ax0, v_ax1, v_ax2, v_ax3] = A[v_ax0, v_ax1, v_ax2, v_ax3]
for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(16), T.int64(160), T.int64(160)):
with T.block("T_split_1"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(A[v_ax0, v_ax1 + T.int64(16), v_ax2, v_ax3])
T.writes(T_split_1[v_ax0, v_ax1, v_ax2, v_ax3])
T_split_1[v_ax0, v_ax1, v_ax2, v_ax3] = A[v_ax0, v_ax1 + T.int64(16), v_ax2, v_ax3]
@T.prim_func(private=True)
def split1(A: T.Buffer((T.int64(1), T.int64(64), T.int64(80), T.int64(80)), "float32"), T_split: T.Buffer((T.int64(1), T.int64(32), T.int64(80), T.int64(80)), "float32"), T_split_1: T.Buffer((T.int64(1), T.int64(32), T.int64(80), T.int64(80)), "float32")):
T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(80), T.int64(80)):
with T.block("T_split"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(A[v_ax0, v_ax1, v_ax2, v_ax3])
T.writes(T_split[v_ax0, v_ax1, v_ax2, v_ax3])
T_split[v_ax0, v_ax1, v_ax2, v_ax3] = A[v_ax0, v_ax1, v_ax2, v_ax3]
for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(80), T.int64(80)):
with T.block("T_split_1"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(A[v_ax0, v_ax1 + T.int64(32), v_ax2, v_ax3])
T.writes(T_split_1[v_ax0, v_ax1, v_ax2, v_ax3])
T_split_1[v_ax0, v_ax1, v_ax2, v_ax3] = A[v_ax0, v_ax1 + T.int64(32), v_ax2, v_ax3]
@T.prim_func(private=True)
def split2(A: T.Buffer((T.int64(1), T.int64(128), T.int64(40), T.int64(40)), "float32"), T_split: T.Buffer((T.int64(1), T.int64(64), T.int64(40), T.int64(40)), "float32"), T_split_1: T.Buffer((T.int64(1), T.int64(64), T.int64(40), T.int64(40)), "float32")):
T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(64), T.int64(40), T.int64(40)):
with T.block("T_split"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(A[v_ax0, v_ax1, v_ax2, v_ax3])
T.writes(T_split[v_ax0, v_ax1, v_ax2, v_ax3])
T_split[v_ax0, v_ax1, v_ax2, v_ax3] = A[v_ax0, v_ax1, v_ax2, v_ax3]
for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(64), T.int64(40), T.int64(40)):
with T.block("T_split_1"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(A[v_ax0, v_ax1 + T.int64(64), v_ax2, v_ax3])
T.writes(T_split_1[v_ax0, v_ax1, v_ax2, v_ax3])
T_split_1[v_ax0, v_ax1, v_ax2, v_ax3] = A[v_ax0, v_ax1 + T.int64(64), v_ax2, v_ax3]
@T.prim_func(private=True)
def split3(A: T.Buffer((T.int64(1), T.int64(256), T.int64(20), T.int64(20)), "float32"), T_split: T.Buffer((T.int64(1), T.int64(128), T.int64(20), T.int64(20)), "float32"), T_split_1: T.Buffer((T.int64(1), T.int64(128), T.int64(20), T.int64(20)), "float32")):
T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(128), T.int64(20), T.int64(20)):
with T.block("T_split"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(A[v_ax0, v_ax1, v_ax2, v_ax3])
T.writes(T_split[v_ax0, v_ax1, v_ax2, v_ax3])
T_split[v_ax0, v_ax1, v_ax2, v_ax3] = A[v_ax0, v_ax1, v_ax2, v_ax3]
for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(128), T.int64(20), T.int64(20)):
with T.block("T_split_1"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(A[v_ax0, v_ax1 + T.int64(128), v_ax2, v_ax3])
T.writes(T_split_1[v_ax0, v_ax1, v_ax2, v_ax3])
T_split_1[v_ax0, v_ax1, v_ax2, v_ax3] = A[v_ax0, v_ax1 + T.int64(128), v_ax2, v_ax3]
@T.prim_func(private=True)
def split4(A: T.Buffer((T.int64(1), T.int64(2), T.int64(128), T.int64(400)), "float32"), T_split: T.Buffer((T.int64(1), T.int64(2), T.int64(32), T.int64(400)), "float32"), T_split_1: T.Buffer((T.int64(1), T.int64(2), T.int64(32), T.int64(400)), "float32"), T_split_2: T.Buffer((T.int64(1), T.int64(2), T.int64(64), T.int64(400)), "float32")):
T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(2), T.int64(32), T.int64(400)):
with T.block("T_split"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(A[v_ax0, v_ax1, v_ax2, v_ax3])
T.writes(T_split[v_ax0, v_ax1, v_ax2, v_ax3])
T_split[v_ax0, v_ax1, v_ax2, v_ax3] = A[v_ax0, v_ax1, v_ax2, v_ax3]
for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(2), T.int64(32), T.int64(400)):
with T.block("T_split_1"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(A[v_ax0, v_ax1, v_ax2 + T.int64(32), v_ax3])
T.writes(T_split_1[v_ax0, v_ax1, v_ax2, v_ax3])
T_split_1[v_ax0, v_ax1, v_ax2, v_ax3] = A[v_ax0, v_ax1, v_ax2 + T.int64(32), v_ax3]
for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(2), T.int64(64), T.int64(400)):
with T.block("T_split_2"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(A[v_ax0, v_ax1, v_ax2 + T.int64(64), v_ax3])
T.writes(T_split_2[v_ax0, v_ax1, v_ax2, v_ax3])
T_split_2[v_ax0, v_ax1, v_ax2, v_ax3] = A[v_ax0, v_ax1, v_ax2 + T.int64(64), v_ax3]
@T.prim_func(private=True)
def split5(A: T.Buffer((T.int64(1), T.int64(144), T.int64(8400)), "float32"), T_split: T.Buffer((T.int64(1), T.int64(64), T.int64(8400)), "float32"), T_split_1: T.Buffer((T.int64(1), T.int64(80), T.int64(8400)), "float32")):
T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(64), T.int64(8400)):
with T.block("T_split"):
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
T.reads(A[v_ax0, v_ax1, v_ax2])
T.writes(T_split[v_ax0, v_ax1, v_ax2])
T_split[v_ax0, v_ax1, v_ax2] = A[v_ax0, v_ax1, v_ax2]
for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(80), T.int64(8400)):
with T.block("T_split_1"):
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
T.reads(A[v_ax0, v_ax1 + T.int64(64), v_ax2])
T.writes(T_split_1[v_ax0, v_ax1, v_ax2])
T_split_1[v_ax0, v_ax1, v_ax2] = A[v_ax0, v_ax1 + T.int64(64), v_ax2]
@R.function
def main(images: R.Tensor((1, 3, 640, 640), dtype="float32")) -> R.Tensor((1, 84, 8400), dtype="float32"):
R.func_attr({"num_input": 1})
cls = Module
with R.dataflow():
lv: R.Tensor((1, 16, 320, 320), dtype="float32") = R.nn.conv2d(images, metadata["relax.expr.Constant"][0], strides=[2, 2], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
lv1: R.Tensor((1, 16, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][1], R.shape([1, 16, 1, 1]))
lv2: R.Tensor((1, 16, 320, 320), dtype="float32") = R.add(lv, lv1)
lv3: R.Tensor((1, 16, 320, 320), dtype="float32") = R.sigmoid(lv2)
lv4: R.Tensor((1, 16, 320, 320), dtype="float32") = R.multiply(lv2, lv3)
lv5: R.Tensor((1, 32, 160, 160), dtype="float32") = R.nn.conv2d(lv4, metadata["relax.expr.Constant"][2], strides=[2, 2], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
lv6: R.Tensor((1, 32, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][3], R.shape([1, 32, 1, 1]))
lv7: R.Tensor((1, 32, 160, 160), dtype="float32") = R.add(lv5, lv6)
lv8: R.Tensor((1, 32, 160, 160), dtype="float32") = R.sigmoid(lv7)
lv9: R.Tensor((1, 32, 160, 160), dtype="float32") = R.multiply(lv7, lv8)
lv10: R.Tensor((1, 32, 160, 160), dtype="float32") = R.nn.conv2d(lv9, metadata["relax.expr.Constant"][4], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
lv11: R.Tensor((1, 32, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][5], R.shape([1, 32, 1, 1]))
lv12: R.Tensor((1, 32, 160, 160), dtype="float32") = R.add(lv10, lv11)
lv13: R.Tensor((1, 32, 160, 160), dtype="float32") = R.sigmoid(lv12)
lv14: R.Tensor((1, 32, 160, 160), dtype="float32") = R.multiply(lv12, lv13)
lv15 = R.call_tir(cls.split, (lv14,), out_sinfo=[R.Tensor((1, 16, 160, 160), dtype="float32"), R.Tensor((1, 16, 160, 160), dtype="float32")])
lv16: R.Tensor((1, 16, 160, 160), dtype="float32") = lv15[0]
lv17: R.Tensor((1, 16, 160, 160), dtype="float32") = lv15[1]
lv18: R.Tensor((1, 8, 160, 160), dtype="float32") = R.nn.conv2d(lv17, metadata["relax.expr.Constant"][6], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
lv19: R.Tensor((1, 8, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][7], R.shape([1, 8, 1, 1]))
lv20: R.Tensor((1, 8, 160, 160), dtype="float32") = R.add(lv18, lv19)
lv21: R.Tensor((1, 8, 160, 160), dtype="float32") = R.sigmoid(lv20)
lv22: R.Tensor((1, 8, 160, 160), dtype="float32") = R.multiply(lv20, lv21)
lv23: R.Tensor((1, 16, 160, 160), dtype="float32") = R.nn.conv2d(lv22, metadata["relax.expr.Constant"][8], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
lv24: R.Tensor((1, 16, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][9], R.shape([1, 16, 1, 1]))
lv25: R.Tensor((1, 16, 160, 160), dtype="float32") = R.add(lv23, lv24)
lv26: R.Tensor((1, 16, 160, 160), dtype="float32") = R.sigmoid(lv25)
lv27: R.Tensor((1, 16, 160, 160), dtype="float32") = R.multiply(lv25, lv26)
lv28: R.Tensor((1, 16, 160, 160), dtype="float32") = R.add(lv17, lv27)
lv29: R.Tensor((1, 48, 160, 160), dtype="float32") = R.concat((lv16, lv17, lv28), axis=1)
lv30: R.Tensor((1, 64, 160, 160), dtype="float32") = R.nn.conv2d(lv29, metadata["relax.expr.Constant"][10], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
lv31: R.Tensor((1, 64, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][11], R.shape([1, 64, 1, 1]))
lv32: R.Tensor((1, 64, 160, 160), dtype="float32") = R.add(lv30, lv31)
lv33: R.Tensor((1, 64, 160, 160), dtype="float32") = R.sigmoid(lv32)
lv34: R.Tensor((1, 64, 160, 160), dtype="float32") = R.multiply(lv32, lv33)
lv35: R.Tensor((1, 64, 80, 80), dtype="float32") = R.nn.conv2d(lv34, metadata["relax.expr.Constant"][12], strides=[2, 2], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
lv36: R.Tensor((1, 64, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][13], R.shape([1, 64, 1, 1]))
lv37: R.Tensor((1, 64, 80, 80), dtype="float32") = R.add(lv35, lv36)
lv38: R.Tensor((1, 64, 80, 80), dtype="float32") = R.sigmoid(lv37)
lv39: R.Tensor((1, 64, 80, 80), dtype="float32") = R.multiply(lv37, lv38)
lv40: R.Tensor((1, 64, 80, 80), dtype="float32") = R.nn.conv2d(lv39, metadata["relax.expr.Constant"][14], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
lv41: R.Tensor((1, 64, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][15], R.shape([1, 64, 1, 1]))
lv42: R.Tensor((1, 64, 80, 80), dtype="float32") = R.add(lv40, lv41)
lv43: R.Tensor((1, 64, 80, 80), dtype="float32") = R.sigmoid(lv42)
lv44: R.Tensor((1, 64, 80, 80), dtype="float32") = R.multiply(lv42, lv43)
lv45 = R.call_tir(cls.split1, (lv44,), out_sinfo=[R.Tensor((1, 32, 80, 80), dtype="float32"), R.Tensor((1, 32, 80, 80), dtype="float32")])
lv46: R.Tensor((1, 32, 80, 80), dtype="float32") = lv45[0]
lv47: R.Tensor((1, 32, 80, 80), dtype="float32") = lv45[1]
lv48: R.Tensor((1, 16, 80, 80), dtype="float32") = R.nn.conv2d(lv47, metadata["relax.expr.Constant"][16], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
lv49: R.Tensor((1, 16, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][17], R.shape([1, 16, 1, 1]))
lv50: R.Tensor((1, 16, 80, 80), dtype="float32") = R.add(lv48, lv49)
lv51: R.Tensor((1, 16, 80, 80), dtype="float32") = R.sigmoid(lv50)
lv52: R.Tensor((1, 16, 80, 80), dtype="float32") = R.multiply(lv50, lv51)
lv53: R.Tensor((1, 32, 80, 80), dtype="float32") = R.nn.conv2d(lv52, metadata["relax.expr.Constant"][18], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
lv54: R.Tensor((1, 32, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][19], R.shape([1, 32, 1, 1]))
lv55: R.Tensor((1, 32, 80, 80), dtype="float32") = R.add(lv53, lv54)
lv56: R.Tensor((1, 32, 80, 80), dtype="float32") = R.sigmoid(lv55)
lv57: R.Tensor((1, 32, 80, 80), dtype="float32") = R.multiply(lv55, lv56)
lv58: R.Tensor((1, 32, 80, 80), dtype="float32") = R.add(lv47, lv57)
lv59: R.Tensor((1, 96, 80, 80), dtype="float32") = R.concat((lv46, lv47, lv58), axis=1)
lv60: R.Tensor((1, 128, 80, 80), dtype="float32") = R.nn.conv2d(lv59, metadata["relax.expr.Constant"][20], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
lv61: R.Tensor((1, 128, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][21], R.shape([1, 128, 1, 1]))
lv62: R.Tensor((1, 128, 80, 80), dtype="float32") = R.add(lv60, lv61)
lv63: R.Tensor((1, 128, 80, 80), dtype="float32") = R.sigmoid(lv62)
lv64: R.Tensor((1, 128, 80, 80), dtype="float32") = R.multiply(lv62, lv63)
lv65: R.Tensor((1, 128, 40, 40), dtype="float32") = R.nn.conv2d(lv64, metadata["relax.expr.Constant"][22], strides=[2, 2], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
lv66: R.Tensor((1, 128, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][23], R.shape([1, 128, 1, 1]))
lv67: R.Tensor((1, 128, 40, 40), dtype="float32") = R.add(lv65, lv66)
lv68: R.Tensor((1, 128, 40, 40), dtype="float32") = R.sigmoid(lv67)
lv69: R.Tensor((1, 128, 40, 40), dtype="float32") = R.multiply(lv67, lv68)
lv70: R.Tensor((1, 128, 40, 40), dtype="float32") = R.nn.conv2d(lv69, metadata["relax.expr.Constant"][24], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
lv71: R.Tensor((1, 128, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][25], R.shape([1, 128, 1, 1]))
lv72: R.Tensor((1, 128, 40, 40), dtype="float32") = R.add(lv70, lv71)
lv73: R.Tensor((1, 128, 40, 40), dtype="float32") = R.sigmoid(lv72)
lv74: R.Tensor((1, 128, 40, 40), dtype="float32") = R.multiply(lv72, lv73)
lv75 = R.call_tir(cls.split2, (lv74,), out_sinfo=[R.Tensor((1, 64, 40, 40), dtype="float32"), R.Tensor((1, 64, 40, 40), dtype="float32")])
lv76: R.Tensor((1, 64, 40, 40), dtype="float32") = lv75[0]
lv77: R.Tensor((1, 64, 40, 40), dtype="float32") = lv75[1]
lv78: R.Tensor((1, 32, 40, 40), dtype="float32") = R.nn.conv2d(lv77, metadata["relax.expr.Constant"][26], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
lv79: R.Tensor((1, 32, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][27], R.shape([1, 32, 1, 1]))
lv80: R.Tensor((1, 32, 40, 40), dtype="float32") = R.nn.conv2d(lv77, metadata["relax.expr.Constant"][28], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
lv81: R.Tensor((1, 32, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][29], R.shape([1, 32, 1, 1]))
lv82: R.Tensor((1, 32, 40, 40), dtype="float32") = R.add(lv78, lv79)
lv83: R.Tensor((1, 32, 40, 40), dtype="float32") = R.add(lv80, lv81)
lv84: R.Tensor((1, 32, 40, 40), dtype="float32") = R.sigmoid(lv82)
lv85: R.Tensor((1, 32, 40, 40), dtype="float32") = R.sigmoid(lv83)
lv86: R.Tensor((1, 32, 40, 40), dtype="float32") = R.multiply(lv82, lv84)
lv87: R.Tensor((1, 32, 40, 40), dtype="float32") = R.nn.conv2d(lv86, metadata["relax.expr.Constant"][30], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
lv88: R.Tensor((1, 32, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][31], R.shape([1, 32, 1, 1]))
lv89: R.Tensor((1, 32, 40, 40), dtype="float32") = R.add(lv87, lv88)
lv90: R.Tensor((1, 32, 40, 40), dtype="float32") = R.sigmoid(lv89)
lv91: R.Tensor((1, 32, 40, 40), dtype="float32") = R.multiply(lv89, lv90)
lv92: R.Tensor((1, 32, 40, 40), dtype="float32") = R.nn.conv2d(lv91, metadata["relax.expr.Constant"][32], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
lv93: R.Tensor((1, 32, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][33], R.shape([1, 32, 1, 1]))
lv94: R.Tensor((1, 32, 40, 40), dtype="float32") = R.add(lv92, lv93)
lv95: R.Tensor((1, 32, 40, 40), dtype="float32") = R.sigmoid(lv94)
lv96: R.Tensor((1, 32, 40, 40), dtype="float32") = R.multiply(lv94, lv95)
lv97: R.Tensor((1, 32, 40, 40), dtype="float32") = R.add(lv86, lv96)
lv98: R.Tensor((1, 32, 40, 40), dtype="float32") = R.nn.conv2d(lv97, metadata["relax.expr.Constant"][34], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
lv99: R.Tensor((1, 32, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][35], R.shape([1, 32, 1, 1]))
lv100: R.Tensor((1, 32, 40, 40), dtype="float32") = R.add(lv98, lv99)
lv101: R.Tensor((1, 32, 40, 40), dtype="float32") = R.sigmoid(lv100)
lv102: R.Tensor((1, 32, 40, 40), dtype="float32") = R.multiply(lv100, lv101)
lv103: R.Tensor((1, 32, 40, 40), dtype="float32") = R.nn.conv2d(lv102, metadata["relax.expr.Constant"][36], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
lv104: R.Tensor((1, 32, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][37], R.shape([1, 32, 1, 1]))
lv105: R.Tensor((1, 32, 40, 40), dtype="float32") = R.add(lv103, lv104)
lv106: R.Tensor((1, 32, 40, 40), dtype="float32") = R.sigmoid(lv105)
lv107: R.Tensor((1, 32, 40, 40), dtype="float32") = R.multiply(lv105, lv106)
lv108: R.Tensor((1, 32, 40, 40), dtype="float32") = R.add(lv97, lv107)
lv109: R.Tensor((1, 32, 40, 40), dtype="float32") = R.multiply(lv83, lv85)
lv110: R.Tensor((1, 64, 40, 40), dtype="float32") = R.concat((lv108, lv109), axis=1)
lv111: R.Tensor((1, 64, 40, 40), dtype="float32") = R.nn.conv2d(lv110, metadata["relax.expr.Constant"][38], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
lv112: R.Tensor((1, 64, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][39], R.shape([1, 64, 1, 1]))
lv113: R.Tensor((1, 64, 40, 40), dtype="float32") = R.add(lv111, lv112)
lv114: R.Tensor((1, 64, 40, 40), dtype="float32") = R.sigmoid(lv113)
lv115: R.Tensor((1, 64, 40, 40), dtype="float32") = R.multiply(lv113, lv114)
lv116: R.Tensor((1, 192, 40, 40), dtype="float32") = R.concat((lv76, lv77, lv115), axis=1)
lv117: R.Tensor((1, 128, 40, 40), dtype="float32") = R.nn.conv2d(lv116, metadata["relax.expr.Constant"][40], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
lv118: R.Tensor((1, 128, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][41], R.shape([1, 128, 1, 1]))
lv119: R.Tensor((1, 128, 40, 40), dtype="float32") = R.add(lv117, lv118)
lv120: R.Tensor((1, 128, 40, 40), dtype="float32") = R.sigmoid(lv119)
lv121: R.Tensor((1, 128, 40, 40), dtype="float32") = R.multiply(lv119, lv120)
lv122: R.Tensor((1, 256, 20, 20), dtype="float32") = R.nn.conv2d(lv121, metadata["relax.expr.Constant"][42], strides=[2, 2], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
lv123: R.Tensor((1, 256, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][43], R.shape([1, 256, 1, 1]))
lv124: R.Tensor((1, 256, 20, 20), dtype="float32") = R.add(lv122, lv123)
lv125: R.Tensor((1, 256, 20, 20), dtype="float32") = R.sigmoid(lv124)
lv126: R.Tensor((1, 256, 20, 20), dtype="float32") = R.multiply(lv124, lv125)
lv127: R.Tensor((1, 256, 20, 20), dtype="float32") = R.nn.conv2d(lv126, metadata["relax.expr.Constant"][44], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
lv128: R.Tensor((1, 256, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][45], R.shape([1, 256, 1, 1]))
lv129: R.Tensor((1, 256, 20, 20), dtype="float32") = R.add(lv127, lv128)
lv130: R.Tensor((1, 256, 20, 20), dtype="float32") = R.sigmoid(lv129)
lv131: R.Tensor((1, 256, 20, 20), dtype="float32") = R.multiply(lv129, lv130)
lv132 = R.call_tir(cls.split3, (lv131,), out_sinfo=[R.Tensor((1, 128, 20, 20), dtype="float32"), R.Tensor((1, 128, 20, 20), dtype="float32")])
lv133: R.Tensor((1, 128, 20, 20), dtype="float32") = lv132[0]
lv134: R.Tensor((1, 128, 20, 20), dtype="float32") = lv132[1]
lv135: R.Tensor((1, 64, 20, 20), dtype="float32") = R.nn.conv2d(lv134, metadata["relax.expr.Constant"][46], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
lv136: R.Tensor((1, 64, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][47], R.shape([1, 64, 1, 1]))
lv137: R.Tensor((1, 64, 20, 20), dtype="float32") = R.nn.conv2d(lv134, metadata["relax.expr.Constant"][48], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
lv138: R.Tensor((1, 64, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][49], R.shape([1, 64, 1, 1]))
lv139: R.Tensor((1, 64, 20, 20), dtype="float32") = R.add(lv135, lv136)
lv140: R.Tensor((1, 64, 20, 20), dtype="float32") = R.add(lv137, lv138)
lv141: R.Tensor((1, 64, 20, 20), dtype="float32") = R.sigmoid(lv139)
lv142: R.Tensor((1, 64, 20, 20), dtype="float32") = R.sigmoid(lv140)
lv143: R.Tensor((1, 64, 20, 20), dtype="float32") = R.multiply(lv139, lv141)
lv144: R.Tensor((1, 64, 20, 20), dtype="float32") = R.nn.conv2d(lv143, metadata["relax.expr.Constant"][50], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
lv145: R.Tensor((1, 64, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][51], R.shape([1, 64, 1, 1]))
lv146: R.Tensor((1, 64, 20, 20), dtype="float32") = R.add(lv144, lv145)
lv147: R.Tensor((1, 64, 20, 20), dtype="float32") = R.sigmoid(lv146)
lv148: R.Tensor((1, 64, 20, 20), dtype="float32") = R.multiply(lv146, lv147)
lv149: R.Tensor((1, 64, 20, 20), dtype="float32") = R.nn.conv2d(lv148, metadata["relax.expr.Constant"][52], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
lv150: R.Tensor((1, 64, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][53], R.shape([1, 64, 1, 1]))
lv151: R.Tensor((1, 64, 20, 20), dtype="float32") = R.add(lv149, lv150)
lv152: R.Tensor((1, 64, 20, 20), dtype="float32") = R.sigmoid(lv151)
lv153: R.Tensor((1, 64, 20, 20), dtype="float32") = R.multiply(lv151, lv152)
lv154: R.Tensor((1, 64, 20, 20), dtype="float32") = R.add(lv143, lv153)
lv155: R.Tensor((1, 64, 20, 20), dtype="float32") = R.nn.conv2d(lv154, metadata["relax.expr.Constant"][54], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
lv156: R.Tensor((1, 64, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][55], R.shape([1, 64, 1, 1]))
lv157: R.Tensor((1, 64, 20, 20), dtype="float32") = R.add(lv155, lv156)
lv158: R.Tensor((1, 64, 20, 20), dtype="float32") = R.sigmoid(lv157)
lv159: R.Tensor((1, 64, 20, 20), dtype="float32") = R.multiply(lv157, lv158)
lv160: R.Tensor((1, 64, 20, 20), dtype="float32") = R.nn.conv2d(lv159, metadata["relax.expr.Constant"][56], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
lv161: R.Tensor((1, 64, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][57], R.shape([1, 64, 1, 1]))
lv162: R.Tensor((1, 64, 20, 20), dtype="float32") = R.add(lv160, lv161)
lv163: R.Tensor((1, 64, 20, 20), dtype="float32") = R.sigmoid(lv162)
lv164: R.Tensor((1, 64, 20, 20), dtype="float32") = R.multiply(lv162, lv163)
lv165: R.Tensor((1, 64, 20, 20), dtype="float32") = R.add(lv154, lv164)
lv166: R.Tensor((1, 64, 20, 20), dtype="float32") = R.multiply(lv140, lv142)
lv167: R.Tensor((1, 128, 20, 20), dtype="float32") = R.concat((lv165, lv166), axis=1)
lv168: R.Tensor((1, 128, 20, 20), dtype="float32") = R.nn.conv2d(lv167, metadata["relax.expr.Constant"][58], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
lv169: R.Tensor((1, 128, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][59], R.shape([1, 128, 1, 1]))
lv170: R.Tensor((1, 128, 20, 20), dtype="float32") = R.add(lv168, lv169)
lv171: R.Tensor((1, 128, 20, 20), dtype="float32") = R.sigmoid(lv170)
lv172: R.Tensor((1, 128, 20, 20), dtype="float32") = R.multiply(lv170, lv171)
lv173: R.Tensor((1, 384, 20, 20), dtype="float32") = R.concat((lv133, lv134, lv172), axis=1)
lv174: R.Tensor((1, 256, 20, 20), dtype="float32") = R.nn.conv2d(lv173, metadata["relax.expr.Constant"][60], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
lv175: R.Tensor((1, 256, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][61], R.shape([1, 256, 1, 1]))
lv176: R.Tensor((1, 256, 20, 20), dtype="float32") = R.add(lv174, lv175)
lv177: R.Tensor((1, 256, 20, 20), dtype="float32") = R.sigmoid(lv176)
lv178: R.Tensor((1, 256, 20, 20), dtype="float32") = R.multiply(lv176, lv177)
lv179: R.Tensor((1, 128, 20, 20), dtype="float32") = R.nn.conv2d(lv178, metadata["relax.expr.Constant"][62], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
lv180: R.Tensor((1, 128, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][63], R.shape([1, 128, 1, 1]))
lv181: R.Tensor((1, 128, 20, 20), dtype="float32") = R.add(lv179, lv180)
lv182: R.Tensor((1, 128, 20, 20), dtype="float32") = R.sigmoid(lv181)
lv183: R.Tensor((1, 128, 20, 20), dtype="float32") = R.multiply(lv181, lv182)
lv184: R.Tensor((1, 128, 20, 20), dtype="float32") = R.nn.max_pool2d(lv183, pool_size=[5, 5], strides=[1, 1], dilation=[1, 1], padding=[2, 2, 2, 2], ceil_mode=False, count_include_pad=False, layout="NCHW", out_layout="NCHW")
lv185: R.Tensor((1, 128, 20, 20), dtype="float32") = R.nn.max_pool2d(lv184, pool_size=[5, 5], strides=[1, 1], dilation=[1, 1], padding=[2, 2, 2, 2], ceil_mode=False, count_include_pad=False, layout="NCHW", out_layout="NCHW")
lv186: R.Tensor((1, 128, 20, 20), dtype="float32") = R.nn.max_pool2d(lv185, pool_size=[5, 5], strides=[1, 1], dilation=[1, 1], padding=[2, 2, 2, 2], ceil_mode=False, count_include_pad=False, layout="NCHW", out_layout="NCHW")
lv187: R.Tensor((1, 512, 20, 20), dtype="float32") = R.concat((lv183, lv184, lv185, lv186), axis=1)
lv188: R.Tensor((1, 256, 20, 20), dtype="float32") = R.nn.conv2d(lv187, metadata["relax.expr.Constant"][64], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
lv189: R.Tensor((1, 256, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][65], R.shape([1, 256, 1, 1]))
lv190: R.Tensor((1, 256, 20, 20), dtype="float32") = R.add(lv188, lv189)
lv191: R.Tensor((1, 256, 20, 20), dtype="float32") = R.sigmoid(lv190)
lv192: R.Tensor((1, 256, 20, 20), dtype="float32") = R.multiply(lv190, lv191)
lv193: R.Tensor((1, 256, 20, 20), dtype="float32") = R.nn.conv2d(lv192, metadata["relax.expr.Constant"][66], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
lv194: R.Tensor((1, 256, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][67], R.shape([1, 256, 1, 1]))
lv195: R.Tensor((1, 256, 20, 20), dtype="float32") = R.add(lv193, lv194)
lv196: R.Tensor((1, 256, 20, 20), dtype="float32") = R.sigmoid(lv195)
lv197: R.Tensor((1, 256, 20, 20), dtype="float32") = R.multiply(lv195, lv196)
lv198 = R.call_tir(cls.split3, (lv197,), out_sinfo=[R.Tensor((1, 128, 20, 20), dtype="float32"), R.Tensor((1, 128, 20, 20), dtype="float32")])
lv199: R.Tensor((1, 128, 20, 20), dtype="float32") = lv198[0]
lv200: R.Tensor((1, 128, 20, 20), dtype="float32") = lv198[1]
lv201: R.Tensor((1, 256, 20, 20), dtype="float32") = R.nn.conv2d(lv200, metadata["relax.expr.Constant"][68], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
lv202: R.Tensor((1, 256, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][69], R.shape([1, 256, 1, 1]))
lv203: R.Tensor((1, 256, 20, 20), dtype="float32") = R.add(lv201, lv202)
lv204: R.Tensor((1, 2, 128, 400), dtype="float32") = R.reshape(lv203, R.shape([1, 2, 128, 400]))
lv205 = R.call_tir(cls.split4, (lv204,), out_sinfo=[R.Tensor((1, 2, 32, 400), dtype="float32"), R.Tensor((1, 2, 32, 400), dtype="float32"), R.Tensor((1, 2, 64, 400), dtype="float32")])
lv206: R.Tensor((1, 2, 32, 400), dtype="float32") = lv205[0]
lv207: R.Tensor((1, 2, 32, 400), dtype="float32") = lv205[1]
lv208: R.Tensor((1, 2, 64, 400), dtype="float32") = lv205[2]
lv209: R.Tensor((1, 2, 400, 32), dtype="float32") = R.permute_dims(lv206, axes=[0, 1, 3, 2])
lv210: R.Tensor((1, 128, 20, 20), dtype="float32") = R.reshape(lv208, R.shape([1, 128, 20, 20]))
lv211: R.Tensor((1, 128, 20, 20), dtype="float32") = R.nn.conv2d(lv210, metadata["relax.expr.Constant"][70], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=128, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
lv212: R.Tensor((1, 128, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][71], R.shape([1, 128, 1, 1]))
lv213: R.Tensor((1, 2, 400, 400), dtype="float32") = R.matmul(lv209, lv207, out_dtype="void")
lv214: R.Tensor((1, 2, 400, 400), dtype="float32") = R.multiply(lv213, R.const(0.1767766922712326, "float32"))
lv215: R.Tensor((1, 2, 400, 400), dtype="float32") = R.nn.softmax(lv214, axis=-1)
lv216: R.Tensor((1, 2, 400, 400), dtype="float32") = R.permute_dims(lv215, axes=[0, 1, 3, 2])
lv217: R.Tensor((1, 2, 64, 400), dtype="float32") = R.matmul(lv208, lv216, out_dtype="void")
lv218: R.Tensor((1, 128, 20, 20), dtype="float32") = R.reshape(lv217, R.shape([1, 128, 20, 20]))
lv219: R.Tensor((1, 128, 20, 20), dtype="float32") = R.add(lv211, lv212)
lv220: R.Tensor((1, 128, 20, 20), dtype="float32") = R.add(lv218, lv219)
lv221: R.Tensor((1, 128, 20, 20), dtype="float32") = R.nn.conv2d(lv220, metadata["relax.expr.Constant"][72], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
lv222: R.Tensor((1, 128, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][73], R.shape([1, 128, 1, 1]))
lv223: R.Tensor((1, 128, 20, 20), dtype="float32") = R.add(lv221, lv222)
lv224: R.Tensor((1, 128, 20, 20), dtype="float32") = R.add(lv200, lv223)
lv225: R.Tensor((1, 256, 20, 20), dtype="float32") = R.nn.conv2d(lv224, metadata["relax.expr.Constant"][74], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
lv226: R.Tensor((1, 256, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][75], R.shape([1, 256, 1, 1]))
lv227: R.Tensor((1, 256, 20, 20), dtype="float32") = R.add(lv225, lv226)
lv228: R.Tensor((1, 256, 20, 20), dtype="float32") = R.sigmoid(lv227)
lv229: R.Tensor((1, 256, 20, 20), dtype="float32") = R.multiply(lv227, lv228)
lv230: R.Tensor((1, 128, 20, 20), dtype="float32") = R.nn.conv2d(lv229, metadata["relax.expr.Constant"][76], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
lv231: R.Tensor((1, 128, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][77], R.shape([1, 128, 1, 1]))
lv232: R.Tensor((1, 128, 20, 20), dtype="float32") = R.add(lv230, lv231)
lv233: R.Tensor((1, 128, 20, 20), dtype="float32") = R.add(lv224, lv232)
lv234: R.Tensor((1, 256, 20, 20), dtype="float32") = R.concat((lv199, lv233), axis=1)
lv235: R.Tensor((1, 256, 20, 20), dtype="float32") = R.nn.conv2d(lv234, metadata["relax.expr.Constant"][78], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
lv236: R.Tensor((1, 256, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][79], R.shape([1, 256, 1, 1]))
lv237: R.Tensor((1, 256, 20, 20), dtype="float32") = R.add(lv235, lv236)
lv238: R.Tensor((1, 256, 20, 20), dtype="float32") = R.sigmoid(lv237)
lv239: R.Tensor((1, 256, 20, 20), dtype="float32") = R.multiply(lv237, lv238)
lv240: R.Tensor((1, 256, 40, 40), dtype="float32") = R.image.resize2d(lv239, R.shape([40, 40]), roi=[T.float32(0.0), T.float32(0.0), T.float32(0.0), T.float32(0.0)], layout="NCHW", method="nearest_neighbor", coordinate_transformation_mode="asymmetric", rounding_method="floor", cubic_alpha=-0.75, cubic_exclude=0, extrapolation_value=0.0, out_dtype="void")
lv241: R.Tensor((1, 384, 40, 40), dtype="float32") = R.concat((lv240, lv121), axis=1)
lv242: R.Tensor((1, 128, 40, 40), dtype="float32") = R.nn.conv2d(lv241, metadata["relax.expr.Constant"][80], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
lv243: R.Tensor((1, 128, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][81], R.shape([1, 128, 1, 1]))
lv244: R.Tensor((1, 128, 40, 40), dtype="float32") = R.add(lv242, lv243)
lv245: R.Tensor((1, 128, 40, 40), dtype="float32") = R.sigmoid(lv244)
lv246: R.Tensor((1, 128, 40, 40), dtype="float32") = R.multiply(lv244, lv245)
lv247 = R.call_tir(cls.split2, (lv246,), out_sinfo=[R.Tensor((1, 64, 40, 40), dtype="float32"), R.Tensor((1, 64, 40, 40), dtype="float32")])
lv248: R.Tensor((1, 64, 40, 40), dtype="float32") = lv247[0]
lv249: R.Tensor((1, 64, 40, 40), dtype="float32") = lv247[1]
lv250: R.Tensor((1, 32, 40, 40), dtype="float32") = R.nn.conv2d(lv249, metadata["relax.expr.Constant"][82], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
lv251: R.Tensor((1, 32, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][83], R.shape([1, 32, 1, 1]))
lv252: R.Tensor((1, 32, 40, 40), dtype="float32") = R.add(lv250, lv251)
lv253: R.Tensor((1, 32, 40, 40), dtype="float32") = R.sigmoid(lv252)
lv254: R.Tensor((1, 32, 40, 40), dtype="float32") = R.multiply(lv252, lv253)
lv255: R.Tensor((1, 64, 40, 40), dtype="float32") = R.nn.conv2d(lv254, metadata["relax.expr.Constant"][84], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
lv256: R.Tensor((1, 64, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][85], R.shape([1, 64, 1, 1]))
lv257: R.Tensor((1, 64, 40, 40), dtype="float32") = R.add(lv255, lv256)
lv258: R.Tensor((1, 64, 40, 40), dtype="float32") = R.sigmoid(lv257)
lv259: R.Tensor((1, 64, 40, 40), dtype="float32") = R.multiply(lv257, lv258)
lv260: R.Tensor((1, 64, 40, 40), dtype="float32") = R.add(lv249, lv259)
lv261: R.Tensor((1, 192, 40, 40), dtype="float32") = R.concat((lv248, lv249, lv260), axis=1)
lv262: R.Tensor((1, 128, 40, 40), dtype="float32") = R.nn.conv2d(lv261, metadata["relax.expr.Constant"][86], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
lv263: R.Tensor((1, 128, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][87], R.shape([1, 128, 1, 1]))
lv264: R.Tensor((1, 128, 40, 40), dtype="float32") = R.add(lv262, lv263)
lv265: R.Tensor((1, 128, 40, 40), dtype="float32") = R.sigmoid(lv264)
lv266: R.Tensor((1, 128, 40, 40), dtype="float32") = R.multiply(lv264, lv265)
lv267: R.Tensor((1, 128, 80, 80), dtype="float32") = R.image.resize2d(lv266, R.shape([80, 80]), roi=[T.float32(0.0), T.float32(0.0), T.float32(0.0), T.float32(0.0)], layout="NCHW", method="nearest_neighbor", coordinate_transformation_mode="asymmetric", rounding_method="floor", cubic_alpha=-0.75, cubic_exclude=0, extrapolation_value=0.0, out_dtype="void")
lv268: R.Tensor((1, 256, 80, 80), dtype="float32") = R.concat((lv267, lv64), axis=1)
lv269: R.Tensor((1, 64, 80, 80), dtype="float32") = R.nn.conv2d(lv268, metadata["relax.expr.Constant"][88], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
lv270: R.Tensor((1, 64, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][89], R.shape([1, 64, 1, 1]))
lv271: R.Tensor((1, 64, 80, 80), dtype="float32") = R.add(lv269, lv270)
lv272: R.Tensor((1, 64, 80, 80), dtype="float32") = R.sigmoid(lv271)
lv273: R.Tensor((1, 64, 80, 80), dtype="float32") = R.multiply(lv271, lv272)
lv274 = R.call_tir(cls.split1, (lv273,), out_sinfo=[R.Tensor((1, 32, 80, 80), dtype="float32"), R.Tensor((1, 32, 80, 80), dtype="float32")])
lv275: R.Tensor((1, 32, 80, 80), dtype="float32") = lv274[0]
lv276: R.Tensor((1, 32, 80, 80), dtype="float32") = lv274[1]
lv277: R.Tensor((1, 16, 80, 80), dtype="float32") = R.nn.conv2d(lv276, metadata["relax.expr.Constant"][90], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
lv278: R.Tensor((1, 16, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][91], R.shape([1, 16, 1, 1]))
lv279: R.Tensor((1, 16, 80, 80), dtype="float32") = R.add(lv277, lv278)
lv280: R.Tensor((1, 16, 80, 80), dtype="float32") = R.sigmoid(lv279)
lv281: R.Tensor((1, 16, 80, 80), dtype="float32") = R.multiply(lv279, lv280)
lv282: R.Tensor((1, 32, 80, 80), dtype="float32") = R.nn.conv2d(lv281, metadata["relax.expr.Constant"][92], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
lv283: R.Tensor((1, 32, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][93], R.shape([1, 32, 1, 1]))
lv284: R.Tensor((1, 32, 80, 80), dtype="float32") = R.add(lv282, lv283)
lv285: R.Tensor((1, 32, 80, 80), dtype="float32") = R.sigmoid(lv284)
lv286: R.Tensor((1, 32, 80, 80), dtype="float32") = R.multiply(lv284, lv285)
lv287: R.Tensor((1, 32, 80, 80), dtype="float32") = R.add(lv276, lv286)
lv288: R.Tensor((1, 96, 80, 80), dtype="float32") = R.concat((lv275, lv276, lv287), axis=1)
lv289: R.Tensor((1, 64, 80, 80), dtype="float32") = R.nn.conv2d(lv288, metadata["relax.expr.Constant"][94], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
lv290: R.Tensor((1, 64, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][95], R.shape([1, 64, 1, 1]))
lv291: R.Tensor((1, 64, 80, 80), dtype="float32") = R.add(lv289, lv290)
lv292: R.Tensor((1, 64, 80, 80), dtype="float32") = R.sigmoid(lv291)
lv293: R.Tensor((1, 64, 80, 80), dtype="float32") = R.multiply(lv291, lv292)
lv294: R.Tensor((1, 64, 40, 40), dtype="float32") = R.nn.conv2d(lv293, metadata["relax.expr.Constant"][96], strides=[2, 2], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
lv295: R.Tensor((1, 64, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][97], R.shape([1, 64, 1, 1]))
lv296: R.Tensor((1, 64, 80, 80), dtype="float32") = R.nn.conv2d(lv293, metadata["relax.expr.Constant"][98], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
lv297: R.Tensor((1, 64, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][99], R.shape([1, 64, 1, 1]))
lv298: R.Tensor((1, 64, 80, 80), dtype="float32") = R.nn.conv2d(lv293, metadata["relax.expr.Constant"][100], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=64, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
lv299: R.Tensor((1, 64, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][101], R.shape([1, 64, 1, 1]))
lv300: R.Tensor((1, 64, 40, 40), dtype="float32") = R.add(lv294, lv295)
lv301: R.Tensor((1, 64, 80, 80), dtype="float32") = R.add(lv296, lv297)
lv302: R.Tensor((1, 64, 80, 80), dtype="float32") = R.add(lv298, lv299)
lv303: R.Tensor((1, 64, 40, 40), dtype="float32") = R.sigmoid(lv300)
lv304: R.Tensor((1, 64, 80, 80), dtype="float32") = R.sigmoid(lv301)
lv305: R.Tensor((1, 64, 80, 80), dtype="float32") = R.sigmoid(lv302)
lv306: R.Tensor((1, 64, 40, 40), dtype="float32") = R.multiply(lv300, lv303)
lv307: R.Tensor((1, 64, 80, 80), dtype="float32") = R.multiply(lv301, lv304)
lv308: R.Tensor((1, 64, 80, 80), dtype="float32") = R.nn.conv2d(lv307, metadata["relax.expr.Constant"][102], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
lv309: R.Tensor((1, 64, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][103], R.shape([1, 64, 1, 1]))
lv310: R.Tensor((1, 64, 80, 80), dtype="float32") = R.multiply(lv302, lv305)
lv311: R.Tensor((1, 80, 80, 80), dtype="float32") = R.nn.conv2d(lv310, metadata["relax.expr.Constant"][104], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
lv312: R.Tensor((1, 80, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][105], R.shape([1, 80, 1, 1]))
lv313: R.Tensor((1, 192, 40, 40), dtype="float32") = R.concat((lv306, lv266), axis=1)
lv314: R.Tensor((1, 128, 40, 40), dtype="float32") = R.nn.conv2d(lv313, metadata["relax.expr.Constant"][106], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
lv315: R.Tensor((1, 128, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][107], R.shape([1, 128, 1, 1]))
lv316: R.Tensor((1, 64, 80, 80), dtype="float32") = R.add(lv308, lv309)
lv317: R.Tensor((1, 80, 80, 80), dtype="float32") = R.add(lv311, lv312)
lv318: R.Tensor((1, 128, 40, 40), dtype="float32") = R.add(lv314, lv315)
lv319: R.Tensor((1, 64, 80, 80), dtype="float32") = R.sigmoid(lv316)
lv320: R.Tensor((1, 80, 80, 80), dtype="float32") = R.sigmoid(lv317)
lv321: R.Tensor((1, 128, 40, 40), dtype="float32") = R.sigmoid(lv318)
lv322: R.Tensor((1, 64, 80, 80), dtype="float32") = R.multiply(lv316, lv319)
lv323: R.Tensor((1, 64, 80, 80), dtype="float32") = R.nn.conv2d(lv322, metadata["relax.expr.Constant"][108], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
lv324: R.Tensor((1, 64, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][109], R.shape([1, 64, 1, 1]))
lv325: R.Tensor((1, 80, 80, 80), dtype="float32") = R.multiply(lv317, lv320)
lv326: R.Tensor((1, 80, 80, 80), dtype="float32") = R.nn.conv2d(lv325, metadata["relax.expr.Constant"][110], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=80, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
lv327: R.Tensor((1, 80, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][111], R.shape([1, 80, 1, 1]))
lv328: R.Tensor((1, 128, 40, 40), dtype="float32") = R.multiply(lv318, lv321)
lv329 = R.call_tir(cls.split2, (lv328,), out_sinfo=[R.Tensor((1, 64, 40, 40), dtype="float32"), R.Tensor((1, 64, 40, 40), dtype="float32")])
lv330: R.Tensor((1, 64, 40, 40), dtype="float32") = lv329[0]
lv331: R.Tensor((1, 64, 40, 40), dtype="float32") = lv329[1]
lv332: R.Tensor((1, 80, 80, 80), dtype="float32") = R.add(lv326, lv327)
lv333: R.Tensor((1, 32, 40, 40), dtype="float32") = R.nn.conv2d(lv331, metadata["relax.expr.Constant"][112], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
lv334: R.Tensor((1, 32, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][113], R.shape([1, 32, 1, 1]))
lv335: R.Tensor((1, 80, 80, 80), dtype="float32") = R.sigmoid(lv332)
lv336: R.Tensor((1, 32, 40, 40), dtype="float32") = R.add(lv333, lv334)
lv337: R.Tensor((1, 80, 80, 80), dtype="float32") = R.multiply(lv332, lv335)
lv338: R.Tensor((1, 80, 80, 80), dtype="float32") = R.nn.conv2d(lv337, metadata["relax.expr.Constant"][114], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
lv339: R.Tensor((1, 80, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][115], R.shape([1, 80, 1, 1]))
lv340: R.Tensor((1, 32, 40, 40), dtype="float32") = R.sigmoid(lv336)
lv341: R.Tensor((1, 80, 80, 80), dtype="float32") = R.add(lv338, lv339)
lv342: R.Tensor((1, 32, 40, 40), dtype="float32") = R.multiply(lv336, lv340)
lv343: R.Tensor((1, 64, 40, 40), dtype="float32") = R.nn.conv2d(lv342, metadata["relax.expr.Constant"][116], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
lv344: R.Tensor((1, 64, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][117], R.shape([1, 64, 1, 1]))
lv345: R.Tensor((1, 80, 80, 80), dtype="float32") = R.sigmoid(lv341)
lv346: R.Tensor((1, 64, 40, 40), dtype="float32") = R.add(lv343, lv344)
lv347: R.Tensor((1, 80, 80, 80), dtype="float32") = R.multiply(lv341, lv345)
lv348: R.Tensor((1, 80, 80, 80), dtype="float32") = R.nn.conv2d(lv347, metadata["relax.expr.Constant"][118], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
lv349: R.Tensor((1, 80, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][119], R.shape([1, 80, 1, 1]))
lv350: R.Tensor((1, 64, 40, 40), dtype="float32") = R.sigmoid(lv346)
lv351: R.Tensor((1, 64, 80, 80), dtype="float32") = R.add(lv323, lv324)
lv352: R.Tensor((1, 80, 80, 80), dtype="float32") = R.add(lv348, lv349)
lv353: R.Tensor((1, 64, 40, 40), dtype="float32") = R.multiply(lv346, lv350)
lv354: R.Tensor((1, 144, 80, 80), dtype="float32") = R.concat((lv351, lv352), axis=1)
lv355: R.Tensor((1, 64, 40, 40), dtype="float32") = R.add(lv331, lv353)
lv356: R.Tensor((1, 192, 40, 40), dtype="float32") = R.concat((lv330, lv331, lv355), axis=1)
lv357: R.Tensor((1, 128, 40, 40), dtype="float32") = R.nn.conv2d(lv356, metadata["relax.expr.Constant"][120], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
lv358: R.Tensor((1, 128, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][121], R.shape([1, 128, 1, 1]))
lv359: R.Tensor((1, 128, 40, 40), dtype="float32") = R.add(lv357, lv358)
lv360: R.Tensor((1, 128, 40, 40), dtype="float32") = R.sigmoid(lv359)
lv361: R.Tensor((1, 128, 40, 40), dtype="float32") = R.multiply(lv359, lv360)
lv362: R.Tensor((1, 128, 20, 20), dtype="float32") = R.nn.conv2d(lv361, metadata["relax.expr.Constant"][122], strides=[2, 2], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
lv363: R.Tensor((1, 128, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][123], R.shape([1, 128, 1, 1]))
lv364: R.Tensor((1, 64, 40, 40), dtype="float32") = R.nn.conv2d(lv361, metadata["relax.expr.Constant"][124], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
lv365: R.Tensor((1, 64, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][125], R.shape([1, 64, 1, 1]))
lv366: R.Tensor((1, 128, 40, 40), dtype="float32") = R.nn.conv2d(lv361, metadata["relax.expr.Constant"][126], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=128, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
lv367: R.Tensor((1, 128, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][127], R.shape([1, 128, 1, 1]))
lv368: R.Tensor((1, 128, 20, 20), dtype="float32") = R.add(lv362, lv363)
lv369: R.Tensor((1, 64, 40, 40), dtype="float32") = R.add(lv364, lv365)
lv370: R.Tensor((1, 128, 40, 40), dtype="float32") = R.add(lv366, lv367)
lv371: R.Tensor((1, 128, 20, 20), dtype="float32") = R.sigmoid(lv368)
lv372: R.Tensor((1, 64, 40, 40), dtype="float32") = R.sigmoid(lv369)
lv373: R.Tensor((1, 128, 40, 40), dtype="float32") = R.sigmoid(lv370)
lv374: R.Tensor((1, 128, 20, 20), dtype="float32") = R.multiply(lv368, lv371)
lv375: R.Tensor((1, 64, 40, 40), dtype="float32") = R.multiply(lv369, lv372)
lv376: R.Tensor((1, 64, 40, 40), dtype="float32") = R.nn.conv2d(lv375, metadata["relax.expr.Constant"][128], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
lv377: R.Tensor((1, 64, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][129], R.shape([1, 64, 1, 1]))
lv378: R.Tensor((1, 128, 40, 40), dtype="float32") = R.multiply(lv370, lv373)
lv379: R.Tensor((1, 80, 40, 40), dtype="float32") = R.nn.conv2d(lv378, metadata["relax.expr.Constant"][130], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
lv380: R.Tensor((1, 80, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][131], R.shape([1, 80, 1, 1]))
lv381: R.Tensor((1, 384, 20, 20), dtype="float32") = R.concat((lv374, lv239), axis=1)
lv382: R.Tensor((1, 256, 20, 20), dtype="float32") = R.nn.conv2d(lv381, metadata["relax.expr.Constant"][132], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
lv383: R.Tensor((1, 256, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][133], R.shape([1, 256, 1, 1]))
lv384: R.Tensor((1, 64, 40, 40), dtype="float32") = R.add(lv376, lv377)
lv385: R.Tensor((1, 80, 40, 40), dtype="float32") = R.add(lv379, lv380)
lv386: R.Tensor((1, 256, 20, 20), dtype="float32") = R.add(lv382, lv383)
lv387: R.Tensor((1, 64, 40, 40), dtype="float32") = R.sigmoid(lv384)
lv388: R.Tensor((1, 80, 40, 40), dtype="float32") = R.sigmoid(lv385)
lv389: R.Tensor((1, 256, 20, 20), dtype="float32") = R.sigmoid(lv386)
lv390: R.Tensor((1, 64, 40, 40), dtype="float32") = R.multiply(lv384, lv387)
lv391: R.Tensor((1, 64, 40, 40), dtype="float32") = R.nn.conv2d(lv390, metadata["relax.expr.Constant"][134], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
lv392: R.Tensor((1, 64, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][135], R.shape([1, 64, 1, 1]))
lv393: R.Tensor((1, 80, 40, 40), dtype="float32") = R.multiply(lv385, lv388)
lv394: R.Tensor((1, 80, 40, 40), dtype="float32") = R.nn.conv2d(lv393, metadata["relax.expr.Constant"][136], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=80, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
lv395: R.Tensor((1, 80, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][137], R.shape([1, 80, 1, 1]))
lv396: R.Tensor((1, 256, 20, 20), dtype="float32") = R.multiply(lv386, lv389)
lv397 = R.call_tir(cls.split3, (lv396,), out_sinfo=[R.Tensor((1, 128, 20, 20), dtype="float32"), R.Tensor((1, 128, 20, 20), dtype="float32")])
lv398: R.Tensor((1, 128, 20, 20), dtype="float32") = lv397[0]
lv399: R.Tensor((1, 128, 20, 20), dtype="float32") = lv397[1]
lv400: R.Tensor((1, 80, 40, 40), dtype="float32") = R.add(lv394, lv395)
lv401: R.Tensor((1, 64, 20, 20), dtype="float32") = R.nn.conv2d(lv399, metadata["relax.expr.Constant"][138], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
lv402: R.Tensor((1, 64, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][139], R.shape([1, 64, 1, 1]))
lv403: R.Tensor((1, 64, 20, 20), dtype="float32") = R.nn.conv2d(lv399, metadata["relax.expr.Constant"][140], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
lv404: R.Tensor((1, 64, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][141], R.shape([1, 64, 1, 1]))
lv405: R.Tensor((1, 80, 40, 40), dtype="float32") = R.sigmoid(lv400)
lv406: R.Tensor((1, 64, 20, 20), dtype="float32") = R.add(lv401, lv402)
lv407: R.Tensor((1, 64, 20, 20), dtype="float32") = R.add(lv403, lv404)
lv408: R.Tensor((1, 80, 40, 40), dtype="float32") = R.multiply(lv400, lv405)
lv409: R.Tensor((1, 80, 40, 40), dtype="float32") = R.nn.conv2d(lv408, metadata["relax.expr.Constant"][142], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
lv410: R.Tensor((1, 80, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][143], R.shape([1, 80, 1, 1]))
lv411: R.Tensor((1, 64, 20, 20), dtype="float32") = R.sigmoid(lv406)
lv412: R.Tensor((1, 64, 20, 20), dtype="float32") = R.sigmoid(lv407)
lv413: R.Tensor((1, 80, 40, 40), dtype="float32") = R.add(lv409, lv410)
lv414: R.Tensor((1, 64, 20, 20), dtype="float32") = R.multiply(lv406, lv411)
lv415: R.Tensor((1, 64, 20, 20), dtype="float32") = R.nn.conv2d(lv414, metadata["relax.expr.Constant"][144], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
lv416: R.Tensor((1, 64, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][145], R.shape([1, 64, 1, 1]))
lv417: R.Tensor((1, 80, 40, 40), dtype="float32") = R.sigmoid(lv413)
lv418: R.Tensor((1, 64, 20, 20), dtype="float32") = R.add(lv415, lv416)
lv419: R.Tensor((1, 80, 40, 40), dtype="float32") = R.multiply(lv413, lv417)
lv420: R.Tensor((1, 80, 40, 40), dtype="float32") = R.nn.conv2d(lv419, metadata["relax.expr.Constant"][146], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
lv421: R.Tensor((1, 80, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][147], R.shape([1, 80, 1, 1]))
lv422: R.Tensor((1, 64, 20, 20), dtype="float32") = R.sigmoid(lv418)
lv423: R.Tensor((1, 64, 40, 40), dtype="float32") = R.add(lv391, lv392)
lv424: R.Tensor((1, 80, 40, 40), dtype="float32") = R.add(lv420, lv421)
lv425: R.Tensor((1, 64, 20, 20), dtype="float32") = R.multiply(lv418, lv422)
lv426: R.Tensor((1, 64, 20, 20), dtype="float32") = R.nn.conv2d(lv425, metadata["relax.expr.Constant"][148], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
lv427: R.Tensor((1, 64, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][149], R.shape([1, 64, 1, 1]))
lv428: R.Tensor((1, 144, 40, 40), dtype="float32") = R.concat((lv423, lv424), axis=1)
lv429: R.Tensor((1, 64, 20, 20), dtype="float32") = R.add(lv426, lv427)
lv430: R.Tensor((1, 64, 20, 20), dtype="float32") = R.sigmoid(lv429)
lv431: R.Tensor((1, 64, 20, 20), dtype="float32") = R.multiply(lv429, lv430)
lv432: R.Tensor((1, 64, 20, 20), dtype="float32") = R.add(lv414, lv431)
lv433: R.Tensor((1, 64, 20, 20), dtype="float32") = R.nn.conv2d(lv432, metadata["relax.expr.Constant"][150], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
lv434: R.Tensor((1, 64, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][151], R.shape([1, 64, 1, 1]))
lv435: R.Tensor((1, 64, 20, 20), dtype="float32") = R.add(lv433, lv434)
lv436: R.Tensor((1, 64, 20, 20), dtype="float32") = R.sigmoid(lv435)
lv437: R.Tensor((1, 64, 20, 20), dtype="float32") = R.multiply(lv435, lv436)
lv438: R.Tensor((1, 64, 20, 20), dtype="float32") = R.nn.conv2d(lv437, metadata["relax.expr.Constant"][152], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
lv439: R.Tensor((1, 64, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][153], R.shape([1, 64, 1, 1]))
lv440: R.Tensor((1, 64, 20, 20), dtype="float32") = R.add(lv438, lv439)
lv441: R.Tensor((1, 64, 20, 20), dtype="float32") = R.sigmoid(lv440)
lv442: R.Tensor((1, 64, 20, 20), dtype="float32") = R.multiply(lv440, lv441)
lv443: R.Tensor((1, 64, 20, 20), dtype="float32") = R.add(lv432, lv442)
lv444: R.Tensor((1, 64, 20, 20), dtype="float32") = R.multiply(lv407, lv412)
lv445: R.Tensor((1, 128, 20, 20), dtype="float32") = R.concat((lv443, lv444), axis=1)
lv446: R.Tensor((1, 128, 20, 20), dtype="float32") = R.nn.conv2d(lv445, metadata["relax.expr.Constant"][154], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
lv447: R.Tensor((1, 128, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][155], R.shape([1, 128, 1, 1]))
lv448: R.Tensor((1, 128, 20, 20), dtype="float32") = R.add(lv446, lv447)
lv449: R.Tensor((1, 128, 20, 20), dtype="float32") = R.sigmoid(lv448)
lv450: R.Tensor((1, 128, 20, 20), dtype="float32") = R.multiply(lv448, lv449)
lv451: R.Tensor((1, 384, 20, 20), dtype="float32") = R.concat((lv398, lv399, lv450), axis=1)
lv452: R.Tensor((1, 256, 20, 20), dtype="float32") = R.nn.conv2d(lv451, metadata["relax.expr.Constant"][156], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
lv453: R.Tensor((1, 256, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][157], R.shape([1, 256, 1, 1]))
lv454: R.Tensor((1, 256, 20, 20), dtype="float32") = R.add(lv452, lv453)
lv455: R.Tensor((1, 256, 20, 20), dtype="float32") = R.sigmoid(lv454)
lv456: R.Tensor((1, 256, 20, 20), dtype="float32") = R.multiply(lv454, lv455)
lv457: R.Tensor((1, 64, 20, 20), dtype="float32") = R.nn.conv2d(lv456, metadata["relax.expr.Constant"][158], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
lv458: R.Tensor((1, 64, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][159], R.shape([1, 64, 1, 1]))
lv459: R.Tensor((1, 256, 20, 20), dtype="float32") = R.nn.conv2d(lv456, metadata["relax.expr.Constant"][160], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=256, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
lv460: R.Tensor((1, 256, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][161], R.shape([1, 256, 1, 1]))
lv461: R.Tensor((1, 64, 20, 20), dtype="float32") = R.add(lv457, lv458)
lv462: R.Tensor((1, 256, 20, 20), dtype="float32") = R.add(lv459, lv460)
lv463: R.Tensor((1, 64, 20, 20), dtype="float32") = R.sigmoid(lv461)
lv464: R.Tensor((1, 256, 20, 20), dtype="float32") = R.sigmoid(lv462)
lv465: R.Tensor((1, 64, 20, 20), dtype="float32") = R.multiply(lv461, lv463)
lv466: R.Tensor((1, 64, 20, 20), dtype="float32") = R.nn.conv2d(lv465, metadata["relax.expr.Constant"][162], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
lv467: R.Tensor((1, 64, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][163], R.shape([1, 64, 1, 1]))
lv468: R.Tensor((1, 256, 20, 20), dtype="float32") = R.multiply(lv462, lv464)
lv469: R.Tensor((1, 80, 20, 20), dtype="float32") = R.nn.conv2d(lv468, metadata["relax.expr.Constant"][164], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
lv470: R.Tensor((1, 80, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][165], R.shape([1, 80, 1, 1]))
lv471: R.Tensor((1, 64, 20, 20), dtype="float32") = R.add(lv466, lv467)
lv472: R.Tensor((1, 80, 20, 20), dtype="float32") = R.add(lv469, lv470)
lv473: R.Tensor((1, 64, 20, 20), dtype="float32") = R.sigmoid(lv471)
lv474: R.Tensor((1, 80, 20, 20), dtype="float32") = R.sigmoid(lv472)
lv475: R.Tensor((1, 64, 20, 20), dtype="float32") = R.multiply(lv471, lv473)
lv476: R.Tensor((1, 64, 20, 20), dtype="float32") = R.nn.conv2d(lv475, metadata["relax.expr.Constant"][166], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
lv477: R.Tensor((1, 64, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][167], R.shape([1, 64, 1, 1]))
lv478: R.Tensor((1, 80, 20, 20), dtype="float32") = R.multiply(lv472, lv474)
lv479: R.Tensor((1, 80, 20, 20), dtype="float32") = R.nn.conv2d(lv478, metadata["relax.expr.Constant"][168], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=80, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
lv480: R.Tensor((1, 80, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][169], R.shape([1, 80, 1, 1]))
lv481: R.Tensor((1, 80, 20, 20), dtype="float32") = R.add(lv479, lv480)
lv482: R.Tensor((1, 80, 20, 20), dtype="float32") = R.sigmoid(lv481)
lv483: R.Tensor((1, 80, 20, 20), dtype="float32") = R.multiply(lv481, lv482)
lv484: R.Tensor((1, 80, 20, 20), dtype="float32") = R.nn.conv2d(lv483, metadata["relax.expr.Constant"][170], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
lv485: R.Tensor((1, 80, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][171], R.shape([1, 80, 1, 1]))
lv486: R.Tensor((1, 80, 20, 20), dtype="float32") = R.add(lv484, lv485)
lv487: R.Tensor((1, 80, 20, 20), dtype="float32") = R.sigmoid(lv486)
lv488: R.Tensor((1, 80, 20, 20), dtype="float32") = R.multiply(lv486, lv487)
lv489: R.Tensor((1, 80, 20, 20), dtype="float32") = R.nn.conv2d(lv488, metadata["relax.expr.Constant"][172], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
lv490: R.Tensor((1, 80, 1, 1), dtype="float32") = R.reshape(metadata["relax.expr.Constant"][173], R.shape([1, 80, 1, 1]))
lv491: R.Tensor((1, 64, 20, 20), dtype="float32") = R.add(lv476, lv477)
lv492: R.Tensor((1, 80, 20, 20), dtype="float32") = R.add(lv489, lv490)
lv493: R.Tensor((1, 144, 20, 20), dtype="float32") = R.concat((lv491, lv492), axis=1)
lv494: R.Tensor((1, 144, 6400), dtype="float32") = R.reshape(lv354, R.shape([1, 144, 6400]))
lv495: R.Tensor((1, 144, 1600), dtype="float32") = R.reshape(lv428, R.shape([1, 144, 1600]))
lv496: R.Tensor((1, 144, 400), dtype="float32") = R.reshape(lv493, R.shape([1, 144, 400]))
lv497: R.Tensor((1, 144, 8400), dtype="float32") = R.concat((lv494, lv495, lv496), axis=2)
lv498 = R.call_tir(cls.split5, (lv497,), out_sinfo=[R.Tensor((1, 64, 8400), dtype="float32"), R.Tensor((1, 80, 8400), dtype="float32")])
lv499: R.Tensor((1, 64, 8400), dtype="float32") = lv498[0]
lv500: R.Tensor((1, 80, 8400), dtype="float32") = lv498[1]
lv501: R.Tensor((1, 4, 16, 8400), dtype="float32") = R.reshape(lv499, R.shape([1, 4, 16, 8400]))
lv502: R.Tensor((1, 16, 4, 8400), dtype="float32") = R.permute_dims(lv501, axes=[0, 2, 1, 3])
lv503: R.Tensor((1, 16, 4, 8400), dtype="float32") = R.nn.softmax(lv502, axis=1)
lv504: R.Tensor((1, 1, 4, 8400), dtype="float32") = R.nn.conv2d(lv503, metadata["relax.expr.Constant"][174], strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="void")
lv505: R.Tensor((1, 4, 8400), dtype="float32") = R.reshape(lv504, R.shape([1, 4, 8400]))
lv506: R.Tensor((1, 2, 8400), dtype="float32") = R.strided_slice(lv505, (R.prim_value(1),), (R.prim_value(0),), (R.prim_value(2),), (R.prim_value(1),), assume_inbound=False)
lv507: R.Tensor((1, 2, 8400), dtype="float32") = R.strided_slice(lv505, (R.prim_value(1),), (R.prim_value(2),), (R.prim_value(4),), (R.prim_value(1),), assume_inbound=False)
lv508: R.Tensor((1, 2, 8400), dtype="float32") = R.subtract(metadata["relax.expr.Constant"][175], lv506)
lv509: R.Tensor((1, 2, 8400), dtype="float32") = R.add(metadata["relax.expr.Constant"][176], lv507)
lv510: R.Tensor((1, 2, 8400), dtype="float32") = R.add(lv508, lv509)
lv511: R.Tensor((1, 2, 8400), dtype="float32") = R.divide(lv510, R.const(2.0, "float32"))
lv512: R.Tensor((1, 2, 8400), dtype="float32") = R.subtract(lv509, lv508)
lv513: R.Tensor((1, 4, 8400), dtype="float32") = R.concat((lv511, lv512), axis=1)
lv514: R.Tensor((1, 4, 8400), dtype="float32") = R.multiply(lv513, metadata["relax.expr.Constant"][177])
lv515: R.Tensor((1, 80, 8400), dtype="float32") = R.sigmoid(lv500)
gv: R.Tensor((1, 84, 8400), dtype="float32") = R.concat((lv514, lv515), axis=1)
R.output(gv)
return gv
# Metadata omitted. Use show_meta=True in script() method to show it.
relax_func = tvm_model["main"]
relax_func.show()
relax.op.split
relax.
tvm_model.show()
# 将算子转换为推理模式
tvm_model = relax.transform.DecomposeOpsForInference()(tvm_model)
tvm_model.show()
# 将任何 Relax 算子合法化为 TensorIR
tvm_model = relax.transform.LegalizeOps()(tvm_model)
# 将模型与参数分离
tvm_model, params = relax.frontend.detach_params(tvm_model)
# 将 Relax 图编译为虚拟机(VM)然后运行
with tvm.transform.PassContext(opt_level=3):
ex = relax.build(tvm_model, target="llvm")
vm = relax.VirtualMachine(ex, tvm.cpu())
准备输入:
input_list = [
inputs[key.name_hint] for key in tvm_model["main"].params if key.name_hint in inputs
]
if params:
input_list += params["main"]
运行模型并检查输出:
vm.set_input("main", *input_list)
vm.invoke_stateful("main")
tvm_output = vm.get_outputs("main")
# 如果只有一个输出,则将其包装为列表
if len(ort_output) == 1:
# 对于 TVM 不检查输出数量
# 对于序列输出,TVM 的输出是元组(Tuple),
# 而 ONNX 的输出数量是一个,即列表形式。
tvm_output = [tvm_output]
def _check_output(tvm_out: list, ort_out: list, rtol: float = 1e-7, atol: float = 1e-5,):
if isinstance(tvm_out, tuple) and isinstance(ort_out, (tvm.runtime.ShapeTuple, list)):
assert len(tvm_out) == len(ort_out), "Unequal number of outputs"
for tvm_out_i, ort_out_i in zip(tvm_out, ort_out):
_check_output(tvm_out_i, ort_out_i)
elif isinstance(tvm_out, tvm.nd.NDArray) and isinstance(ort_out, np.ndarray):
np.testing.assert_allclose(tvm_out.numpy(), ort_out, rtol=rtol, atol=atol)
elif isinstance(tvm_out, tvm.runtime.ShapeTuple) and isinstance(ort_out, np.ndarray):
shape_out = tvm.nd.array([int(i) for i in tvm_out])
np.testing.assert_allclose(shape_out.numpy(), ort_out, rtol=rtol, atol=atol)
elif isinstance(tvm_out, (int, float, bool)) and isinstance(ort_out, np.ndarray):
np.testing.assert_allclose(np.array(tvm_out), ort_out, rtol=rtol, atol=atol)
else:
raise ValueError(f"Unsupported types: {type(tvm_out)}, {type(ort_out)}")
# Check that number of outputs match.
assert len(tvm_output) == len(ort_output), "Unequal number of outputs"
for tvm_out, ort_out in zip(tvm_output, ort_output):
# TODO Allow configurable tolerance.
if ort_out is not None:
_check_output(tvm_out, ort_out, rtol=1e-4, atol=1e-5)