翻译 Relay 程序为 Relax

翻译 Relay 程序为 Relax#

import numpy as np
import tvm
from tvm.relay import testing
from tvm import relax, relay
from tvm.relax.testing import relay_translator, nn
from tvm.runtime import vm as vm_rt
from tvm.script import relax as R
relay_mod, _ = testing.resnet.get_workload(num_layers=50, batch_size=1, dtype="float32")
# translate the ResNet model from Relay to Relax
target = tvm.target.Target("llvm", host="llvm")
relax_mod = relay_translator.from_relay(relay_mod["main"], target)

# print the ResNet IRmodule got translated
relax_mod.show()

# build the IRModule and create relax vm
ex = relax.build(relax_mod, target)
vm = relax.VirtualMachine(ex, tvm.cpu())

# init weights and run the model on relax vm
shape = (1, 3, 224, 224)
data = tvm.nd.array(np.random.rand(*shape).astype(np.float32))
params = nn.init_params(relax_mod)
res = vm["main"](data, *params)

# check correctness by comparing with relay result
exe = relay.vm.compile(relay_mod, target)
relay_vm = vm_rt.VirtualMachine(exe, tvm.cpu())
inputs = [data] + params
expected_output = relay_vm.run(*inputs)
np.testing.assert_allclose(res.numpy(), expected_output.numpy(), rtol=1e-4, atol=1e-4)
# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R


@I.ir_module
class Module:
    @T.prim_func
    def add(
        A: T.Buffer((T.int64(3),), "float32"),
        B: T.Buffer((), "float32"),
        T_add: T.Buffer((T.int64(3),), "float32"),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0 in range(T.int64(3)):
            with T.block("T_add"):
                v_ax0 = T.axis.spatial(T.int64(3), ax0)
                T.reads(A[v_ax0], B[()])
                T.writes(T_add[v_ax0])
                T_add[v_ax0] = A[v_ax0] + B[()]

    @T.prim_func
    def add1(
        A: T.Buffer((T.int64(3),), "float32"),
        B: T.Buffer((T.int64(3),), "float32"),
        T_add: T.Buffer((T.int64(3),), "float32"),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0 in range(T.int64(3)):
            with T.block("T_add"):
                v_ax0 = T.axis.spatial(T.int64(3), ax0)
                T.reads(A[v_ax0], B[v_ax0])
                T.writes(T_add[v_ax0])
                T_add[v_ax0] = A[v_ax0] + B[v_ax0]

    @T.prim_func
    def add10(
        A: T.Buffer(
            (T.int64(1), T.int64(64), T.int64(56), T.int64(56), T.int64(4)), "float32"
        ),
        B: T.Buffer(
            (T.int64(1), T.int64(64), T.int64(1), T.int64(1), T.int64(4)), "float32"
        ),
        T_add: T.Buffer(
            (T.int64(1), T.int64(64), T.int64(56), T.int64(56), T.int64(4)), "float32"
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2, ax3, ax4 in T.grid(
            T.int64(1), T.int64(64), T.int64(56), T.int64(56), T.int64(4)
        ):
            with T.block("T_add"):
                v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap(
                    "SSSSS", [ax0, ax1, ax2, ax3, ax4]
                )
                T.reads(
                    A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4],
                    B[v_ax0, v_ax1, T.int64(0), T.int64(0), v_ax4],
                )
                T.writes(T_add[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
                T_add[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = (
                    A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]
                    + B[v_ax0, v_ax1, T.int64(0), T.int64(0), v_ax4]
                )

    @T.prim_func
    def add11(
        A: T.Buffer((T.int64(128),), "float32"),
        B: T.Buffer((), "float32"),
        T_add: T.Buffer((T.int64(128),), "float32"),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0 in range(T.int64(128)):
            with T.block("T_add"):
                v_ax0 = T.axis.spatial(T.int64(128), ax0)
                T.reads(A[v_ax0], B[()])
                T.writes(T_add[v_ax0])
                T_add[v_ax0] = A[v_ax0] + B[()]

    @T.prim_func
    def add12(
        A: T.Buffer((T.int64(128),), "float32"),
        B: T.Buffer((T.int64(128),), "float32"),
        T_add: T.Buffer((T.int64(128),), "float32"),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0 in range(T.int64(128)):
            with T.block("T_add"):
                v_ax0 = T.axis.spatial(T.int64(128), ax0)
                T.reads(A[v_ax0], B[v_ax0])
                T.writes(T_add[v_ax0])
                T_add[v_ax0] = A[v_ax0] + B[v_ax0]

    @T.prim_func
    def add13(
        A: T.Buffer(
            (T.int64(1), T.int64(32), T.int64(28), T.int64(28), T.int64(4)), "float32"
        ),
        B: T.Buffer(
            (T.int64(1), T.int64(32), T.int64(1), T.int64(1), T.int64(4)), "float32"
        ),
        T_add: T.Buffer(
            (T.int64(1), T.int64(32), T.int64(28), T.int64(28), T.int64(4)), "float32"
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2, ax3, ax4 in T.grid(
            T.int64(1), T.int64(32), T.int64(28), T.int64(28), T.int64(4)
        ):
            with T.block("T_add"):
                v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap(
                    "SSSSS", [ax0, ax1, ax2, ax3, ax4]
                )
                T.reads(
                    A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4],
                    B[v_ax0, v_ax1, T.int64(0), T.int64(0), v_ax4],
                )
                T.writes(T_add[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
                T_add[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = (
                    A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]
                    + B[v_ax0, v_ax1, T.int64(0), T.int64(0), v_ax4]
                )

    @T.prim_func
    def add14(
        A: T.Buffer(
            (T.int64(1), T.int64(128), T.int64(28), T.int64(28), T.int64(4)), "float32"
        ),
        B: T.Buffer(
            (T.int64(1), T.int64(128), T.int64(28), T.int64(28), T.int64(4)), "float32"
        ),
        T_add: T.Buffer(
            (T.int64(1), T.int64(128), T.int64(28), T.int64(28), T.int64(4)), "float32"
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2, ax3, ax4 in T.grid(
            T.int64(1), T.int64(128), T.int64(28), T.int64(28), T.int64(4)
        ):
            with T.block("T_add"):
                v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap(
                    "SSSSS", [ax0, ax1, ax2, ax3, ax4]
                )
                T.reads(
                    A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4],
                    B[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4],
                )
                T.writes(T_add[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
                T_add[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = (
                    A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]
                    + B[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]
                )

    @T.prim_func
    def add15(
        A: T.Buffer((T.int64(512),), "float32"),
        B: T.Buffer((), "float32"),
        T_add: T.Buffer((T.int64(512),), "float32"),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0 in range(T.int64(512)):
            with T.block("T_add"):
                v_ax0 = T.axis.spatial(T.int64(512), ax0)
                T.reads(A[v_ax0], B[()])
                T.writes(T_add[v_ax0])
                T_add[v_ax0] = A[v_ax0] + B[()]

    @T.prim_func
    def add16(
        A: T.Buffer((T.int64(512),), "float32"),
        B: T.Buffer((T.int64(512),), "float32"),
        T_add: T.Buffer((T.int64(512),), "float32"),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0 in range(T.int64(512)):
            with T.block("T_add"):
                v_ax0 = T.axis.spatial(T.int64(512), ax0)
                T.reads(A[v_ax0], B[v_ax0])
                T.writes(T_add[v_ax0])
                T_add[v_ax0] = A[v_ax0] + B[v_ax0]

    @T.prim_func
    def add17(
        A: T.Buffer(
            (T.int64(1), T.int64(128), T.int64(28), T.int64(28), T.int64(4)), "float32"
        ),
        B: T.Buffer(
            (T.int64(1), T.int64(128), T.int64(1), T.int64(1), T.int64(4)), "float32"
        ),
        T_add: T.Buffer(
            (T.int64(1), T.int64(128), T.int64(28), T.int64(28), T.int64(4)), "float32"
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2, ax3, ax4 in T.grid(
            T.int64(1), T.int64(128), T.int64(28), T.int64(28), T.int64(4)
        ):
            with T.block("T_add"):
                v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap(
                    "SSSSS", [ax0, ax1, ax2, ax3, ax4]
                )
                T.reads(
                    A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4],
                    B[v_ax0, v_ax1, T.int64(0), T.int64(0), v_ax4],
                )
                T.writes(T_add[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
                T_add[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = (
                    A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]
                    + B[v_ax0, v_ax1, T.int64(0), T.int64(0), v_ax4]
                )

    @T.prim_func
    def add18(
        A: T.Buffer(
            (T.int64(1), T.int64(64), T.int64(14), T.int64(14), T.int64(4)), "float32"
        ),
        B: T.Buffer(
            (T.int64(1), T.int64(64), T.int64(1), T.int64(1), T.int64(4)), "float32"
        ),
        T_add: T.Buffer(
            (T.int64(1), T.int64(64), T.int64(14), T.int64(14), T.int64(4)), "float32"
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2, ax3, ax4 in T.grid(
            T.int64(1), T.int64(64), T.int64(14), T.int64(14), T.int64(4)
        ):
            with T.block("T_add"):
                v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap(
                    "SSSSS", [ax0, ax1, ax2, ax3, ax4]
                )
                T.reads(
                    A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4],
                    B[v_ax0, v_ax1, T.int64(0), T.int64(0), v_ax4],
                )
                T.writes(T_add[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
                T_add[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = (
                    A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]
                    + B[v_ax0, v_ax1, T.int64(0), T.int64(0), v_ax4]
                )

    @T.prim_func
    def add19(
        A: T.Buffer(
            (T.int64(1), T.int64(256), T.int64(14), T.int64(14), T.int64(4)), "float32"
        ),
        B: T.Buffer(
            (T.int64(1), T.int64(256), T.int64(14), T.int64(14), T.int64(4)), "float32"
        ),
        T_add: T.Buffer(
            (T.int64(1), T.int64(256), T.int64(14), T.int64(14), T.int64(4)), "float32"
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2, ax3, ax4 in T.grid(
            T.int64(1), T.int64(256), T.int64(14), T.int64(14), T.int64(4)
        ):
            with T.block("T_add"):
                v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap(
                    "SSSSS", [ax0, ax1, ax2, ax3, ax4]
                )
                T.reads(
                    A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4],
                    B[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4],
                )
                T.writes(T_add[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
                T_add[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = (
                    A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]
                    + B[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]
                )

    @T.prim_func
    def add2(
        A: T.Buffer((T.int64(1), T.int64(3), T.int64(224), T.int64(224)), "float32"),
        B: T.Buffer((T.int64(3), T.int64(1), T.int64(1)), "float32"),
        T_add: T.Buffer(
            (T.int64(1), T.int64(3), T.int64(224), T.int64(224)), "float32"
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2, ax3 in T.grid(
            T.int64(1), T.int64(3), T.int64(224), T.int64(224)
        ):
            with T.block("T_add"):
                v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                T.reads(A[v_ax0, v_ax1, v_ax2, v_ax3], B[v_ax1, T.int64(0), T.int64(0)])
                T.writes(T_add[v_ax0, v_ax1, v_ax2, v_ax3])
                T_add[v_ax0, v_ax1, v_ax2, v_ax3] = (
                    A[v_ax0, v_ax1, v_ax2, v_ax3] + B[v_ax1, T.int64(0), T.int64(0)]
                )

    @T.prim_func
    def add20(
        A: T.Buffer((T.int64(1024),), "float32"),
        B: T.Buffer((), "float32"),
        T_add: T.Buffer((T.int64(1024),), "float32"),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0 in range(T.int64(1024)):
            with T.block("T_add"):
                v_ax0 = T.axis.spatial(T.int64(1024), ax0)
                T.reads(A[v_ax0], B[()])
                T.writes(T_add[v_ax0])
                T_add[v_ax0] = A[v_ax0] + B[()]

    @T.prim_func
    def add21(
        A: T.Buffer((T.int64(1024),), "float32"),
        B: T.Buffer((T.int64(1024),), "float32"),
        T_add: T.Buffer((T.int64(1024),), "float32"),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0 in range(T.int64(1024)):
            with T.block("T_add"):
                v_ax0 = T.axis.spatial(T.int64(1024), ax0)
                T.reads(A[v_ax0], B[v_ax0])
                T.writes(T_add[v_ax0])
                T_add[v_ax0] = A[v_ax0] + B[v_ax0]

    @T.prim_func
    def add22(
        A: T.Buffer(
            (T.int64(1), T.int64(256), T.int64(14), T.int64(14), T.int64(4)), "float32"
        ),
        B: T.Buffer(
            (T.int64(1), T.int64(256), T.int64(1), T.int64(1), T.int64(4)), "float32"
        ),
        T_add: T.Buffer(
            (T.int64(1), T.int64(256), T.int64(14), T.int64(14), T.int64(4)), "float32"
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2, ax3, ax4 in T.grid(
            T.int64(1), T.int64(256), T.int64(14), T.int64(14), T.int64(4)
        ):
            with T.block("T_add"):
                v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap(
                    "SSSSS", [ax0, ax1, ax2, ax3, ax4]
                )
                T.reads(
                    A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4],
                    B[v_ax0, v_ax1, T.int64(0), T.int64(0), v_ax4],
                )
                T.writes(T_add[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
                T_add[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = (
                    A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]
                    + B[v_ax0, v_ax1, T.int64(0), T.int64(0), v_ax4]
                )

    @T.prim_func
    def add23(
        A: T.Buffer(
            (T.int64(1), T.int64(128), T.int64(7), T.int64(7), T.int64(4)), "float32"
        ),
        B: T.Buffer(
            (T.int64(1), T.int64(128), T.int64(1), T.int64(1), T.int64(4)), "float32"
        ),
        T_add: T.Buffer(
            (T.int64(1), T.int64(128), T.int64(7), T.int64(7), T.int64(4)), "float32"
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2, ax3, ax4 in T.grid(
            T.int64(1), T.int64(128), T.int64(7), T.int64(7), T.int64(4)
        ):
            with T.block("T_add"):
                v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap(
                    "SSSSS", [ax0, ax1, ax2, ax3, ax4]
                )
                T.reads(
                    A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4],
                    B[v_ax0, v_ax1, T.int64(0), T.int64(0), v_ax4],
                )
                T.writes(T_add[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
                T_add[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = (
                    A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]
                    + B[v_ax0, v_ax1, T.int64(0), T.int64(0), v_ax4]
                )

    @T.prim_func
    def add24(
        A: T.Buffer(
            (T.int64(1), T.int64(512), T.int64(7), T.int64(7), T.int64(4)), "float32"
        ),
        B: T.Buffer(
            (T.int64(1), T.int64(512), T.int64(7), T.int64(7), T.int64(4)), "float32"
        ),
        T_add: T.Buffer(
            (T.int64(1), T.int64(512), T.int64(7), T.int64(7), T.int64(4)), "float32"
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2, ax3, ax4 in T.grid(
            T.int64(1), T.int64(512), T.int64(7), T.int64(7), T.int64(4)
        ):
            with T.block("T_add"):
                v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap(
                    "SSSSS", [ax0, ax1, ax2, ax3, ax4]
                )
                T.reads(
                    A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4],
                    B[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4],
                )
                T.writes(T_add[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
                T_add[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = (
                    A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]
                    + B[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]
                )

    @T.prim_func
    def add25(
        A: T.Buffer((T.int64(2048),), "float32"),
        B: T.Buffer((), "float32"),
        T_add: T.Buffer((T.int64(2048),), "float32"),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0 in range(T.int64(2048)):
            with T.block("T_add"):
                v_ax0 = T.axis.spatial(T.int64(2048), ax0)
                T.reads(A[v_ax0], B[()])
                T.writes(T_add[v_ax0])
                T_add[v_ax0] = A[v_ax0] + B[()]

    @T.prim_func
    def add26(
        A: T.Buffer((T.int64(2048),), "float32"),
        B: T.Buffer((T.int64(2048),), "float32"),
        T_add: T.Buffer((T.int64(2048),), "float32"),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0 in range(T.int64(2048)):
            with T.block("T_add"):
                v_ax0 = T.axis.spatial(T.int64(2048), ax0)
                T.reads(A[v_ax0], B[v_ax0])
                T.writes(T_add[v_ax0])
                T_add[v_ax0] = A[v_ax0] + B[v_ax0]

    @T.prim_func
    def add27(
        A: T.Buffer(
            (T.int64(1), T.int64(512), T.int64(7), T.int64(7), T.int64(4)), "float32"
        ),
        B: T.Buffer(
            (T.int64(1), T.int64(512), T.int64(1), T.int64(1), T.int64(4)), "float32"
        ),
        T_add: T.Buffer(
            (T.int64(1), T.int64(512), T.int64(7), T.int64(7), T.int64(4)), "float32"
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2, ax3, ax4 in T.grid(
            T.int64(1), T.int64(512), T.int64(7), T.int64(7), T.int64(4)
        ):
            with T.block("T_add"):
                v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap(
                    "SSSSS", [ax0, ax1, ax2, ax3, ax4]
                )
                T.reads(
                    A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4],
                    B[v_ax0, v_ax1, T.int64(0), T.int64(0), v_ax4],
                )
                T.writes(T_add[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
                T_add[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = (
                    A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]
                    + B[v_ax0, v_ax1, T.int64(0), T.int64(0), v_ax4]
                )

    @T.prim_func
    def add28(
        A: T.Buffer((T.int64(1), T.int64(1000)), "float32"),
        B: T.Buffer((T.int64(1), T.int64(1000)), "float32"),
        T_add: T.Buffer((T.int64(1), T.int64(1000)), "float32"),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1 in T.grid(T.int64(1), T.int64(1000)):
            with T.block("T_add"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(A[v_ax0, v_ax1], B[v_ax0, v_ax1])
                T.writes(T_add[v_ax0, v_ax1])
                T_add[v_ax0, v_ax1] = A[v_ax0, v_ax1] + B[v_ax0, v_ax1]

    @T.prim_func
    def add3(
        A: T.Buffer((T.int64(64),), "float32"),
        B: T.Buffer((), "float32"),
        T_add: T.Buffer((T.int64(64),), "float32"),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0 in range(T.int64(64)):
            with T.block("T_add"):
                v_ax0 = T.axis.spatial(T.int64(64), ax0)
                T.reads(A[v_ax0], B[()])
                T.writes(T_add[v_ax0])
                T_add[v_ax0] = A[v_ax0] + B[()]

    @T.prim_func
    def add4(
        A: T.Buffer((T.int64(64),), "float32"),
        B: T.Buffer((T.int64(64),), "float32"),
        T_add: T.Buffer((T.int64(64),), "float32"),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0 in range(T.int64(64)):
            with T.block("T_add"):
                v_ax0 = T.axis.spatial(T.int64(64), ax0)
                T.reads(A[v_ax0], B[v_ax0])
                T.writes(T_add[v_ax0])
                T_add[v_ax0] = A[v_ax0] + B[v_ax0]

    @T.prim_func
    def add5(
        A: T.Buffer(
            (T.int64(1), T.int64(16), T.int64(112), T.int64(112), T.int64(4)), "float32"
        ),
        B: T.Buffer(
            (T.int64(1), T.int64(16), T.int64(1), T.int64(1), T.int64(4)), "float32"
        ),
        T_add: T.Buffer(
            (T.int64(1), T.int64(16), T.int64(112), T.int64(112), T.int64(4)), "float32"
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2, ax3, ax4 in T.grid(
            T.int64(1), T.int64(16), T.int64(112), T.int64(112), T.int64(4)
        ):
            with T.block("T_add"):
                v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap(
                    "SSSSS", [ax0, ax1, ax2, ax3, ax4]
                )
                T.reads(
                    A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4],
                    B[v_ax0, v_ax1, T.int64(0), T.int64(0), v_ax4],
                )
                T.writes(T_add[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
                T_add[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = (
                    A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]
                    + B[v_ax0, v_ax1, T.int64(0), T.int64(0), v_ax4]
                )

    @T.prim_func
    def add6(
        A: T.Buffer(
            (T.int64(1), T.int64(16), T.int64(56), T.int64(56), T.int64(4)), "float32"
        ),
        B: T.Buffer(
            (T.int64(1), T.int64(16), T.int64(1), T.int64(1), T.int64(4)), "float32"
        ),
        T_add: T.Buffer(
            (T.int64(1), T.int64(16), T.int64(56), T.int64(56), T.int64(4)), "float32"
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2, ax3, ax4 in T.grid(
            T.int64(1), T.int64(16), T.int64(56), T.int64(56), T.int64(4)
        ):
            with T.block("T_add"):
                v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap(
                    "SSSSS", [ax0, ax1, ax2, ax3, ax4]
                )
                T.reads(
                    A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4],
                    B[v_ax0, v_ax1, T.int64(0), T.int64(0), v_ax4],
                )
                T.writes(T_add[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
                T_add[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = (
                    A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]
                    + B[v_ax0, v_ax1, T.int64(0), T.int64(0), v_ax4]
                )

    @T.prim_func
    def add7(
        A: T.Buffer(
            (T.int64(1), T.int64(64), T.int64(56), T.int64(56), T.int64(4)), "float32"
        ),
        B: T.Buffer(
            (T.int64(1), T.int64(64), T.int64(56), T.int64(56), T.int64(4)), "float32"
        ),
        T_add: T.Buffer(
            (T.int64(1), T.int64(64), T.int64(56), T.int64(56), T.int64(4)), "float32"
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2, ax3, ax4 in T.grid(
            T.int64(1), T.int64(64), T.int64(56), T.int64(56), T.int64(4)
        ):
            with T.block("T_add"):
                v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap(
                    "SSSSS", [ax0, ax1, ax2, ax3, ax4]
                )
                T.reads(
                    A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4],
                    B[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4],
                )
                T.writes(T_add[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
                T_add[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = (
                    A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]
                    + B[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]
                )

    @T.prim_func
    def add8(
        A: T.Buffer((T.int64(256),), "float32"),
        B: T.Buffer((), "float32"),
        T_add: T.Buffer((T.int64(256),), "float32"),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0 in range(T.int64(256)):
            with T.block("T_add"):
                v_ax0 = T.axis.spatial(T.int64(256), ax0)
                T.reads(A[v_ax0], B[()])
                T.writes(T_add[v_ax0])
                T_add[v_ax0] = A[v_ax0] + B[()]

    @T.prim_func
    def add9(
        A: T.Buffer((T.int64(256),), "float32"),
        B: T.Buffer((T.int64(256),), "float32"),
        T_add: T.Buffer((T.int64(256),), "float32"),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0 in range(T.int64(256)):
            with T.block("T_add"):
                v_ax0 = T.axis.spatial(T.int64(256), ax0)
                T.reads(A[v_ax0], B[v_ax0])
                T.writes(T_add[v_ax0])
                T_add[v_ax0] = A[v_ax0] + B[v_ax0]

    @T.prim_func
    def batch_flatten(
        A: T.Buffer((T.int64(1), T.int64(2048), T.int64(1), T.int64(1)), "float32"),
        tensor: T.Buffer((T.int64(1), T.int64(2048)), "float32"),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1 in T.grid(T.int64(1), T.int64(2048)):
            with T.block("tensor"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(A[v_ax0, v_ax1 % T.int64(2048), T.int64(0), T.int64(0)])
                T.writes(tensor[v_ax0, v_ax1])
                tensor[v_ax0, v_ax1] = A[
                    v_ax0, v_ax1 % T.int64(2048), T.int64(0), T.int64(0)
                ]

    @T.prim_func
    def contrib_conv2d_NCHWc(
        A: T.Buffer(
            (T.int64(1), T.int64(1), T.int64(224), T.int64(224), T.int64(3)), "float32"
        ),
        B: T.Buffer(
            (T.int64(16), T.int64(1), T.int64(7), T.int64(7), T.int64(3), T.int64(4)),
            "float32",
        ),
        conv2d_NCHWc: T.Buffer(
            (T.int64(1), T.int64(16), T.int64(112), T.int64(112), T.int64(4)), "float32"
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        data_pad = T.alloc_buffer(
            (T.int64(1), T.int64(1), T.int64(230), T.int64(230), T.int64(3))
        )
        for i0, i1, i2, i3, i4 in T.grid(
            T.int64(1), T.int64(1), T.int64(230), T.int64(230), T.int64(3)
        ):
            with T.block("data_pad"):
                v_i0, v_i1, v_i2, v_i3, v_i4 = T.axis.remap(
                    "SSSSS", [i0, i1, i2, i3, i4]
                )
                T.reads(A[v_i0, v_i1, v_i2 - T.int64(3), v_i3 - T.int64(3), v_i4])
                T.writes(data_pad[v_i0, v_i1, v_i2, v_i3, v_i4])
                data_pad[v_i0, v_i1, v_i2, v_i3, v_i4] = T.if_then_else(
                    T.int64(3) <= v_i2
                    and v_i2 < T.int64(227)
                    and T.int64(3) <= v_i3
                    and v_i3 < T.int64(227),
                    A[v_i0, v_i1, v_i2 - T.int64(3), v_i3 - T.int64(3), v_i4],
                    T.float32(0),
                )
        for n, oc_chunk, oh, ow, oc_block, ic, kh, kw in T.grid(
            T.int64(1),
            T.int64(16),
            T.int64(112),
            T.int64(112),
            T.int64(4),
            T.int64(3),
            T.int64(7),
            T.int64(7),
        ):
            with T.block("conv2d_NCHWc"):
                (
                    v_n,
                    v_oc_chunk,
                    v_oh,
                    v_ow,
                    v_oc_block,
                    v_ic,
                    v_kh,
                    v_kw,
                ) = T.axis.remap(
                    "SSSSSRRR", [n, oc_chunk, oh, ow, oc_block, ic, kh, kw]
                )
                T.reads(
                    data_pad[
                        v_n,
                        v_ic // T.int64(3),
                        v_oh * T.int64(2) + v_kh,
                        v_ow * T.int64(2) + v_kw,
                        v_ic % T.int64(3),
                    ],
                    B[
                        v_oc_chunk,
                        v_ic // T.int64(3),
                        v_kh,
                        v_kw,
                        v_ic % T.int64(3),
                        v_oc_block,
                    ],
                )
                T.writes(conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block])
                with T.init():
                    conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block] = T.float32(0)
                conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block] = (
                    conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block]
                    + data_pad[
                        v_n,
                        v_ic // T.int64(3),
                        v_oh * T.int64(2) + v_kh,
                        v_ow * T.int64(2) + v_kw,
                        v_ic % T.int64(3),
                    ]
                    * B[
                        v_oc_chunk,
                        v_ic // T.int64(3),
                        v_kh,
                        v_kw,
                        v_ic % T.int64(3),
                        v_oc_block,
                    ]
                )

    @T.prim_func
    def contrib_conv2d_NCHWc1(
        A: T.Buffer(
            (T.int64(1), T.int64(16), T.int64(56), T.int64(56), T.int64(4)), "float32"
        ),
        B: T.Buffer(
            (T.int64(16), T.int64(16), T.int64(1), T.int64(1), T.int64(4), T.int64(4)),
            "float32",
        ),
        conv2d_NCHWc: T.Buffer(
            (T.int64(1), T.int64(16), T.int64(56), T.int64(56), T.int64(4)), "float32"
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for n, oc_chunk, oh, ow, oc_block, ic, kh, kw in T.grid(
            T.int64(1),
            T.int64(16),
            T.int64(56),
            T.int64(56),
            T.int64(4),
            T.int64(64),
            T.int64(1),
            T.int64(1),
        ):
            with T.block("conv2d_NCHWc"):
                (
                    v_n,
                    v_oc_chunk,
                    v_oh,
                    v_ow,
                    v_oc_block,
                    v_ic,
                    v_kh,
                    v_kw,
                ) = T.axis.remap(
                    "SSSSSRRR", [n, oc_chunk, oh, ow, oc_block, ic, kh, kw]
                )
                T.reads(
                    A[
                        v_n,
                        v_ic // T.int64(4),
                        v_oh + v_kh,
                        v_ow + v_kw,
                        v_ic % T.int64(4),
                    ],
                    B[
                        v_oc_chunk,
                        v_ic // T.int64(4),
                        v_kh,
                        v_kw,
                        v_ic % T.int64(4),
                        v_oc_block,
                    ],
                )
                T.writes(conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block])
                with T.init():
                    conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block] = T.float32(0)
                conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block] = (
                    conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block]
                    + A[
                        v_n,
                        v_ic // T.int64(4),
                        v_oh + v_kh,
                        v_ow + v_kw,
                        v_ic % T.int64(4),
                    ]
                    * B[
                        v_oc_chunk,
                        v_ic // T.int64(4),
                        v_kh,
                        v_kw,
                        v_ic % T.int64(4),
                        v_oc_block,
                    ]
                )

    @T.prim_func
    def contrib_conv2d_NCHWc10(
        A: T.Buffer(
            (T.int64(1), T.int64(128), T.int64(28), T.int64(28), T.int64(4)), "float32"
        ),
        B: T.Buffer(
            (T.int64(64), T.int64(128), T.int64(1), T.int64(1), T.int64(4), T.int64(4)),
            "float32",
        ),
        conv2d_NCHWc: T.Buffer(
            (T.int64(1), T.int64(64), T.int64(14), T.int64(14), T.int64(4)), "float32"
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for n, oc_chunk, oh, ow, oc_block, ic, kh, kw in T.grid(
            T.int64(1),
            T.int64(64),
            T.int64(14),
            T.int64(14),
            T.int64(4),
            T.int64(512),
            T.int64(1),
            T.int64(1),
        ):
            with T.block("conv2d_NCHWc"):
                (
                    v_n,
                    v_oc_chunk,
                    v_oh,
                    v_ow,
                    v_oc_block,
                    v_ic,
                    v_kh,
                    v_kw,
                ) = T.axis.remap(
                    "SSSSSRRR", [n, oc_chunk, oh, ow, oc_block, ic, kh, kw]
                )
                T.reads(
                    A[
                        v_n,
                        v_ic // T.int64(4),
                        v_oh * T.int64(2) + v_kh,
                        v_ow * T.int64(2) + v_kw,
                        v_ic % T.int64(4),
                    ],
                    B[
                        v_oc_chunk,
                        v_ic // T.int64(4),
                        v_kh,
                        v_kw,
                        v_ic % T.int64(4),
                        v_oc_block,
                    ],
                )
                T.writes(conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block])
                with T.init():
                    conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block] = T.float32(0)
                conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block] = (
                    conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block]
                    + A[
                        v_n,
                        v_ic // T.int64(4),
                        v_oh * T.int64(2) + v_kh,
                        v_ow * T.int64(2) + v_kw,
                        v_ic % T.int64(4),
                    ]
                    * B[
                        v_oc_chunk,
                        v_ic // T.int64(4),
                        v_kh,
                        v_kw,
                        v_ic % T.int64(4),
                        v_oc_block,
                    ]
                )

    @T.prim_func
    def contrib_conv2d_NCHWc11(
        A: T.Buffer(
            (T.int64(1), T.int64(64), T.int64(14), T.int64(14), T.int64(4)), "float32"
        ),
        B: T.Buffer(
            (T.int64(64), T.int64(64), T.int64(3), T.int64(3), T.int64(4), T.int64(4)),
            "float32",
        ),
        conv2d_NCHWc: T.Buffer(
            (T.int64(1), T.int64(64), T.int64(14), T.int64(14), T.int64(4)), "float32"
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        data_pad = T.alloc_buffer(
            (T.int64(1), T.int64(64), T.int64(16), T.int64(16), T.int64(4))
        )
        for i0, i1, i2, i3, i4 in T.grid(
            T.int64(1), T.int64(64), T.int64(16), T.int64(16), T.int64(4)
        ):
            with T.block("data_pad"):
                v_i0, v_i1, v_i2, v_i3, v_i4 = T.axis.remap(
                    "SSSSS", [i0, i1, i2, i3, i4]
                )
                T.reads(A[v_i0, v_i1, v_i2 - T.int64(1), v_i3 - T.int64(1), v_i4])
                T.writes(data_pad[v_i0, v_i1, v_i2, v_i3, v_i4])
                data_pad[v_i0, v_i1, v_i2, v_i3, v_i4] = T.if_then_else(
                    T.int64(1) <= v_i2
                    and v_i2 < T.int64(15)
                    and T.int64(1) <= v_i3
                    and v_i3 < T.int64(15),
                    A[v_i0, v_i1, v_i2 - T.int64(1), v_i3 - T.int64(1), v_i4],
                    T.float32(0),
                )
        for n, oc_chunk, oh, ow, oc_block, ic, kh, kw in T.grid(
            T.int64(1),
            T.int64(64),
            T.int64(14),
            T.int64(14),
            T.int64(4),
            T.int64(256),
            T.int64(3),
            T.int64(3),
        ):
            with T.block("conv2d_NCHWc"):
                (
                    v_n,
                    v_oc_chunk,
                    v_oh,
                    v_ow,
                    v_oc_block,
                    v_ic,
                    v_kh,
                    v_kw,
                ) = T.axis.remap(
                    "SSSSSRRR", [n, oc_chunk, oh, ow, oc_block, ic, kh, kw]
                )
                T.reads(
                    data_pad[
                        v_n,
                        v_ic // T.int64(4),
                        v_oh + v_kh,
                        v_ow + v_kw,
                        v_ic % T.int64(4),
                    ],
                    B[
                        v_oc_chunk,
                        v_ic // T.int64(4),
                        v_kh,
                        v_kw,
                        v_ic % T.int64(4),
                        v_oc_block,
                    ],
                )
                T.writes(conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block])
                with T.init():
                    conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block] = T.float32(0)
                conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block] = (
                    conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block]
                    + data_pad[
                        v_n,
                        v_ic // T.int64(4),
                        v_oh + v_kh,
                        v_ow + v_kw,
                        v_ic % T.int64(4),
                    ]
                    * B[
                        v_oc_chunk,
                        v_ic // T.int64(4),
                        v_kh,
                        v_kw,
                        v_ic % T.int64(4),
                        v_oc_block,
                    ]
                )

    @T.prim_func
    def contrib_conv2d_NCHWc12(
        A: T.Buffer(
            (T.int64(1), T.int64(64), T.int64(14), T.int64(14), T.int64(4)), "float32"
        ),
        B: T.Buffer(
            (T.int64(256), T.int64(64), T.int64(1), T.int64(1), T.int64(4), T.int64(4)),
            "float32",
        ),
        conv2d_NCHWc: T.Buffer(
            (T.int64(1), T.int64(256), T.int64(14), T.int64(14), T.int64(4)), "float32"
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for n, oc_chunk, oh, ow, oc_block, ic, kh, kw in T.grid(
            T.int64(1),
            T.int64(256),
            T.int64(14),
            T.int64(14),
            T.int64(4),
            T.int64(256),
            T.int64(1),
            T.int64(1),
        ):
            with T.block("conv2d_NCHWc"):
                (
                    v_n,
                    v_oc_chunk,
                    v_oh,
                    v_ow,
                    v_oc_block,
                    v_ic,
                    v_kh,
                    v_kw,
                ) = T.axis.remap(
                    "SSSSSRRR", [n, oc_chunk, oh, ow, oc_block, ic, kh, kw]
                )
                T.reads(
                    A[
                        v_n,
                        v_ic // T.int64(4),
                        v_oh + v_kh,
                        v_ow + v_kw,
                        v_ic % T.int64(4),
                    ],
                    B[
                        v_oc_chunk,
                        v_ic // T.int64(4),
                        v_kh,
                        v_kw,
                        v_ic % T.int64(4),
                        v_oc_block,
                    ],
                )
                T.writes(conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block])
                with T.init():
                    conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block] = T.float32(0)
                conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block] = (
                    conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block]
                    + A[
                        v_n,
                        v_ic // T.int64(4),
                        v_oh + v_kh,
                        v_ow + v_kw,
                        v_ic % T.int64(4),
                    ]
                    * B[
                        v_oc_chunk,
                        v_ic // T.int64(4),
                        v_kh,
                        v_kw,
                        v_ic % T.int64(4),
                        v_oc_block,
                    ]
                )

    @T.prim_func
    def contrib_conv2d_NCHWc13(
        A: T.Buffer(
            (T.int64(1), T.int64(128), T.int64(28), T.int64(28), T.int64(4)), "float32"
        ),
        B: T.Buffer(
            (
                T.int64(256),
                T.int64(128),
                T.int64(1),
                T.int64(1),
                T.int64(4),
                T.int64(4),
            ),
            "float32",
        ),
        conv2d_NCHWc: T.Buffer(
            (T.int64(1), T.int64(256), T.int64(14), T.int64(14), T.int64(4)), "float32"
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for n, oc_chunk, oh, ow, oc_block, ic, kh, kw in T.grid(
            T.int64(1),
            T.int64(256),
            T.int64(14),
            T.int64(14),
            T.int64(4),
            T.int64(512),
            T.int64(1),
            T.int64(1),
        ):
            with T.block("conv2d_NCHWc"):
                (
                    v_n,
                    v_oc_chunk,
                    v_oh,
                    v_ow,
                    v_oc_block,
                    v_ic,
                    v_kh,
                    v_kw,
                ) = T.axis.remap(
                    "SSSSSRRR", [n, oc_chunk, oh, ow, oc_block, ic, kh, kw]
                )
                T.reads(
                    A[
                        v_n,
                        v_ic // T.int64(4),
                        v_oh * T.int64(2) + v_kh,
                        v_ow * T.int64(2) + v_kw,
                        v_ic % T.int64(4),
                    ],
                    B[
                        v_oc_chunk,
                        v_ic // T.int64(4),
                        v_kh,
                        v_kw,
                        v_ic % T.int64(4),
                        v_oc_block,
                    ],
                )
                T.writes(conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block])
                with T.init():
                    conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block] = T.float32(0)
                conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block] = (
                    conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block]
                    + A[
                        v_n,
                        v_ic // T.int64(4),
                        v_oh * T.int64(2) + v_kh,
                        v_ow * T.int64(2) + v_kw,
                        v_ic % T.int64(4),
                    ]
                    * B[
                        v_oc_chunk,
                        v_ic // T.int64(4),
                        v_kh,
                        v_kw,
                        v_ic % T.int64(4),
                        v_oc_block,
                    ]
                )

    @T.prim_func
    def contrib_conv2d_NCHWc14(
        A: T.Buffer(
            (T.int64(1), T.int64(256), T.int64(14), T.int64(14), T.int64(4)), "float32"
        ),
        B: T.Buffer(
            (T.int64(64), T.int64(256), T.int64(1), T.int64(1), T.int64(4), T.int64(4)),
            "float32",
        ),
        conv2d_NCHWc: T.Buffer(
            (T.int64(1), T.int64(64), T.int64(14), T.int64(14), T.int64(4)), "float32"
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for n, oc_chunk, oh, ow, oc_block, ic, kh, kw in T.grid(
            T.int64(1),
            T.int64(64),
            T.int64(14),
            T.int64(14),
            T.int64(4),
            T.int64(1024),
            T.int64(1),
            T.int64(1),
        ):
            with T.block("conv2d_NCHWc"):
                (
                    v_n,
                    v_oc_chunk,
                    v_oh,
                    v_ow,
                    v_oc_block,
                    v_ic,
                    v_kh,
                    v_kw,
                ) = T.axis.remap(
                    "SSSSSRRR", [n, oc_chunk, oh, ow, oc_block, ic, kh, kw]
                )
                T.reads(
                    A[
                        v_n,
                        v_ic // T.int64(4),
                        v_oh + v_kh,
                        v_ow + v_kw,
                        v_ic % T.int64(4),
                    ],
                    B[
                        v_oc_chunk,
                        v_ic // T.int64(4),
                        v_kh,
                        v_kw,
                        v_ic % T.int64(4),
                        v_oc_block,
                    ],
                )
                T.writes(conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block])
                with T.init():
                    conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block] = T.float32(0)
                conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block] = (
                    conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block]
                    + A[
                        v_n,
                        v_ic // T.int64(4),
                        v_oh + v_kh,
                        v_ow + v_kw,
                        v_ic % T.int64(4),
                    ]
                    * B[
                        v_oc_chunk,
                        v_ic // T.int64(4),
                        v_kh,
                        v_kw,
                        v_ic % T.int64(4),
                        v_oc_block,
                    ]
                )

    @T.prim_func
    def contrib_conv2d_NCHWc15(
        A: T.Buffer(
            (T.int64(1), T.int64(256), T.int64(14), T.int64(14), T.int64(4)), "float32"
        ),
        B: T.Buffer(
            (
                T.int64(128),
                T.int64(256),
                T.int64(1),
                T.int64(1),
                T.int64(4),
                T.int64(4),
            ),
            "float32",
        ),
        conv2d_NCHWc: T.Buffer(
            (T.int64(1), T.int64(128), T.int64(7), T.int64(7), T.int64(4)), "float32"
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for n, oc_chunk, oh, ow, oc_block, ic, kh, kw in T.grid(
            T.int64(1),
            T.int64(128),
            T.int64(7),
            T.int64(7),
            T.int64(4),
            T.int64(1024),
            T.int64(1),
            T.int64(1),
        ):
            with T.block("conv2d_NCHWc"):
                (
                    v_n,
                    v_oc_chunk,
                    v_oh,
                    v_ow,
                    v_oc_block,
                    v_ic,
                    v_kh,
                    v_kw,
                ) = T.axis.remap(
                    "SSSSSRRR", [n, oc_chunk, oh, ow, oc_block, ic, kh, kw]
                )
                T.reads(
                    A[
                        v_n,
                        v_ic // T.int64(4),
                        v_oh * T.int64(2) + v_kh,
                        v_ow * T.int64(2) + v_kw,
                        v_ic % T.int64(4),
                    ],
                    B[
                        v_oc_chunk,
                        v_ic // T.int64(4),
                        v_kh,
                        v_kw,
                        v_ic % T.int64(4),
                        v_oc_block,
                    ],
                )
                T.writes(conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block])
                with T.init():
                    conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block] = T.float32(0)
                conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block] = (
                    conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block]
                    + A[
                        v_n,
                        v_ic // T.int64(4),
                        v_oh * T.int64(2) + v_kh,
                        v_ow * T.int64(2) + v_kw,
                        v_ic % T.int64(4),
                    ]
                    * B[
                        v_oc_chunk,
                        v_ic // T.int64(4),
                        v_kh,
                        v_kw,
                        v_ic % T.int64(4),
                        v_oc_block,
                    ]
                )

    @T.prim_func
    def contrib_conv2d_NCHWc16(
        A: T.Buffer(
            (T.int64(1), T.int64(128), T.int64(7), T.int64(7), T.int64(4)), "float32"
        ),
        B: T.Buffer(
            (
                T.int64(128),
                T.int64(128),
                T.int64(3),
                T.int64(3),
                T.int64(4),
                T.int64(4),
            ),
            "float32",
        ),
        conv2d_NCHWc: T.Buffer(
            (T.int64(1), T.int64(128), T.int64(7), T.int64(7), T.int64(4)), "float32"
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        data_pad = T.alloc_buffer(
            (T.int64(1), T.int64(128), T.int64(9), T.int64(9), T.int64(4))
        )
        for i0, i1, i2, i3, i4 in T.grid(
            T.int64(1), T.int64(128), T.int64(9), T.int64(9), T.int64(4)
        ):
            with T.block("data_pad"):
                v_i0, v_i1, v_i2, v_i3, v_i4 = T.axis.remap(
                    "SSSSS", [i0, i1, i2, i3, i4]
                )
                T.reads(A[v_i0, v_i1, v_i2 - T.int64(1), v_i3 - T.int64(1), v_i4])
                T.writes(data_pad[v_i0, v_i1, v_i2, v_i3, v_i4])
                data_pad[v_i0, v_i1, v_i2, v_i3, v_i4] = T.if_then_else(
                    T.int64(1) <= v_i2
                    and v_i2 < T.int64(8)
                    and T.int64(1) <= v_i3
                    and v_i3 < T.int64(8),
                    A[v_i0, v_i1, v_i2 - T.int64(1), v_i3 - T.int64(1), v_i4],
                    T.float32(0),
                )
        for n, oc_chunk, oh, ow, oc_block, ic, kh, kw in T.grid(
            T.int64(1),
            T.int64(128),
            T.int64(7),
            T.int64(7),
            T.int64(4),
            T.int64(512),
            T.int64(3),
            T.int64(3),
        ):
            with T.block("conv2d_NCHWc"):
                (
                    v_n,
                    v_oc_chunk,
                    v_oh,
                    v_ow,
                    v_oc_block,
                    v_ic,
                    v_kh,
                    v_kw,
                ) = T.axis.remap(
                    "SSSSSRRR", [n, oc_chunk, oh, ow, oc_block, ic, kh, kw]
                )
                T.reads(
                    data_pad[
                        v_n,
                        v_ic // T.int64(4),
                        v_oh + v_kh,
                        v_ow + v_kw,
                        v_ic % T.int64(4),
                    ],
                    B[
                        v_oc_chunk,
                        v_ic // T.int64(4),
                        v_kh,
                        v_kw,
                        v_ic % T.int64(4),
                        v_oc_block,
                    ],
                )
                T.writes(conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block])
                with T.init():
                    conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block] = T.float32(0)
                conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block] = (
                    conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block]
                    + data_pad[
                        v_n,
                        v_ic // T.int64(4),
                        v_oh + v_kh,
                        v_ow + v_kw,
                        v_ic % T.int64(4),
                    ]
                    * B[
                        v_oc_chunk,
                        v_ic // T.int64(4),
                        v_kh,
                        v_kw,
                        v_ic % T.int64(4),
                        v_oc_block,
                    ]
                )

    @T.prim_func
    def contrib_conv2d_NCHWc17(
        A: T.Buffer(
            (T.int64(1), T.int64(128), T.int64(7), T.int64(7), T.int64(4)), "float32"
        ),
        B: T.Buffer(
            (
                T.int64(512),
                T.int64(128),
                T.int64(1),
                T.int64(1),
                T.int64(4),
                T.int64(4),
            ),
            "float32",
        ),
        conv2d_NCHWc: T.Buffer(
            (T.int64(1), T.int64(512), T.int64(7), T.int64(7), T.int64(4)), "float32"
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for n, oc_chunk, oh, ow, oc_block, ic, kh, kw in T.grid(
            T.int64(1),
            T.int64(512),
            T.int64(7),
            T.int64(7),
            T.int64(4),
            T.int64(512),
            T.int64(1),
            T.int64(1),
        ):
            with T.block("conv2d_NCHWc"):
                (
                    v_n,
                    v_oc_chunk,
                    v_oh,
                    v_ow,
                    v_oc_block,
                    v_ic,
                    v_kh,
                    v_kw,
                ) = T.axis.remap(
                    "SSSSSRRR", [n, oc_chunk, oh, ow, oc_block, ic, kh, kw]
                )
                T.reads(
                    A[
                        v_n,
                        v_ic // T.int64(4),
                        v_oh + v_kh,
                        v_ow + v_kw,
                        v_ic % T.int64(4),
                    ],
                    B[
                        v_oc_chunk,
                        v_ic // T.int64(4),
                        v_kh,
                        v_kw,
                        v_ic % T.int64(4),
                        v_oc_block,
                    ],
                )
                T.writes(conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block])
                with T.init():
                    conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block] = T.float32(0)
                conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block] = (
                    conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block]
                    + A[
                        v_n,
                        v_ic // T.int64(4),
                        v_oh + v_kh,
                        v_ow + v_kw,
                        v_ic % T.int64(4),
                    ]
                    * B[
                        v_oc_chunk,
                        v_ic // T.int64(4),
                        v_kh,
                        v_kw,
                        v_ic % T.int64(4),
                        v_oc_block,
                    ]
                )

    @T.prim_func
    def contrib_conv2d_NCHWc18(
        A: T.Buffer(
            (T.int64(1), T.int64(256), T.int64(14), T.int64(14), T.int64(4)), "float32"
        ),
        B: T.Buffer(
            (
                T.int64(512),
                T.int64(256),
                T.int64(1),
                T.int64(1),
                T.int64(4),
                T.int64(4),
            ),
            "float32",
        ),
        conv2d_NCHWc: T.Buffer(
            (T.int64(1), T.int64(512), T.int64(7), T.int64(7), T.int64(4)), "float32"
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for n, oc_chunk, oh, ow, oc_block, ic, kh, kw in T.grid(
            T.int64(1),
            T.int64(512),
            T.int64(7),
            T.int64(7),
            T.int64(4),
            T.int64(1024),
            T.int64(1),
            T.int64(1),
        ):
            with T.block("conv2d_NCHWc"):
                (
                    v_n,
                    v_oc_chunk,
                    v_oh,
                    v_ow,
                    v_oc_block,
                    v_ic,
                    v_kh,
                    v_kw,
                ) = T.axis.remap(
                    "SSSSSRRR", [n, oc_chunk, oh, ow, oc_block, ic, kh, kw]
                )
                T.reads(
                    A[
                        v_n,
                        v_ic // T.int64(4),
                        v_oh * T.int64(2) + v_kh,
                        v_ow * T.int64(2) + v_kw,
                        v_ic % T.int64(4),
                    ],
                    B[
                        v_oc_chunk,
                        v_ic // T.int64(4),
                        v_kh,
                        v_kw,
                        v_ic % T.int64(4),
                        v_oc_block,
                    ],
                )
                T.writes(conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block])
                with T.init():
                    conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block] = T.float32(0)
                conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block] = (
                    conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block]
                    + A[
                        v_n,
                        v_ic // T.int64(4),
                        v_oh * T.int64(2) + v_kh,
                        v_ow * T.int64(2) + v_kw,
                        v_ic % T.int64(4),
                    ]
                    * B[
                        v_oc_chunk,
                        v_ic // T.int64(4),
                        v_kh,
                        v_kw,
                        v_ic % T.int64(4),
                        v_oc_block,
                    ]
                )

    @T.prim_func
    def contrib_conv2d_NCHWc19(
        A: T.Buffer(
            (T.int64(1), T.int64(512), T.int64(7), T.int64(7), T.int64(4)), "float32"
        ),
        B: T.Buffer(
            (
                T.int64(128),
                T.int64(512),
                T.int64(1),
                T.int64(1),
                T.int64(4),
                T.int64(4),
            ),
            "float32",
        ),
        conv2d_NCHWc: T.Buffer(
            (T.int64(1), T.int64(128), T.int64(7), T.int64(7), T.int64(4)), "float32"
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for n, oc_chunk, oh, ow, oc_block, ic, kh, kw in T.grid(
            T.int64(1),
            T.int64(128),
            T.int64(7),
            T.int64(7),
            T.int64(4),
            T.int64(2048),
            T.int64(1),
            T.int64(1),
        ):
            with T.block("conv2d_NCHWc"):
                (
                    v_n,
                    v_oc_chunk,
                    v_oh,
                    v_ow,
                    v_oc_block,
                    v_ic,
                    v_kh,
                    v_kw,
                ) = T.axis.remap(
                    "SSSSSRRR", [n, oc_chunk, oh, ow, oc_block, ic, kh, kw]
                )
                T.reads(
                    A[
                        v_n,
                        v_ic // T.int64(4),
                        v_oh + v_kh,
                        v_ow + v_kw,
                        v_ic % T.int64(4),
                    ],
                    B[
                        v_oc_chunk,
                        v_ic // T.int64(4),
                        v_kh,
                        v_kw,
                        v_ic % T.int64(4),
                        v_oc_block,
                    ],
                )
                T.writes(conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block])
                with T.init():
                    conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block] = T.float32(0)
                conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block] = (
                    conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block]
                    + A[
                        v_n,
                        v_ic // T.int64(4),
                        v_oh + v_kh,
                        v_ow + v_kw,
                        v_ic % T.int64(4),
                    ]
                    * B[
                        v_oc_chunk,
                        v_ic // T.int64(4),
                        v_kh,
                        v_kw,
                        v_ic % T.int64(4),
                        v_oc_block,
                    ]
                )

    @T.prim_func
    def contrib_conv2d_NCHWc2(
        A: T.Buffer(
            (T.int64(1), T.int64(16), T.int64(56), T.int64(56), T.int64(4)), "float32"
        ),
        B: T.Buffer(
            (T.int64(16), T.int64(16), T.int64(3), T.int64(3), T.int64(4), T.int64(4)),
            "float32",
        ),
        conv2d_NCHWc: T.Buffer(
            (T.int64(1), T.int64(16), T.int64(56), T.int64(56), T.int64(4)), "float32"
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        data_pad = T.alloc_buffer(
            (T.int64(1), T.int64(16), T.int64(58), T.int64(58), T.int64(4))
        )
        for i0, i1, i2, i3, i4 in T.grid(
            T.int64(1), T.int64(16), T.int64(58), T.int64(58), T.int64(4)
        ):
            with T.block("data_pad"):
                v_i0, v_i1, v_i2, v_i3, v_i4 = T.axis.remap(
                    "SSSSS", [i0, i1, i2, i3, i4]
                )
                T.reads(A[v_i0, v_i1, v_i2 - T.int64(1), v_i3 - T.int64(1), v_i4])
                T.writes(data_pad[v_i0, v_i1, v_i2, v_i3, v_i4])
                data_pad[v_i0, v_i1, v_i2, v_i3, v_i4] = T.if_then_else(
                    T.int64(1) <= v_i2
                    and v_i2 < T.int64(57)
                    and T.int64(1) <= v_i3
                    and v_i3 < T.int64(57),
                    A[v_i0, v_i1, v_i2 - T.int64(1), v_i3 - T.int64(1), v_i4],
                    T.float32(0),
                )
        for n, oc_chunk, oh, ow, oc_block, ic, kh, kw in T.grid(
            T.int64(1),
            T.int64(16),
            T.int64(56),
            T.int64(56),
            T.int64(4),
            T.int64(64),
            T.int64(3),
            T.int64(3),
        ):
            with T.block("conv2d_NCHWc"):
                (
                    v_n,
                    v_oc_chunk,
                    v_oh,
                    v_ow,
                    v_oc_block,
                    v_ic,
                    v_kh,
                    v_kw,
                ) = T.axis.remap(
                    "SSSSSRRR", [n, oc_chunk, oh, ow, oc_block, ic, kh, kw]
                )
                T.reads(
                    data_pad[
                        v_n,
                        v_ic // T.int64(4),
                        v_oh + v_kh,
                        v_ow + v_kw,
                        v_ic % T.int64(4),
                    ],
                    B[
                        v_oc_chunk,
                        v_ic // T.int64(4),
                        v_kh,
                        v_kw,
                        v_ic % T.int64(4),
                        v_oc_block,
                    ],
                )
                T.writes(conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block])
                with T.init():
                    conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block] = T.float32(0)
                conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block] = (
                    conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block]
                    + data_pad[
                        v_n,
                        v_ic // T.int64(4),
                        v_oh + v_kh,
                        v_ow + v_kw,
                        v_ic % T.int64(4),
                    ]
                    * B[
                        v_oc_chunk,
                        v_ic // T.int64(4),
                        v_kh,
                        v_kw,
                        v_ic % T.int64(4),
                        v_oc_block,
                    ]
                )

    @T.prim_func
    def contrib_conv2d_NCHWc3(
        A: T.Buffer(
            (T.int64(1), T.int64(16), T.int64(56), T.int64(56), T.int64(4)), "float32"
        ),
        B: T.Buffer(
            (T.int64(64), T.int64(16), T.int64(1), T.int64(1), T.int64(4), T.int64(4)),
            "float32",
        ),
        conv2d_NCHWc: T.Buffer(
            (T.int64(1), T.int64(64), T.int64(56), T.int64(56), T.int64(4)), "float32"
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for n, oc_chunk, oh, ow, oc_block, ic, kh, kw in T.grid(
            T.int64(1),
            T.int64(64),
            T.int64(56),
            T.int64(56),
            T.int64(4),
            T.int64(64),
            T.int64(1),
            T.int64(1),
        ):
            with T.block("conv2d_NCHWc"):
                (
                    v_n,
                    v_oc_chunk,
                    v_oh,
                    v_ow,
                    v_oc_block,
                    v_ic,
                    v_kh,
                    v_kw,
                ) = T.axis.remap(
                    "SSSSSRRR", [n, oc_chunk, oh, ow, oc_block, ic, kh, kw]
                )
                T.reads(
                    A[
                        v_n,
                        v_ic // T.int64(4),
                        v_oh + v_kh,
                        v_ow + v_kw,
                        v_ic % T.int64(4),
                    ],
                    B[
                        v_oc_chunk,
                        v_ic // T.int64(4),
                        v_kh,
                        v_kw,
                        v_ic % T.int64(4),
                        v_oc_block,
                    ],
                )
                T.writes(conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block])
                with T.init():
                    conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block] = T.float32(0)
                conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block] = (
                    conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block]
                    + A[
                        v_n,
                        v_ic // T.int64(4),
                        v_oh + v_kh,
                        v_ow + v_kw,
                        v_ic % T.int64(4),
                    ]
                    * B[
                        v_oc_chunk,
                        v_ic // T.int64(4),
                        v_kh,
                        v_kw,
                        v_ic % T.int64(4),
                        v_oc_block,
                    ]
                )

    @T.prim_func
    def contrib_conv2d_NCHWc4(
        A: T.Buffer(
            (T.int64(1), T.int64(64), T.int64(56), T.int64(56), T.int64(4)), "float32"
        ),
        B: T.Buffer(
            (T.int64(16), T.int64(64), T.int64(1), T.int64(1), T.int64(4), T.int64(4)),
            "float32",
        ),
        conv2d_NCHWc: T.Buffer(
            (T.int64(1), T.int64(16), T.int64(56), T.int64(56), T.int64(4)), "float32"
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for n, oc_chunk, oh, ow, oc_block, ic, kh, kw in T.grid(
            T.int64(1),
            T.int64(16),
            T.int64(56),
            T.int64(56),
            T.int64(4),
            T.int64(256),
            T.int64(1),
            T.int64(1),
        ):
            with T.block("conv2d_NCHWc"):
                (
                    v_n,
                    v_oc_chunk,
                    v_oh,
                    v_ow,
                    v_oc_block,
                    v_ic,
                    v_kh,
                    v_kw,
                ) = T.axis.remap(
                    "SSSSSRRR", [n, oc_chunk, oh, ow, oc_block, ic, kh, kw]
                )
                T.reads(
                    A[
                        v_n,
                        v_ic // T.int64(4),
                        v_oh + v_kh,
                        v_ow + v_kw,
                        v_ic % T.int64(4),
                    ],
                    B[
                        v_oc_chunk,
                        v_ic // T.int64(4),
                        v_kh,
                        v_kw,
                        v_ic % T.int64(4),
                        v_oc_block,
                    ],
                )
                T.writes(conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block])
                with T.init():
                    conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block] = T.float32(0)
                conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block] = (
                    conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block]
                    + A[
                        v_n,
                        v_ic // T.int64(4),
                        v_oh + v_kh,
                        v_ow + v_kw,
                        v_ic % T.int64(4),
                    ]
                    * B[
                        v_oc_chunk,
                        v_ic // T.int64(4),
                        v_kh,
                        v_kw,
                        v_ic % T.int64(4),
                        v_oc_block,
                    ]
                )

    @T.prim_func
    def contrib_conv2d_NCHWc5(
        A: T.Buffer(
            (T.int64(1), T.int64(64), T.int64(56), T.int64(56), T.int64(4)), "float32"
        ),
        B: T.Buffer(
            (T.int64(32), T.int64(64), T.int64(1), T.int64(1), T.int64(4), T.int64(4)),
            "float32",
        ),
        conv2d_NCHWc: T.Buffer(
            (T.int64(1), T.int64(32), T.int64(28), T.int64(28), T.int64(4)), "float32"
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for n, oc_chunk, oh, ow, oc_block, ic, kh, kw in T.grid(
            T.int64(1),
            T.int64(32),
            T.int64(28),
            T.int64(28),
            T.int64(4),
            T.int64(256),
            T.int64(1),
            T.int64(1),
        ):
            with T.block("conv2d_NCHWc"):
                (
                    v_n,
                    v_oc_chunk,
                    v_oh,
                    v_ow,
                    v_oc_block,
                    v_ic,
                    v_kh,
                    v_kw,
                ) = T.axis.remap(
                    "SSSSSRRR", [n, oc_chunk, oh, ow, oc_block, ic, kh, kw]
                )
                T.reads(
                    A[
                        v_n,
                        v_ic // T.int64(4),
                        v_oh * T.int64(2) + v_kh,
                        v_ow * T.int64(2) + v_kw,
                        v_ic % T.int64(4),
                    ],
                    B[
                        v_oc_chunk,
                        v_ic // T.int64(4),
                        v_kh,
                        v_kw,
                        v_ic % T.int64(4),
                        v_oc_block,
                    ],
                )
                T.writes(conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block])
                with T.init():
                    conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block] = T.float32(0)
                conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block] = (
                    conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block]
                    + A[
                        v_n,
                        v_ic // T.int64(4),
                        v_oh * T.int64(2) + v_kh,
                        v_ow * T.int64(2) + v_kw,
                        v_ic % T.int64(4),
                    ]
                    * B[
                        v_oc_chunk,
                        v_ic // T.int64(4),
                        v_kh,
                        v_kw,
                        v_ic % T.int64(4),
                        v_oc_block,
                    ]
                )

    @T.prim_func
    def contrib_conv2d_NCHWc6(
        A: T.Buffer(
            (T.int64(1), T.int64(32), T.int64(28), T.int64(28), T.int64(4)), "float32"
        ),
        B: T.Buffer(
            (T.int64(32), T.int64(32), T.int64(3), T.int64(3), T.int64(4), T.int64(4)),
            "float32",
        ),
        conv2d_NCHWc: T.Buffer(
            (T.int64(1), T.int64(32), T.int64(28), T.int64(28), T.int64(4)), "float32"
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        data_pad = T.alloc_buffer(
            (T.int64(1), T.int64(32), T.int64(30), T.int64(30), T.int64(4))
        )
        for i0, i1, i2, i3, i4 in T.grid(
            T.int64(1), T.int64(32), T.int64(30), T.int64(30), T.int64(4)
        ):
            with T.block("data_pad"):
                v_i0, v_i1, v_i2, v_i3, v_i4 = T.axis.remap(
                    "SSSSS", [i0, i1, i2, i3, i4]
                )
                T.reads(A[v_i0, v_i1, v_i2 - T.int64(1), v_i3 - T.int64(1), v_i4])
                T.writes(data_pad[v_i0, v_i1, v_i2, v_i3, v_i4])
                data_pad[v_i0, v_i1, v_i2, v_i3, v_i4] = T.if_then_else(
                    T.int64(1) <= v_i2
                    and v_i2 < T.int64(29)
                    and T.int64(1) <= v_i3
                    and v_i3 < T.int64(29),
                    A[v_i0, v_i1, v_i2 - T.int64(1), v_i3 - T.int64(1), v_i4],
                    T.float32(0),
                )
        for n, oc_chunk, oh, ow, oc_block, ic, kh, kw in T.grid(
            T.int64(1),
            T.int64(32),
            T.int64(28),
            T.int64(28),
            T.int64(4),
            T.int64(128),
            T.int64(3),
            T.int64(3),
        ):
            with T.block("conv2d_NCHWc"):
                (
                    v_n,
                    v_oc_chunk,
                    v_oh,
                    v_ow,
                    v_oc_block,
                    v_ic,
                    v_kh,
                    v_kw,
                ) = T.axis.remap(
                    "SSSSSRRR", [n, oc_chunk, oh, ow, oc_block, ic, kh, kw]
                )
                T.reads(
                    data_pad[
                        v_n,
                        v_ic // T.int64(4),
                        v_oh + v_kh,
                        v_ow + v_kw,
                        v_ic % T.int64(4),
                    ],
                    B[
                        v_oc_chunk,
                        v_ic // T.int64(4),
                        v_kh,
                        v_kw,
                        v_ic % T.int64(4),
                        v_oc_block,
                    ],
                )
                T.writes(conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block])
                with T.init():
                    conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block] = T.float32(0)
                conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block] = (
                    conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block]
                    + data_pad[
                        v_n,
                        v_ic // T.int64(4),
                        v_oh + v_kh,
                        v_ow + v_kw,
                        v_ic % T.int64(4),
                    ]
                    * B[
                        v_oc_chunk,
                        v_ic // T.int64(4),
                        v_kh,
                        v_kw,
                        v_ic % T.int64(4),
                        v_oc_block,
                    ]
                )

    @T.prim_func
    def contrib_conv2d_NCHWc7(
        A: T.Buffer(
            (T.int64(1), T.int64(32), T.int64(28), T.int64(28), T.int64(4)), "float32"
        ),
        B: T.Buffer(
            (T.int64(128), T.int64(32), T.int64(1), T.int64(1), T.int64(4), T.int64(4)),
            "float32",
        ),
        conv2d_NCHWc: T.Buffer(
            (T.int64(1), T.int64(128), T.int64(28), T.int64(28), T.int64(4)), "float32"
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for n, oc_chunk, oh, ow, oc_block, ic, kh, kw in T.grid(
            T.int64(1),
            T.int64(128),
            T.int64(28),
            T.int64(28),
            T.int64(4),
            T.int64(128),
            T.int64(1),
            T.int64(1),
        ):
            with T.block("conv2d_NCHWc"):
                (
                    v_n,
                    v_oc_chunk,
                    v_oh,
                    v_ow,
                    v_oc_block,
                    v_ic,
                    v_kh,
                    v_kw,
                ) = T.axis.remap(
                    "SSSSSRRR", [n, oc_chunk, oh, ow, oc_block, ic, kh, kw]
                )
                T.reads(
                    A[
                        v_n,
                        v_ic // T.int64(4),
                        v_oh + v_kh,
                        v_ow + v_kw,
                        v_ic % T.int64(4),
                    ],
                    B[
                        v_oc_chunk,
                        v_ic // T.int64(4),
                        v_kh,
                        v_kw,
                        v_ic % T.int64(4),
                        v_oc_block,
                    ],
                )
                T.writes(conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block])
                with T.init():
                    conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block] = T.float32(0)
                conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block] = (
                    conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block]
                    + A[
                        v_n,
                        v_ic // T.int64(4),
                        v_oh + v_kh,
                        v_ow + v_kw,
                        v_ic % T.int64(4),
                    ]
                    * B[
                        v_oc_chunk,
                        v_ic // T.int64(4),
                        v_kh,
                        v_kw,
                        v_ic % T.int64(4),
                        v_oc_block,
                    ]
                )

    @T.prim_func
    def contrib_conv2d_NCHWc8(
        A: T.Buffer(
            (T.int64(1), T.int64(64), T.int64(56), T.int64(56), T.int64(4)), "float32"
        ),
        B: T.Buffer(
            (T.int64(128), T.int64(64), T.int64(1), T.int64(1), T.int64(4), T.int64(4)),
            "float32",
        ),
        conv2d_NCHWc: T.Buffer(
            (T.int64(1), T.int64(128), T.int64(28), T.int64(28), T.int64(4)), "float32"
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for n, oc_chunk, oh, ow, oc_block, ic, kh, kw in T.grid(
            T.int64(1),
            T.int64(128),
            T.int64(28),
            T.int64(28),
            T.int64(4),
            T.int64(256),
            T.int64(1),
            T.int64(1),
        ):
            with T.block("conv2d_NCHWc"):
                (
                    v_n,
                    v_oc_chunk,
                    v_oh,
                    v_ow,
                    v_oc_block,
                    v_ic,
                    v_kh,
                    v_kw,
                ) = T.axis.remap(
                    "SSSSSRRR", [n, oc_chunk, oh, ow, oc_block, ic, kh, kw]
                )
                T.reads(
                    A[
                        v_n,
                        v_ic // T.int64(4),
                        v_oh * T.int64(2) + v_kh,
                        v_ow * T.int64(2) + v_kw,
                        v_ic % T.int64(4),
                    ],
                    B[
                        v_oc_chunk,
                        v_ic // T.int64(4),
                        v_kh,
                        v_kw,
                        v_ic % T.int64(4),
                        v_oc_block,
                    ],
                )
                T.writes(conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block])
                with T.init():
                    conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block] = T.float32(0)
                conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block] = (
                    conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block]
                    + A[
                        v_n,
                        v_ic // T.int64(4),
                        v_oh * T.int64(2) + v_kh,
                        v_ow * T.int64(2) + v_kw,
                        v_ic % T.int64(4),
                    ]
                    * B[
                        v_oc_chunk,
                        v_ic // T.int64(4),
                        v_kh,
                        v_kw,
                        v_ic % T.int64(4),
                        v_oc_block,
                    ]
                )

    @T.prim_func
    def contrib_conv2d_NCHWc9(
        A: T.Buffer(
            (T.int64(1), T.int64(128), T.int64(28), T.int64(28), T.int64(4)), "float32"
        ),
        B: T.Buffer(
            (T.int64(32), T.int64(128), T.int64(1), T.int64(1), T.int64(4), T.int64(4)),
            "float32",
        ),
        conv2d_NCHWc: T.Buffer(
            (T.int64(1), T.int64(32), T.int64(28), T.int64(28), T.int64(4)), "float32"
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for n, oc_chunk, oh, ow, oc_block, ic, kh, kw in T.grid(
            T.int64(1),
            T.int64(32),
            T.int64(28),
            T.int64(28),
            T.int64(4),
            T.int64(512),
            T.int64(1),
            T.int64(1),
        ):
            with T.block("conv2d_NCHWc"):
                (
                    v_n,
                    v_oc_chunk,
                    v_oh,
                    v_ow,
                    v_oc_block,
                    v_ic,
                    v_kh,
                    v_kw,
                ) = T.axis.remap(
                    "SSSSSRRR", [n, oc_chunk, oh, ow, oc_block, ic, kh, kw]
                )
                T.reads(
                    A[
                        v_n,
                        v_ic // T.int64(4),
                        v_oh + v_kh,
                        v_ow + v_kw,
                        v_ic % T.int64(4),
                    ],
                    B[
                        v_oc_chunk,
                        v_ic // T.int64(4),
                        v_kh,
                        v_kw,
                        v_ic % T.int64(4),
                        v_oc_block,
                    ],
                )
                T.writes(conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block])
                with T.init():
                    conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block] = T.float32(0)
                conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block] = (
                    conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block]
                    + A[
                        v_n,
                        v_ic // T.int64(4),
                        v_oh + v_kh,
                        v_ow + v_kw,
                        v_ic % T.int64(4),
                    ]
                    * B[
                        v_oc_chunk,
                        v_ic // T.int64(4),
                        v_kh,
                        v_kw,
                        v_ic % T.int64(4),
                        v_oc_block,
                    ]
                )

    @T.prim_func
    def dense(
        A: T.Buffer((T.int64(1), T.int64(2048)), "float32"),
        B: T.Buffer((T.int64(1000), T.int64(2048)), "float32"),
        T_matmul_NT: T.Buffer((T.int64(1), T.int64(1000)), "float32"),
    ):
        T.func_attr({"layout_free_buffers": [1], "tir.noalias": T.bool(True)})
        # with T.block("root"):
        for i, j, k in T.grid(T.int64(1), T.int64(1000), T.int64(2048)):
            with T.block("T_matmul_NT"):
                v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k])
                T.reads(A[v_i, v_k], B[v_j, v_k])
                T.writes(T_matmul_NT[v_i, v_j])
                with T.init():
                    T_matmul_NT[v_i, v_j] = T.float32(0)
                T_matmul_NT[v_i, v_j] = (
                    T_matmul_NT[v_i, v_j] + A[v_i, v_k] * B[v_j, v_k]
                )

    @T.prim_func
    def divide(
        A: T.Buffer((T.int64(3), T.int64(1), T.int64(1)), "float32"),
        B: T.Buffer((T.int64(3), T.int64(1), T.int64(1)), "float32"),
        T_divide: T.Buffer((T.int64(3), T.int64(1), T.int64(1)), "float32"),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2 in T.grid(T.int64(3), T.int64(1), T.int64(1)):
            with T.block("T_divide"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(A[v_ax0, v_ax1, v_ax2], B[v_ax0, v_ax1, v_ax2])
                T.writes(T_divide[v_ax0, v_ax1, v_ax2])
                T_divide[v_ax0, v_ax1, v_ax2] = (
                    A[v_ax0, v_ax1, v_ax2] / B[v_ax0, v_ax1, v_ax2]
                )

    @T.prim_func
    def expand_dims(
        A: T.Buffer((T.int64(3),), "float32"),
        T_expand_dims: T.Buffer((T.int64(3), T.int64(1), T.int64(1)), "float32"),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2 in T.grid(T.int64(3), T.int64(1), T.int64(1)):
            with T.block("T_expand_dims"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(A[v_ax0])
                T.writes(T_expand_dims[v_ax0, v_ax1, v_ax2])
                T_expand_dims[v_ax0, v_ax1, v_ax2] = A[v_ax0]

    @T.prim_func
    def expand_dims1(
        A: T.Buffer((T.int64(64),), "float32"),
        T_expand_dims: T.Buffer((T.int64(64), T.int64(1), T.int64(1)), "float32"),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2 in T.grid(T.int64(64), T.int64(1), T.int64(1)):
            with T.block("T_expand_dims"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(A[v_ax0])
                T.writes(T_expand_dims[v_ax0, v_ax1, v_ax2])
                T_expand_dims[v_ax0, v_ax1, v_ax2] = A[v_ax0]

    @T.prim_func
    def expand_dims10(
        A: T.Buffer((T.int64(512), T.int64(1), T.int64(1)), "float32"),
        T_expand_dims: T.Buffer(
            (T.int64(1), T.int64(512), T.int64(1), T.int64(1)), "float32"
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2, ax3 in T.grid(
            T.int64(1), T.int64(512), T.int64(1), T.int64(1)
        ):
            with T.block("T_expand_dims"):
                v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                T.reads(A[v_ax1, v_ax2, v_ax3])
                T.writes(T_expand_dims[v_ax0, v_ax1, v_ax2, v_ax3])
                T_expand_dims[v_ax0, v_ax1, v_ax2, v_ax3] = A[v_ax1, v_ax2, v_ax3]

    @T.prim_func
    def expand_dims11(
        A: T.Buffer((T.int64(256),), "float32"),
        T_expand_dims: T.Buffer(
            (T.int64(256), T.int64(1), T.int64(1), T.int64(1)), "float32"
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2, ax3 in T.grid(
            T.int64(256), T.int64(1), T.int64(1), T.int64(1)
        ):
            with T.block("T_expand_dims"):
                v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                T.reads(A[v_ax0])
                T.writes(T_expand_dims[v_ax0, v_ax1, v_ax2, v_ax3])
                T_expand_dims[v_ax0, v_ax1, v_ax2, v_ax3] = A[v_ax0]

    @T.prim_func
    def expand_dims12(
        A: T.Buffer((T.int64(1024),), "float32"),
        T_expand_dims: T.Buffer((T.int64(1024), T.int64(1), T.int64(1)), "float32"),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2 in T.grid(T.int64(1024), T.int64(1), T.int64(1)):
            with T.block("T_expand_dims"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(A[v_ax0])
                T.writes(T_expand_dims[v_ax0, v_ax1, v_ax2])
                T_expand_dims[v_ax0, v_ax1, v_ax2] = A[v_ax0]

    @T.prim_func
    def expand_dims13(
        A: T.Buffer((T.int64(1024), T.int64(1), T.int64(1)), "float32"),
        T_expand_dims: T.Buffer(
            (T.int64(1), T.int64(1024), T.int64(1), T.int64(1)), "float32"
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2, ax3 in T.grid(
            T.int64(1), T.int64(1024), T.int64(1), T.int64(1)
        ):
            with T.block("T_expand_dims"):
                v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                T.reads(A[v_ax1, v_ax2, v_ax3])
                T.writes(T_expand_dims[v_ax0, v_ax1, v_ax2, v_ax3])
                T_expand_dims[v_ax0, v_ax1, v_ax2, v_ax3] = A[v_ax1, v_ax2, v_ax3]

    @T.prim_func
    def expand_dims14(
        A: T.Buffer((T.int64(512),), "float32"),
        T_expand_dims: T.Buffer(
            (T.int64(512), T.int64(1), T.int64(1), T.int64(1)), "float32"
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2, ax3 in T.grid(
            T.int64(512), T.int64(1), T.int64(1), T.int64(1)
        ):
            with T.block("T_expand_dims"):
                v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                T.reads(A[v_ax0])
                T.writes(T_expand_dims[v_ax0, v_ax1, v_ax2, v_ax3])
                T_expand_dims[v_ax0, v_ax1, v_ax2, v_ax3] = A[v_ax0]

    @T.prim_func
    def expand_dims15(
        A: T.Buffer((T.int64(2048),), "float32"),
        T_expand_dims: T.Buffer((T.int64(2048), T.int64(1), T.int64(1)), "float32"),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2 in T.grid(T.int64(2048), T.int64(1), T.int64(1)):
            with T.block("T_expand_dims"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(A[v_ax0])
                T.writes(T_expand_dims[v_ax0, v_ax1, v_ax2])
                T_expand_dims[v_ax0, v_ax1, v_ax2] = A[v_ax0]

    @T.prim_func
    def expand_dims16(
        A: T.Buffer((T.int64(2048), T.int64(1), T.int64(1)), "float32"),
        T_expand_dims: T.Buffer(
            (T.int64(1), T.int64(2048), T.int64(1), T.int64(1)), "float32"
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2, ax3 in T.grid(
            T.int64(1), T.int64(2048), T.int64(1), T.int64(1)
        ):
            with T.block("T_expand_dims"):
                v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                T.reads(A[v_ax1, v_ax2, v_ax3])
                T.writes(T_expand_dims[v_ax0, v_ax1, v_ax2, v_ax3])
                T_expand_dims[v_ax0, v_ax1, v_ax2, v_ax3] = A[v_ax1, v_ax2, v_ax3]

    @T.prim_func
    def expand_dims17(
        A: T.Buffer((T.int64(1000),), "float32"),
        T_expand_dims: T.Buffer((T.int64(1), T.int64(1000)), "float32"),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1 in T.grid(T.int64(1), T.int64(1000)):
            with T.block("T_expand_dims"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(A[v_ax1])
                T.writes(T_expand_dims[v_ax0, v_ax1])
                T_expand_dims[v_ax0, v_ax1] = A[v_ax1]

    @T.prim_func
    def expand_dims2(
        A: T.Buffer((T.int64(64),), "float32"),
        T_expand_dims: T.Buffer(
            (T.int64(64), T.int64(1), T.int64(1), T.int64(1)), "float32"
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2, ax3 in T.grid(
            T.int64(64), T.int64(1), T.int64(1), T.int64(1)
        ):
            with T.block("T_expand_dims"):
                v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                T.reads(A[v_ax0])
                T.writes(T_expand_dims[v_ax0, v_ax1, v_ax2, v_ax3])
                T_expand_dims[v_ax0, v_ax1, v_ax2, v_ax3] = A[v_ax0]

    @T.prim_func
    def expand_dims3(
        A: T.Buffer((T.int64(64), T.int64(1), T.int64(1)), "float32"),
        T_expand_dims: T.Buffer(
            (T.int64(1), T.int64(64), T.int64(1), T.int64(1)), "float32"
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2, ax3 in T.grid(
            T.int64(1), T.int64(64), T.int64(1), T.int64(1)
        ):
            with T.block("T_expand_dims"):
                v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                T.reads(A[v_ax1, v_ax2, v_ax3])
                T.writes(T_expand_dims[v_ax0, v_ax1, v_ax2, v_ax3])
                T_expand_dims[v_ax0, v_ax1, v_ax2, v_ax3] = A[v_ax1, v_ax2, v_ax3]

    @T.prim_func
    def expand_dims4(
        A: T.Buffer((T.int64(256),), "float32"),
        T_expand_dims: T.Buffer((T.int64(256), T.int64(1), T.int64(1)), "float32"),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2 in T.grid(T.int64(256), T.int64(1), T.int64(1)):
            with T.block("T_expand_dims"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(A[v_ax0])
                T.writes(T_expand_dims[v_ax0, v_ax1, v_ax2])
                T_expand_dims[v_ax0, v_ax1, v_ax2] = A[v_ax0]

    @T.prim_func
    def expand_dims5(
        A: T.Buffer((T.int64(256), T.int64(1), T.int64(1)), "float32"),
        T_expand_dims: T.Buffer(
            (T.int64(1), T.int64(256), T.int64(1), T.int64(1)), "float32"
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2, ax3 in T.grid(
            T.int64(1), T.int64(256), T.int64(1), T.int64(1)
        ):
            with T.block("T_expand_dims"):
                v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                T.reads(A[v_ax1, v_ax2, v_ax3])
                T.writes(T_expand_dims[v_ax0, v_ax1, v_ax2, v_ax3])
                T_expand_dims[v_ax0, v_ax1, v_ax2, v_ax3] = A[v_ax1, v_ax2, v_ax3]

    @T.prim_func
    def expand_dims6(
        A: T.Buffer((T.int64(128),), "float32"),
        T_expand_dims: T.Buffer((T.int64(128), T.int64(1), T.int64(1)), "float32"),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2 in T.grid(T.int64(128), T.int64(1), T.int64(1)):
            with T.block("T_expand_dims"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(A[v_ax0])
                T.writes(T_expand_dims[v_ax0, v_ax1, v_ax2])
                T_expand_dims[v_ax0, v_ax1, v_ax2] = A[v_ax0]

    @T.prim_func
    def expand_dims7(
        A: T.Buffer((T.int64(128),), "float32"),
        T_expand_dims: T.Buffer(
            (T.int64(128), T.int64(1), T.int64(1), T.int64(1)), "float32"
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2, ax3 in T.grid(
            T.int64(128), T.int64(1), T.int64(1), T.int64(1)
        ):
            with T.block("T_expand_dims"):
                v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                T.reads(A[v_ax0])
                T.writes(T_expand_dims[v_ax0, v_ax1, v_ax2, v_ax3])
                T_expand_dims[v_ax0, v_ax1, v_ax2, v_ax3] = A[v_ax0]

    @T.prim_func
    def expand_dims8(
        A: T.Buffer((T.int64(128), T.int64(1), T.int64(1)), "float32"),
        T_expand_dims: T.Buffer(
            (T.int64(1), T.int64(128), T.int64(1), T.int64(1)), "float32"
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2, ax3 in T.grid(
            T.int64(1), T.int64(128), T.int64(1), T.int64(1)
        ):
            with T.block("T_expand_dims"):
                v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                T.reads(A[v_ax1, v_ax2, v_ax3])
                T.writes(T_expand_dims[v_ax0, v_ax1, v_ax2, v_ax3])
                T_expand_dims[v_ax0, v_ax1, v_ax2, v_ax3] = A[v_ax1, v_ax2, v_ax3]

    @T.prim_func
    def expand_dims9(
        A: T.Buffer((T.int64(512),), "float32"),
        T_expand_dims: T.Buffer((T.int64(512), T.int64(1), T.int64(1)), "float32"),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2 in T.grid(T.int64(512), T.int64(1), T.int64(1)):
            with T.block("T_expand_dims"):
                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                T.reads(A[v_ax0])
                T.writes(T_expand_dims[v_ax0, v_ax1, v_ax2])
                T_expand_dims[v_ax0, v_ax1, v_ax2] = A[v_ax0]

    @T.prim_func
    def global_avg_pool2d(
        A: T.Buffer(
            (T.int64(1), T.int64(512), T.int64(7), T.int64(7), T.int64(4)), "float32"
        ),
        adaptive_pool_avg: T.Buffer(
            (T.int64(1), T.int64(512), T.int64(1), T.int64(1), T.int64(4)), "float32"
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        adaptive_pool_sum = T.alloc_buffer(
            (T.int64(1), T.int64(512), T.int64(1), T.int64(1), T.int64(4))
        )
        for ax0, ax1, ax2, ax3, ax4, rv0, rv1 in T.grid(
            T.int64(1),
            T.int64(512),
            T.int64(1),
            T.int64(1),
            T.int64(4),
            T.int64(7),
            T.int64(7),
        ):
            with T.block("adaptive_pool_sum"):
                v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_rv0, v_rv1 = T.axis.remap(
                    "SSSSSRR", [ax0, ax1, ax2, ax3, ax4, rv0, rv1]
                )
                T.reads(
                    A[
                        v_ax0,
                        v_ax1,
                        v_ax2 * T.int64(7) + v_rv0,
                        v_ax3 * T.int64(7) + v_rv1,
                        v_ax4,
                    ]
                )
                T.writes(adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
                with T.init():
                    adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = T.float32(0)
                adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = (
                    adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]
                    + A[
                        v_ax0,
                        v_ax1,
                        v_ax2 * T.int64(7) + v_rv0,
                        v_ax3 * T.int64(7) + v_rv1,
                        v_ax4,
                    ]
                )
        for ax0, ax1, ax2, ax3, ax4 in T.grid(
            T.int64(1), T.int64(512), T.int64(1), T.int64(1), T.int64(4)
        ):
            with T.block("adaptive_pool_avg"):
                v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap(
                    "SSSSS", [ax0, ax1, ax2, ax3, ax4]
                )
                T.reads(adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
                T.writes(adaptive_pool_avg[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
                T.block_attr({"schedule_rule": "meta_schedule.adaptive_pool_avg"})
                adaptive_pool_avg[
                    v_ax0, v_ax1, v_ax2, v_ax3, v_ax4
                ] = adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] * T.float32(
                    0.020408163265306121
                )

    @T.prim_func
    def layout_transform(
        A: T.Buffer((T.int64(1), T.int64(3), T.int64(224), T.int64(224)), "float32"),
        T_layout_trans: T.Buffer(
            (T.int64(1), T.int64(1), T.int64(224), T.int64(224), T.int64(3)), "float32"
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2, ax3, ax4 in T.grid(
            T.int64(1), T.int64(1), T.int64(224), T.int64(224), T.int64(3)
        ):
            with T.block("T_layout_trans"):
                v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap(
                    "SSSSS", [ax0, ax1, ax2, ax3, ax4]
                )
                T.reads(A[v_ax0, v_ax1 * T.int64(3) + v_ax4, v_ax2, v_ax3])
                T.writes(T_layout_trans[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
                T.block_attr(
                    {
                        "dst_layout": "NCHW3c",
                        "input_shape": [
                            T.int64(1),
                            T.int64(3),
                            T.int64(224),
                            T.int64(224),
                        ],
                        "schedule_rule": "None",
                        "src_layout": "NCHW",
                    }
                )
                T_layout_trans[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = T.if_then_else(
                    v_ax0 < T.int64(1)
                    and v_ax1 * T.int64(3) + v_ax4 < T.int64(3)
                    and v_ax2 < T.int64(224)
                    and v_ax3 < T.int64(224),
                    A[v_ax0, v_ax1 * T.int64(3) + v_ax4, v_ax2, v_ax3],
                    T.float32(0),
                )

    @T.prim_func
    def layout_transform1(
        A: T.Buffer((T.int64(64), T.int64(3), T.int64(7), T.int64(7)), "float32"),
        T_layout_trans: T.Buffer(
            (T.int64(16), T.int64(1), T.int64(7), T.int64(7), T.int64(3), T.int64(4)),
            "float32",
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2, ax3, ax4, ax5 in T.grid(
            T.int64(16), T.int64(1), T.int64(7), T.int64(7), T.int64(3), T.int64(4)
        ):
            with T.block("T_layout_trans"):
                v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5 = T.axis.remap(
                    "SSSSSS", [ax0, ax1, ax2, ax3, ax4, ax5]
                )
                T.reads(
                    A[
                        v_ax0 * T.int64(4) + v_ax5,
                        v_ax1 * T.int64(3) + v_ax4,
                        v_ax2,
                        v_ax3,
                    ]
                )
                T.writes(T_layout_trans[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5])
                T.block_attr(
                    {
                        "dst_layout": "OIHW3i4o",
                        "input_shape": [
                            T.int64(64),
                            T.int64(3),
                            T.int64(7),
                            T.int64(7),
                        ],
                        "schedule_rule": "None",
                        "src_layout": "OIHW",
                    }
                )
                T_layout_trans[
                    v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5
                ] = T.if_then_else(
                    v_ax0 * T.int64(4) + v_ax5 < T.int64(64)
                    and v_ax1 * T.int64(3) + v_ax4 < T.int64(3)
                    and v_ax2 < T.int64(7)
                    and v_ax3 < T.int64(7),
                    A[
                        v_ax0 * T.int64(4) + v_ax5,
                        v_ax1 * T.int64(3) + v_ax4,
                        v_ax2,
                        v_ax3,
                    ],
                    T.float32(0),
                )

    @T.prim_func
    def layout_transform10(
        A: T.Buffer((T.int64(128), T.int64(128), T.int64(3), T.int64(3)), "float32"),
        T_layout_trans: T.Buffer(
            (T.int64(32), T.int64(32), T.int64(3), T.int64(3), T.int64(4), T.int64(4)),
            "float32",
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2, ax3, ax4, ax5 in T.grid(
            T.int64(32), T.int64(32), T.int64(3), T.int64(3), T.int64(4), T.int64(4)
        ):
            with T.block("T_layout_trans"):
                v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5 = T.axis.remap(
                    "SSSSSS", [ax0, ax1, ax2, ax3, ax4, ax5]
                )
                T.reads(
                    A[
                        v_ax0 * T.int64(4) + v_ax5,
                        v_ax1 * T.int64(4) + v_ax4,
                        v_ax2,
                        v_ax3,
                    ]
                )
                T.writes(T_layout_trans[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5])
                T.block_attr(
                    {
                        "dst_layout": "OIHW4i4o",
                        "input_shape": [
                            T.int64(128),
                            T.int64(128),
                            T.int64(3),
                            T.int64(3),
                        ],
                        "schedule_rule": "None",
                        "src_layout": "OIHW",
                    }
                )
                T_layout_trans[
                    v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5
                ] = T.if_then_else(
                    v_ax0 * T.int64(4) + v_ax5 < T.int64(128)
                    and v_ax1 * T.int64(4) + v_ax4 < T.int64(128)
                    and v_ax2 < T.int64(3)
                    and v_ax3 < T.int64(3),
                    A[
                        v_ax0 * T.int64(4) + v_ax5,
                        v_ax1 * T.int64(4) + v_ax4,
                        v_ax2,
                        v_ax3,
                    ],
                    T.float32(0),
                )

    @T.prim_func
    def layout_transform11(
        A: T.Buffer((T.int64(512), T.int64(128), T.int64(1), T.int64(1)), "float32"),
        T_layout_trans: T.Buffer(
            (T.int64(128), T.int64(32), T.int64(1), T.int64(1), T.int64(4), T.int64(4)),
            "float32",
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2, ax3, ax4, ax5 in T.grid(
            T.int64(128), T.int64(32), T.int64(1), T.int64(1), T.int64(4), T.int64(4)
        ):
            with T.block("T_layout_trans"):
                v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5 = T.axis.remap(
                    "SSSSSS", [ax0, ax1, ax2, ax3, ax4, ax5]
                )
                T.reads(
                    A[
                        v_ax0 * T.int64(4) + v_ax5,
                        v_ax1 * T.int64(4) + v_ax4,
                        v_ax2,
                        v_ax3,
                    ]
                )
                T.writes(T_layout_trans[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5])
                T.block_attr(
                    {
                        "dst_layout": "OIHW4i4o",
                        "input_shape": [
                            T.int64(512),
                            T.int64(128),
                            T.int64(1),
                            T.int64(1),
                        ],
                        "schedule_rule": "None",
                        "src_layout": "OIHW",
                    }
                )
                T_layout_trans[
                    v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5
                ] = T.if_then_else(
                    v_ax0 * T.int64(4) + v_ax5 < T.int64(512)
                    and v_ax1 * T.int64(4) + v_ax4 < T.int64(128)
                    and v_ax2 < T.int64(1)
                    and v_ax3 < T.int64(1),
                    A[
                        v_ax0 * T.int64(4) + v_ax5,
                        v_ax1 * T.int64(4) + v_ax4,
                        v_ax2,
                        v_ax3,
                    ],
                    T.float32(0),
                )

    @T.prim_func
    def layout_transform12(
        A: T.Buffer((T.int64(512), T.int64(256), T.int64(1), T.int64(1)), "float32"),
        T_layout_trans: T.Buffer(
            (T.int64(128), T.int64(64), T.int64(1), T.int64(1), T.int64(4), T.int64(4)),
            "float32",
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2, ax3, ax4, ax5 in T.grid(
            T.int64(128), T.int64(64), T.int64(1), T.int64(1), T.int64(4), T.int64(4)
        ):
            with T.block("T_layout_trans"):
                v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5 = T.axis.remap(
                    "SSSSSS", [ax0, ax1, ax2, ax3, ax4, ax5]
                )
                T.reads(
                    A[
                        v_ax0 * T.int64(4) + v_ax5,
                        v_ax1 * T.int64(4) + v_ax4,
                        v_ax2,
                        v_ax3,
                    ]
                )
                T.writes(T_layout_trans[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5])
                T.block_attr(
                    {
                        "dst_layout": "OIHW4i4o",
                        "input_shape": [
                            T.int64(512),
                            T.int64(256),
                            T.int64(1),
                            T.int64(1),
                        ],
                        "schedule_rule": "None",
                        "src_layout": "OIHW",
                    }
                )
                T_layout_trans[
                    v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5
                ] = T.if_then_else(
                    v_ax0 * T.int64(4) + v_ax5 < T.int64(512)
                    and v_ax1 * T.int64(4) + v_ax4 < T.int64(256)
                    and v_ax2 < T.int64(1)
                    and v_ax3 < T.int64(1),
                    A[
                        v_ax0 * T.int64(4) + v_ax5,
                        v_ax1 * T.int64(4) + v_ax4,
                        v_ax2,
                        v_ax3,
                    ],
                    T.float32(0),
                )

    @T.prim_func
    def layout_transform13(
        A: T.Buffer((T.int64(1), T.int64(512), T.int64(1), T.int64(1)), "float32"),
        T_layout_trans: T.Buffer(
            (T.int64(1), T.int64(128), T.int64(1), T.int64(1), T.int64(4)), "float32"
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2, ax3, ax4 in T.grid(
            T.int64(1), T.int64(128), T.int64(1), T.int64(1), T.int64(4)
        ):
            with T.block("T_layout_trans"):
                v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap(
                    "SSSSS", [ax0, ax1, ax2, ax3, ax4]
                )
                T.reads(A[v_ax0, v_ax1 * T.int64(4) + v_ax4, v_ax2, v_ax3])
                T.writes(T_layout_trans[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
                T.block_attr(
                    {
                        "dst_layout": "NCHW4c",
                        "input_shape": [
                            T.int64(1),
                            T.int64(512),
                            T.int64(1),
                            T.int64(1),
                        ],
                        "schedule_rule": "None",
                        "src_layout": "NCHW",
                    }
                )
                T_layout_trans[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = T.if_then_else(
                    v_ax0 < T.int64(1)
                    and v_ax1 * T.int64(4) + v_ax4 < T.int64(512)
                    and v_ax2 < T.int64(1)
                    and v_ax3 < T.int64(1),
                    A[v_ax0, v_ax1 * T.int64(4) + v_ax4, v_ax2, v_ax3],
                    T.float32(0),
                )

    @T.prim_func
    def layout_transform14(
        A: T.Buffer((T.int64(128), T.int64(512), T.int64(1), T.int64(1)), "float32"),
        T_layout_trans: T.Buffer(
            (T.int64(32), T.int64(128), T.int64(1), T.int64(1), T.int64(4), T.int64(4)),
            "float32",
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2, ax3, ax4, ax5 in T.grid(
            T.int64(32), T.int64(128), T.int64(1), T.int64(1), T.int64(4), T.int64(4)
        ):
            with T.block("T_layout_trans"):
                v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5 = T.axis.remap(
                    "SSSSSS", [ax0, ax1, ax2, ax3, ax4, ax5]
                )
                T.reads(
                    A[
                        v_ax0 * T.int64(4) + v_ax5,
                        v_ax1 * T.int64(4) + v_ax4,
                        v_ax2,
                        v_ax3,
                    ]
                )
                T.writes(T_layout_trans[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5])
                T.block_attr(
                    {
                        "dst_layout": "OIHW4i4o",
                        "input_shape": [
                            T.int64(128),
                            T.int64(512),
                            T.int64(1),
                            T.int64(1),
                        ],
                        "schedule_rule": "None",
                        "src_layout": "OIHW",
                    }
                )
                T_layout_trans[
                    v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5
                ] = T.if_then_else(
                    v_ax0 * T.int64(4) + v_ax5 < T.int64(128)
                    and v_ax1 * T.int64(4) + v_ax4 < T.int64(512)
                    and v_ax2 < T.int64(1)
                    and v_ax3 < T.int64(1),
                    A[
                        v_ax0 * T.int64(4) + v_ax5,
                        v_ax1 * T.int64(4) + v_ax4,
                        v_ax2,
                        v_ax3,
                    ],
                    T.float32(0),
                )

    @T.prim_func
    def layout_transform15(
        A: T.Buffer((T.int64(256), T.int64(512), T.int64(1), T.int64(1)), "float32"),
        T_layout_trans: T.Buffer(
            (T.int64(64), T.int64(128), T.int64(1), T.int64(1), T.int64(4), T.int64(4)),
            "float32",
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2, ax3, ax4, ax5 in T.grid(
            T.int64(64), T.int64(128), T.int64(1), T.int64(1), T.int64(4), T.int64(4)
        ):
            with T.block("T_layout_trans"):
                v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5 = T.axis.remap(
                    "SSSSSS", [ax0, ax1, ax2, ax3, ax4, ax5]
                )
                T.reads(
                    A[
                        v_ax0 * T.int64(4) + v_ax5,
                        v_ax1 * T.int64(4) + v_ax4,
                        v_ax2,
                        v_ax3,
                    ]
                )
                T.writes(T_layout_trans[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5])
                T.block_attr(
                    {
                        "dst_layout": "OIHW4i4o",
                        "input_shape": [
                            T.int64(256),
                            T.int64(512),
                            T.int64(1),
                            T.int64(1),
                        ],
                        "schedule_rule": "None",
                        "src_layout": "OIHW",
                    }
                )
                T_layout_trans[
                    v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5
                ] = T.if_then_else(
                    v_ax0 * T.int64(4) + v_ax5 < T.int64(256)
                    and v_ax1 * T.int64(4) + v_ax4 < T.int64(512)
                    and v_ax2 < T.int64(1)
                    and v_ax3 < T.int64(1),
                    A[
                        v_ax0 * T.int64(4) + v_ax5,
                        v_ax1 * T.int64(4) + v_ax4,
                        v_ax2,
                        v_ax3,
                    ],
                    T.float32(0),
                )

    @T.prim_func
    def layout_transform16(
        A: T.Buffer((T.int64(256), T.int64(256), T.int64(3), T.int64(3)), "float32"),
        T_layout_trans: T.Buffer(
            (T.int64(64), T.int64(64), T.int64(3), T.int64(3), T.int64(4), T.int64(4)),
            "float32",
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2, ax3, ax4, ax5 in T.grid(
            T.int64(64), T.int64(64), T.int64(3), T.int64(3), T.int64(4), T.int64(4)
        ):
            with T.block("T_layout_trans"):
                v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5 = T.axis.remap(
                    "SSSSSS", [ax0, ax1, ax2, ax3, ax4, ax5]
                )
                T.reads(
                    A[
                        v_ax0 * T.int64(4) + v_ax5,
                        v_ax1 * T.int64(4) + v_ax4,
                        v_ax2,
                        v_ax3,
                    ]
                )
                T.writes(T_layout_trans[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5])
                T.block_attr(
                    {
                        "dst_layout": "OIHW4i4o",
                        "input_shape": [
                            T.int64(256),
                            T.int64(256),
                            T.int64(3),
                            T.int64(3),
                        ],
                        "schedule_rule": "None",
                        "src_layout": "OIHW",
                    }
                )
                T_layout_trans[
                    v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5
                ] = T.if_then_else(
                    v_ax0 * T.int64(4) + v_ax5 < T.int64(256)
                    and v_ax1 * T.int64(4) + v_ax4 < T.int64(256)
                    and v_ax2 < T.int64(3)
                    and v_ax3 < T.int64(3),
                    A[
                        v_ax0 * T.int64(4) + v_ax5,
                        v_ax1 * T.int64(4) + v_ax4,
                        v_ax2,
                        v_ax3,
                    ],
                    T.float32(0),
                )

    @T.prim_func
    def layout_transform17(
        A: T.Buffer((T.int64(1024), T.int64(256), T.int64(1), T.int64(1)), "float32"),
        T_layout_trans: T.Buffer(
            (T.int64(256), T.int64(64), T.int64(1), T.int64(1), T.int64(4), T.int64(4)),
            "float32",
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2, ax3, ax4, ax5 in T.grid(
            T.int64(256), T.int64(64), T.int64(1), T.int64(1), T.int64(4), T.int64(4)
        ):
            with T.block("T_layout_trans"):
                v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5 = T.axis.remap(
                    "SSSSSS", [ax0, ax1, ax2, ax3, ax4, ax5]
                )
                T.reads(
                    A[
                        v_ax0 * T.int64(4) + v_ax5,
                        v_ax1 * T.int64(4) + v_ax4,
                        v_ax2,
                        v_ax3,
                    ]
                )
                T.writes(T_layout_trans[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5])
                T.block_attr(
                    {
                        "dst_layout": "OIHW4i4o",
                        "input_shape": [
                            T.int64(1024),
                            T.int64(256),
                            T.int64(1),
                            T.int64(1),
                        ],
                        "schedule_rule": "None",
                        "src_layout": "OIHW",
                    }
                )
                T_layout_trans[
                    v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5
                ] = T.if_then_else(
                    v_ax0 * T.int64(4) + v_ax5 < T.int64(1024)
                    and v_ax1 * T.int64(4) + v_ax4 < T.int64(256)
                    and v_ax2 < T.int64(1)
                    and v_ax3 < T.int64(1),
                    A[
                        v_ax0 * T.int64(4) + v_ax5,
                        v_ax1 * T.int64(4) + v_ax4,
                        v_ax2,
                        v_ax3,
                    ],
                    T.float32(0),
                )

    @T.prim_func
    def layout_transform18(
        A: T.Buffer((T.int64(1024), T.int64(512), T.int64(1), T.int64(1)), "float32"),
        T_layout_trans: T.Buffer(
            (
                T.int64(256),
                T.int64(128),
                T.int64(1),
                T.int64(1),
                T.int64(4),
                T.int64(4),
            ),
            "float32",
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2, ax3, ax4, ax5 in T.grid(
            T.int64(256), T.int64(128), T.int64(1), T.int64(1), T.int64(4), T.int64(4)
        ):
            with T.block("T_layout_trans"):
                v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5 = T.axis.remap(
                    "SSSSSS", [ax0, ax1, ax2, ax3, ax4, ax5]
                )
                T.reads(
                    A[
                        v_ax0 * T.int64(4) + v_ax5,
                        v_ax1 * T.int64(4) + v_ax4,
                        v_ax2,
                        v_ax3,
                    ]
                )
                T.writes(T_layout_trans[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5])
                T.block_attr(
                    {
                        "dst_layout": "OIHW4i4o",
                        "input_shape": [
                            T.int64(1024),
                            T.int64(512),
                            T.int64(1),
                            T.int64(1),
                        ],
                        "schedule_rule": "None",
                        "src_layout": "OIHW",
                    }
                )
                T_layout_trans[
                    v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5
                ] = T.if_then_else(
                    v_ax0 * T.int64(4) + v_ax5 < T.int64(1024)
                    and v_ax1 * T.int64(4) + v_ax4 < T.int64(512)
                    and v_ax2 < T.int64(1)
                    and v_ax3 < T.int64(1),
                    A[
                        v_ax0 * T.int64(4) + v_ax5,
                        v_ax1 * T.int64(4) + v_ax4,
                        v_ax2,
                        v_ax3,
                    ],
                    T.float32(0),
                )

    @T.prim_func
    def layout_transform19(
        A: T.Buffer((T.int64(1), T.int64(1024), T.int64(1), T.int64(1)), "float32"),
        T_layout_trans: T.Buffer(
            (T.int64(1), T.int64(256), T.int64(1), T.int64(1), T.int64(4)), "float32"
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2, ax3, ax4 in T.grid(
            T.int64(1), T.int64(256), T.int64(1), T.int64(1), T.int64(4)
        ):
            with T.block("T_layout_trans"):
                v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap(
                    "SSSSS", [ax0, ax1, ax2, ax3, ax4]
                )
                T.reads(A[v_ax0, v_ax1 * T.int64(4) + v_ax4, v_ax2, v_ax3])
                T.writes(T_layout_trans[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
                T.block_attr(
                    {
                        "dst_layout": "NCHW4c",
                        "input_shape": [
                            T.int64(1),
                            T.int64(1024),
                            T.int64(1),
                            T.int64(1),
                        ],
                        "schedule_rule": "None",
                        "src_layout": "NCHW",
                    }
                )
                T_layout_trans[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = T.if_then_else(
                    v_ax0 < T.int64(1)
                    and v_ax1 * T.int64(4) + v_ax4 < T.int64(1024)
                    and v_ax2 < T.int64(1)
                    and v_ax3 < T.int64(1),
                    A[v_ax0, v_ax1 * T.int64(4) + v_ax4, v_ax2, v_ax3],
                    T.float32(0),
                )

    @T.prim_func
    def layout_transform2(
        A: T.Buffer((T.int64(1), T.int64(64), T.int64(1), T.int64(1)), "float32"),
        T_layout_trans: T.Buffer(
            (T.int64(1), T.int64(16), T.int64(1), T.int64(1), T.int64(4)), "float32"
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2, ax3, ax4 in T.grid(
            T.int64(1), T.int64(16), T.int64(1), T.int64(1), T.int64(4)
        ):
            with T.block("T_layout_trans"):
                v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap(
                    "SSSSS", [ax0, ax1, ax2, ax3, ax4]
                )
                T.reads(A[v_ax0, v_ax1 * T.int64(4) + v_ax4, v_ax2, v_ax3])
                T.writes(T_layout_trans[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
                T.block_attr(
                    {
                        "dst_layout": "NCHW4c",
                        "input_shape": [
                            T.int64(1),
                            T.int64(64),
                            T.int64(1),
                            T.int64(1),
                        ],
                        "schedule_rule": "None",
                        "src_layout": "NCHW",
                    }
                )
                T_layout_trans[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = T.if_then_else(
                    v_ax0 < T.int64(1)
                    and v_ax1 * T.int64(4) + v_ax4 < T.int64(64)
                    and v_ax2 < T.int64(1)
                    and v_ax3 < T.int64(1),
                    A[v_ax0, v_ax1 * T.int64(4) + v_ax4, v_ax2, v_ax3],
                    T.float32(0),
                )

    @T.prim_func
    def layout_transform20(
        A: T.Buffer((T.int64(256), T.int64(1024), T.int64(1), T.int64(1)), "float32"),
        T_layout_trans: T.Buffer(
            (T.int64(64), T.int64(256), T.int64(1), T.int64(1), T.int64(4), T.int64(4)),
            "float32",
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2, ax3, ax4, ax5 in T.grid(
            T.int64(64), T.int64(256), T.int64(1), T.int64(1), T.int64(4), T.int64(4)
        ):
            with T.block("T_layout_trans"):
                v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5 = T.axis.remap(
                    "SSSSSS", [ax0, ax1, ax2, ax3, ax4, ax5]
                )
                T.reads(
                    A[
                        v_ax0 * T.int64(4) + v_ax5,
                        v_ax1 * T.int64(4) + v_ax4,
                        v_ax2,
                        v_ax3,
                    ]
                )
                T.writes(T_layout_trans[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5])
                T.block_attr(
                    {
                        "dst_layout": "OIHW4i4o",
                        "input_shape": [
                            T.int64(256),
                            T.int64(1024),
                            T.int64(1),
                            T.int64(1),
                        ],
                        "schedule_rule": "None",
                        "src_layout": "OIHW",
                    }
                )
                T_layout_trans[
                    v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5
                ] = T.if_then_else(
                    v_ax0 * T.int64(4) + v_ax5 < T.int64(256)
                    and v_ax1 * T.int64(4) + v_ax4 < T.int64(1024)
                    and v_ax2 < T.int64(1)
                    and v_ax3 < T.int64(1),
                    A[
                        v_ax0 * T.int64(4) + v_ax5,
                        v_ax1 * T.int64(4) + v_ax4,
                        v_ax2,
                        v_ax3,
                    ],
                    T.float32(0),
                )

    @T.prim_func
    def layout_transform21(
        A: T.Buffer((T.int64(512), T.int64(1024), T.int64(1), T.int64(1)), "float32"),
        T_layout_trans: T.Buffer(
            (
                T.int64(128),
                T.int64(256),
                T.int64(1),
                T.int64(1),
                T.int64(4),
                T.int64(4),
            ),
            "float32",
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2, ax3, ax4, ax5 in T.grid(
            T.int64(128), T.int64(256), T.int64(1), T.int64(1), T.int64(4), T.int64(4)
        ):
            with T.block("T_layout_trans"):
                v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5 = T.axis.remap(
                    "SSSSSS", [ax0, ax1, ax2, ax3, ax4, ax5]
                )
                T.reads(
                    A[
                        v_ax0 * T.int64(4) + v_ax5,
                        v_ax1 * T.int64(4) + v_ax4,
                        v_ax2,
                        v_ax3,
                    ]
                )
                T.writes(T_layout_trans[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5])
                T.block_attr(
                    {
                        "dst_layout": "OIHW4i4o",
                        "input_shape": [
                            T.int64(512),
                            T.int64(1024),
                            T.int64(1),
                            T.int64(1),
                        ],
                        "schedule_rule": "None",
                        "src_layout": "OIHW",
                    }
                )
                T_layout_trans[
                    v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5
                ] = T.if_then_else(
                    v_ax0 * T.int64(4) + v_ax5 < T.int64(512)
                    and v_ax1 * T.int64(4) + v_ax4 < T.int64(1024)
                    and v_ax2 < T.int64(1)
                    and v_ax3 < T.int64(1),
                    A[
                        v_ax0 * T.int64(4) + v_ax5,
                        v_ax1 * T.int64(4) + v_ax4,
                        v_ax2,
                        v_ax3,
                    ],
                    T.float32(0),
                )

    @T.prim_func
    def layout_transform22(
        A: T.Buffer((T.int64(512), T.int64(512), T.int64(3), T.int64(3)), "float32"),
        T_layout_trans: T.Buffer(
            (
                T.int64(128),
                T.int64(128),
                T.int64(3),
                T.int64(3),
                T.int64(4),
                T.int64(4),
            ),
            "float32",
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2, ax3, ax4, ax5 in T.grid(
            T.int64(128), T.int64(128), T.int64(3), T.int64(3), T.int64(4), T.int64(4)
        ):
            with T.block("T_layout_trans"):
                v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5 = T.axis.remap(
                    "SSSSSS", [ax0, ax1, ax2, ax3, ax4, ax5]
                )
                T.reads(
                    A[
                        v_ax0 * T.int64(4) + v_ax5,
                        v_ax1 * T.int64(4) + v_ax4,
                        v_ax2,
                        v_ax3,
                    ]
                )
                T.writes(T_layout_trans[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5])
                T.block_attr(
                    {
                        "dst_layout": "OIHW4i4o",
                        "input_shape": [
                            T.int64(512),
                            T.int64(512),
                            T.int64(3),
                            T.int64(3),
                        ],
                        "schedule_rule": "None",
                        "src_layout": "OIHW",
                    }
                )
                T_layout_trans[
                    v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5
                ] = T.if_then_else(
                    v_ax0 * T.int64(4) + v_ax5 < T.int64(512)
                    and v_ax1 * T.int64(4) + v_ax4 < T.int64(512)
                    and v_ax2 < T.int64(3)
                    and v_ax3 < T.int64(3),
                    A[
                        v_ax0 * T.int64(4) + v_ax5,
                        v_ax1 * T.int64(4) + v_ax4,
                        v_ax2,
                        v_ax3,
                    ],
                    T.float32(0),
                )

    @T.prim_func
    def layout_transform23(
        A: T.Buffer((T.int64(2048), T.int64(512), T.int64(1), T.int64(1)), "float32"),
        T_layout_trans: T.Buffer(
            (
                T.int64(512),
                T.int64(128),
                T.int64(1),
                T.int64(1),
                T.int64(4),
                T.int64(4),
            ),
            "float32",
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2, ax3, ax4, ax5 in T.grid(
            T.int64(512), T.int64(128), T.int64(1), T.int64(1), T.int64(4), T.int64(4)
        ):
            with T.block("T_layout_trans"):
                v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5 = T.axis.remap(
                    "SSSSSS", [ax0, ax1, ax2, ax3, ax4, ax5]
                )
                T.reads(
                    A[
                        v_ax0 * T.int64(4) + v_ax5,
                        v_ax1 * T.int64(4) + v_ax4,
                        v_ax2,
                        v_ax3,
                    ]
                )
                T.writes(T_layout_trans[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5])
                T.block_attr(
                    {
                        "dst_layout": "OIHW4i4o",
                        "input_shape": [
                            T.int64(2048),
                            T.int64(512),
                            T.int64(1),
                            T.int64(1),
                        ],
                        "schedule_rule": "None",
                        "src_layout": "OIHW",
                    }
                )
                T_layout_trans[
                    v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5
                ] = T.if_then_else(
                    v_ax0 * T.int64(4) + v_ax5 < T.int64(2048)
                    and v_ax1 * T.int64(4) + v_ax4 < T.int64(512)
                    and v_ax2 < T.int64(1)
                    and v_ax3 < T.int64(1),
                    A[
                        v_ax0 * T.int64(4) + v_ax5,
                        v_ax1 * T.int64(4) + v_ax4,
                        v_ax2,
                        v_ax3,
                    ],
                    T.float32(0),
                )

    @T.prim_func
    def layout_transform24(
        A: T.Buffer((T.int64(2048), T.int64(1024), T.int64(1), T.int64(1)), "float32"),
        T_layout_trans: T.Buffer(
            (
                T.int64(512),
                T.int64(256),
                T.int64(1),
                T.int64(1),
                T.int64(4),
                T.int64(4),
            ),
            "float32",
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2, ax3, ax4, ax5 in T.grid(
            T.int64(512), T.int64(256), T.int64(1), T.int64(1), T.int64(4), T.int64(4)
        ):
            with T.block("T_layout_trans"):
                v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5 = T.axis.remap(
                    "SSSSSS", [ax0, ax1, ax2, ax3, ax4, ax5]
                )
                T.reads(
                    A[
                        v_ax0 * T.int64(4) + v_ax5,
                        v_ax1 * T.int64(4) + v_ax4,
                        v_ax2,
                        v_ax3,
                    ]
                )
                T.writes(T_layout_trans[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5])
                T.block_attr(
                    {
                        "dst_layout": "OIHW4i4o",
                        "input_shape": [
                            T.int64(2048),
                            T.int64(1024),
                            T.int64(1),
                            T.int64(1),
                        ],
                        "schedule_rule": "None",
                        "src_layout": "OIHW",
                    }
                )
                T_layout_trans[
                    v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5
                ] = T.if_then_else(
                    v_ax0 * T.int64(4) + v_ax5 < T.int64(2048)
                    and v_ax1 * T.int64(4) + v_ax4 < T.int64(1024)
                    and v_ax2 < T.int64(1)
                    and v_ax3 < T.int64(1),
                    A[
                        v_ax0 * T.int64(4) + v_ax5,
                        v_ax1 * T.int64(4) + v_ax4,
                        v_ax2,
                        v_ax3,
                    ],
                    T.float32(0),
                )

    @T.prim_func
    def layout_transform25(
        A: T.Buffer((T.int64(1), T.int64(2048), T.int64(1), T.int64(1)), "float32"),
        T_layout_trans: T.Buffer(
            (T.int64(1), T.int64(512), T.int64(1), T.int64(1), T.int64(4)), "float32"
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2, ax3, ax4 in T.grid(
            T.int64(1), T.int64(512), T.int64(1), T.int64(1), T.int64(4)
        ):
            with T.block("T_layout_trans"):
                v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap(
                    "SSSSS", [ax0, ax1, ax2, ax3, ax4]
                )
                T.reads(A[v_ax0, v_ax1 * T.int64(4) + v_ax4, v_ax2, v_ax3])
                T.writes(T_layout_trans[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
                T.block_attr(
                    {
                        "dst_layout": "NCHW4c",
                        "input_shape": [
                            T.int64(1),
                            T.int64(2048),
                            T.int64(1),
                            T.int64(1),
                        ],
                        "schedule_rule": "None",
                        "src_layout": "NCHW",
                    }
                )
                T_layout_trans[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = T.if_then_else(
                    v_ax0 < T.int64(1)
                    and v_ax1 * T.int64(4) + v_ax4 < T.int64(2048)
                    and v_ax2 < T.int64(1)
                    and v_ax3 < T.int64(1),
                    A[v_ax0, v_ax1 * T.int64(4) + v_ax4, v_ax2, v_ax3],
                    T.float32(0),
                )

    @T.prim_func
    def layout_transform26(
        A: T.Buffer((T.int64(512), T.int64(2048), T.int64(1), T.int64(1)), "float32"),
        T_layout_trans: T.Buffer(
            (
                T.int64(128),
                T.int64(512),
                T.int64(1),
                T.int64(1),
                T.int64(4),
                T.int64(4),
            ),
            "float32",
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2, ax3, ax4, ax5 in T.grid(
            T.int64(128), T.int64(512), T.int64(1), T.int64(1), T.int64(4), T.int64(4)
        ):
            with T.block("T_layout_trans"):
                v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5 = T.axis.remap(
                    "SSSSSS", [ax0, ax1, ax2, ax3, ax4, ax5]
                )
                T.reads(
                    A[
                        v_ax0 * T.int64(4) + v_ax5,
                        v_ax1 * T.int64(4) + v_ax4,
                        v_ax2,
                        v_ax3,
                    ]
                )
                T.writes(T_layout_trans[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5])
                T.block_attr(
                    {
                        "dst_layout": "OIHW4i4o",
                        "input_shape": [
                            T.int64(512),
                            T.int64(2048),
                            T.int64(1),
                            T.int64(1),
                        ],
                        "schedule_rule": "None",
                        "src_layout": "OIHW",
                    }
                )
                T_layout_trans[
                    v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5
                ] = T.if_then_else(
                    v_ax0 * T.int64(4) + v_ax5 < T.int64(512)
                    and v_ax1 * T.int64(4) + v_ax4 < T.int64(2048)
                    and v_ax2 < T.int64(1)
                    and v_ax3 < T.int64(1),
                    A[
                        v_ax0 * T.int64(4) + v_ax5,
                        v_ax1 * T.int64(4) + v_ax4,
                        v_ax2,
                        v_ax3,
                    ],
                    T.float32(0),
                )

    @T.prim_func
    def layout_transform27(
        A: T.Buffer(
            (T.int64(1), T.int64(512), T.int64(1), T.int64(1), T.int64(4)), "float32"
        ),
        T_layout_trans: T.Buffer(
            (T.int64(1), T.int64(2048), T.int64(1), T.int64(1)), "float32"
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2, ax3 in T.grid(
            T.int64(1), T.int64(2048), T.int64(1), T.int64(1)
        ):
            with T.block("T_layout_trans"):
                v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                T.reads(A[v_ax0, v_ax1 // T.int64(4), v_ax2, v_ax3, v_ax1 % T.int64(4)])
                T.writes(T_layout_trans[v_ax0, v_ax1, v_ax2, v_ax3])
                T.block_attr(
                    {
                        "dst_layout": "NCHW",
                        "input_shape": [
                            T.int64(1),
                            T.int64(512),
                            T.int64(1),
                            T.int64(1),
                            T.int64(4),
                        ],
                        "schedule_rule": "None",
                        "src_layout": "NCHW4c",
                    }
                )
                T_layout_trans[v_ax0, v_ax1, v_ax2, v_ax3] = T.if_then_else(
                    v_ax0 < T.int64(1)
                    and v_ax1 < T.int64(2048)
                    and v_ax2 < T.int64(1)
                    and v_ax3 < T.int64(1),
                    A[v_ax0, v_ax1 // T.int64(4), v_ax2, v_ax3, v_ax1 % T.int64(4)],
                    T.float32(0),
                )

    @T.prim_func
    def layout_transform3(
        A: T.Buffer((T.int64(64), T.int64(64), T.int64(1), T.int64(1)), "float32"),
        T_layout_trans: T.Buffer(
            (T.int64(16), T.int64(16), T.int64(1), T.int64(1), T.int64(4), T.int64(4)),
            "float32",
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2, ax3, ax4, ax5 in T.grid(
            T.int64(16), T.int64(16), T.int64(1), T.int64(1), T.int64(4), T.int64(4)
        ):
            with T.block("T_layout_trans"):
                v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5 = T.axis.remap(
                    "SSSSSS", [ax0, ax1, ax2, ax3, ax4, ax5]
                )
                T.reads(
                    A[
                        v_ax0 * T.int64(4) + v_ax5,
                        v_ax1 * T.int64(4) + v_ax4,
                        v_ax2,
                        v_ax3,
                    ]
                )
                T.writes(T_layout_trans[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5])
                T.block_attr(
                    {
                        "dst_layout": "OIHW4i4o",
                        "input_shape": [
                            T.int64(64),
                            T.int64(64),
                            T.int64(1),
                            T.int64(1),
                        ],
                        "schedule_rule": "None",
                        "src_layout": "OIHW",
                    }
                )
                T_layout_trans[
                    v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5
                ] = T.if_then_else(
                    v_ax0 * T.int64(4) + v_ax5 < T.int64(64)
                    and v_ax1 * T.int64(4) + v_ax4 < T.int64(64)
                    and v_ax2 < T.int64(1)
                    and v_ax3 < T.int64(1),
                    A[
                        v_ax0 * T.int64(4) + v_ax5,
                        v_ax1 * T.int64(4) + v_ax4,
                        v_ax2,
                        v_ax3,
                    ],
                    T.float32(0),
                )

    @T.prim_func
    def layout_transform4(
        A: T.Buffer((T.int64(64), T.int64(64), T.int64(3), T.int64(3)), "float32"),
        T_layout_trans: T.Buffer(
            (T.int64(16), T.int64(16), T.int64(3), T.int64(3), T.int64(4), T.int64(4)),
            "float32",
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2, ax3, ax4, ax5 in T.grid(
            T.int64(16), T.int64(16), T.int64(3), T.int64(3), T.int64(4), T.int64(4)
        ):
            with T.block("T_layout_trans"):
                v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5 = T.axis.remap(
                    "SSSSSS", [ax0, ax1, ax2, ax3, ax4, ax5]
                )
                T.reads(
                    A[
                        v_ax0 * T.int64(4) + v_ax5,
                        v_ax1 * T.int64(4) + v_ax4,
                        v_ax2,
                        v_ax3,
                    ]
                )
                T.writes(T_layout_trans[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5])
                T.block_attr(
                    {
                        "dst_layout": "OIHW4i4o",
                        "input_shape": [
                            T.int64(64),
                            T.int64(64),
                            T.int64(3),
                            T.int64(3),
                        ],
                        "schedule_rule": "None",
                        "src_layout": "OIHW",
                    }
                )
                T_layout_trans[
                    v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5
                ] = T.if_then_else(
                    v_ax0 * T.int64(4) + v_ax5 < T.int64(64)
                    and v_ax1 * T.int64(4) + v_ax4 < T.int64(64)
                    and v_ax2 < T.int64(3)
                    and v_ax3 < T.int64(3),
                    A[
                        v_ax0 * T.int64(4) + v_ax5,
                        v_ax1 * T.int64(4) + v_ax4,
                        v_ax2,
                        v_ax3,
                    ],
                    T.float32(0),
                )

    @T.prim_func
    def layout_transform5(
        A: T.Buffer((T.int64(256), T.int64(64), T.int64(1), T.int64(1)), "float32"),
        T_layout_trans: T.Buffer(
            (T.int64(64), T.int64(16), T.int64(1), T.int64(1), T.int64(4), T.int64(4)),
            "float32",
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2, ax3, ax4, ax5 in T.grid(
            T.int64(64), T.int64(16), T.int64(1), T.int64(1), T.int64(4), T.int64(4)
        ):
            with T.block("T_layout_trans"):
                v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5 = T.axis.remap(
                    "SSSSSS", [ax0, ax1, ax2, ax3, ax4, ax5]
                )
                T.reads(
                    A[
                        v_ax0 * T.int64(4) + v_ax5,
                        v_ax1 * T.int64(4) + v_ax4,
                        v_ax2,
                        v_ax3,
                    ]
                )
                T.writes(T_layout_trans[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5])
                T.block_attr(
                    {
                        "dst_layout": "OIHW4i4o",
                        "input_shape": [
                            T.int64(256),
                            T.int64(64),
                            T.int64(1),
                            T.int64(1),
                        ],
                        "schedule_rule": "None",
                        "src_layout": "OIHW",
                    }
                )
                T_layout_trans[
                    v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5
                ] = T.if_then_else(
                    v_ax0 * T.int64(4) + v_ax5 < T.int64(256)
                    and v_ax1 * T.int64(4) + v_ax4 < T.int64(64)
                    and v_ax2 < T.int64(1)
                    and v_ax3 < T.int64(1),
                    A[
                        v_ax0 * T.int64(4) + v_ax5,
                        v_ax1 * T.int64(4) + v_ax4,
                        v_ax2,
                        v_ax3,
                    ],
                    T.float32(0),
                )

    @T.prim_func
    def layout_transform6(
        A: T.Buffer((T.int64(1), T.int64(256), T.int64(1), T.int64(1)), "float32"),
        T_layout_trans: T.Buffer(
            (T.int64(1), T.int64(64), T.int64(1), T.int64(1), T.int64(4)), "float32"
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2, ax3, ax4 in T.grid(
            T.int64(1), T.int64(64), T.int64(1), T.int64(1), T.int64(4)
        ):
            with T.block("T_layout_trans"):
                v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap(
                    "SSSSS", [ax0, ax1, ax2, ax3, ax4]
                )
                T.reads(A[v_ax0, v_ax1 * T.int64(4) + v_ax4, v_ax2, v_ax3])
                T.writes(T_layout_trans[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
                T.block_attr(
                    {
                        "dst_layout": "NCHW4c",
                        "input_shape": [
                            T.int64(1),
                            T.int64(256),
                            T.int64(1),
                            T.int64(1),
                        ],
                        "schedule_rule": "None",
                        "src_layout": "NCHW",
                    }
                )
                T_layout_trans[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = T.if_then_else(
                    v_ax0 < T.int64(1)
                    and v_ax1 * T.int64(4) + v_ax4 < T.int64(256)
                    and v_ax2 < T.int64(1)
                    and v_ax3 < T.int64(1),
                    A[v_ax0, v_ax1 * T.int64(4) + v_ax4, v_ax2, v_ax3],
                    T.float32(0),
                )

    @T.prim_func
    def layout_transform7(
        A: T.Buffer((T.int64(64), T.int64(256), T.int64(1), T.int64(1)), "float32"),
        T_layout_trans: T.Buffer(
            (T.int64(16), T.int64(64), T.int64(1), T.int64(1), T.int64(4), T.int64(4)),
            "float32",
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2, ax3, ax4, ax5 in T.grid(
            T.int64(16), T.int64(64), T.int64(1), T.int64(1), T.int64(4), T.int64(4)
        ):
            with T.block("T_layout_trans"):
                v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5 = T.axis.remap(
                    "SSSSSS", [ax0, ax1, ax2, ax3, ax4, ax5]
                )
                T.reads(
                    A[
                        v_ax0 * T.int64(4) + v_ax5,
                        v_ax1 * T.int64(4) + v_ax4,
                        v_ax2,
                        v_ax3,
                    ]
                )
                T.writes(T_layout_trans[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5])
                T.block_attr(
                    {
                        "dst_layout": "OIHW4i4o",
                        "input_shape": [
                            T.int64(64),
                            T.int64(256),
                            T.int64(1),
                            T.int64(1),
                        ],
                        "schedule_rule": "None",
                        "src_layout": "OIHW",
                    }
                )
                T_layout_trans[
                    v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5
                ] = T.if_then_else(
                    v_ax0 * T.int64(4) + v_ax5 < T.int64(64)
                    and v_ax1 * T.int64(4) + v_ax4 < T.int64(256)
                    and v_ax2 < T.int64(1)
                    and v_ax3 < T.int64(1),
                    A[
                        v_ax0 * T.int64(4) + v_ax5,
                        v_ax1 * T.int64(4) + v_ax4,
                        v_ax2,
                        v_ax3,
                    ],
                    T.float32(0),
                )

    @T.prim_func
    def layout_transform8(
        A: T.Buffer((T.int64(128), T.int64(256), T.int64(1), T.int64(1)), "float32"),
        T_layout_trans: T.Buffer(
            (T.int64(32), T.int64(64), T.int64(1), T.int64(1), T.int64(4), T.int64(4)),
            "float32",
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2, ax3, ax4, ax5 in T.grid(
            T.int64(32), T.int64(64), T.int64(1), T.int64(1), T.int64(4), T.int64(4)
        ):
            with T.block("T_layout_trans"):
                v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5 = T.axis.remap(
                    "SSSSSS", [ax0, ax1, ax2, ax3, ax4, ax5]
                )
                T.reads(
                    A[
                        v_ax0 * T.int64(4) + v_ax5,
                        v_ax1 * T.int64(4) + v_ax4,
                        v_ax2,
                        v_ax3,
                    ]
                )
                T.writes(T_layout_trans[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5])
                T.block_attr(
                    {
                        "dst_layout": "OIHW4i4o",
                        "input_shape": [
                            T.int64(128),
                            T.int64(256),
                            T.int64(1),
                            T.int64(1),
                        ],
                        "schedule_rule": "None",
                        "src_layout": "OIHW",
                    }
                )
                T_layout_trans[
                    v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5
                ] = T.if_then_else(
                    v_ax0 * T.int64(4) + v_ax5 < T.int64(128)
                    and v_ax1 * T.int64(4) + v_ax4 < T.int64(256)
                    and v_ax2 < T.int64(1)
                    and v_ax3 < T.int64(1),
                    A[
                        v_ax0 * T.int64(4) + v_ax5,
                        v_ax1 * T.int64(4) + v_ax4,
                        v_ax2,
                        v_ax3,
                    ],
                    T.float32(0),
                )

    @T.prim_func
    def layout_transform9(
        A: T.Buffer((T.int64(1), T.int64(128), T.int64(1), T.int64(1)), "float32"),
        T_layout_trans: T.Buffer(
            (T.int64(1), T.int64(32), T.int64(1), T.int64(1), T.int64(4)), "float32"
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2, ax3, ax4 in T.grid(
            T.int64(1), T.int64(32), T.int64(1), T.int64(1), T.int64(4)
        ):
            with T.block("T_layout_trans"):
                v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap(
                    "SSSSS", [ax0, ax1, ax2, ax3, ax4]
                )
                T.reads(A[v_ax0, v_ax1 * T.int64(4) + v_ax4, v_ax2, v_ax3])
                T.writes(T_layout_trans[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
                T.block_attr(
                    {
                        "dst_layout": "NCHW4c",
                        "input_shape": [
                            T.int64(1),
                            T.int64(128),
                            T.int64(1),
                            T.int64(1),
                        ],
                        "schedule_rule": "None",
                        "src_layout": "NCHW",
                    }
                )
                T_layout_trans[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = T.if_then_else(
                    v_ax0 < T.int64(1)
                    and v_ax1 * T.int64(4) + v_ax4 < T.int64(128)
                    and v_ax2 < T.int64(1)
                    and v_ax3 < T.int64(1),
                    A[v_ax0, v_ax1 * T.int64(4) + v_ax4, v_ax2, v_ax3],
                    T.float32(0),
                )

    @T.prim_func
    def max_pool2d(
        A: T.Buffer(
            (T.int64(1), T.int64(16), T.int64(112), T.int64(112), T.int64(4)), "float32"
        ),
        pool_max: T.Buffer(
            (T.int64(1), T.int64(16), T.int64(56), T.int64(56), T.int64(4)), "float32"
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        pad_temp = T.alloc_buffer(
            (T.int64(1), T.int64(16), T.int64(114), T.int64(114), T.int64(4))
        )
        for ax0, ax1, ax2, ax3, ax4 in T.grid(
            T.int64(1), T.int64(16), T.int64(114), T.int64(114), T.int64(4)
        ):
            with T.block("pad_temp"):
                v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap(
                    "SSSSS", [ax0, ax1, ax2, ax3, ax4]
                )
                T.reads(A[v_ax0, v_ax1, v_ax2 - T.int64(1), v_ax3 - T.int64(1), v_ax4])
                T.writes(pad_temp[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
                pad_temp[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = T.if_then_else(
                    T.int64(1) <= v_ax2
                    and v_ax2 < T.int64(113)
                    and T.int64(1) <= v_ax3
                    and v_ax3 < T.int64(113),
                    A[v_ax0, v_ax1, v_ax2 - T.int64(1), v_ax3 - T.int64(1), v_ax4],
                    T.float32(-3.4028234663852886e38),
                )
        for ax0, ax1, ax2, ax3, ax4, rv0, rv1 in T.grid(
            T.int64(1),
            T.int64(16),
            T.int64(56),
            T.int64(56),
            T.int64(4),
            T.int64(3),
            T.int64(3),
        ):
            with T.block("pool_max"):
                v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_rv0, v_rv1 = T.axis.remap(
                    "SSSSSRR", [ax0, ax1, ax2, ax3, ax4, rv0, rv1]
                )
                T.reads(
                    pad_temp[
                        v_ax0,
                        v_ax1,
                        v_ax2 * T.int64(2) + v_rv0,
                        v_ax3 * T.int64(2) + v_rv1,
                        v_ax4,
                    ]
                )
                T.writes(pool_max[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
                T.block_attr({"schedule_rule": "meta_schedule.pool_max"})
                with T.init():
                    pool_max[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = T.float32(
                        -3.4028234663852886e38
                    )
                pool_max[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = T.max(
                    pool_max[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4],
                    pad_temp[
                        v_ax0,
                        v_ax1,
                        v_ax2 * T.int64(2) + v_rv0,
                        v_ax3 * T.int64(2) + v_rv1,
                        v_ax4,
                    ],
                )

    @T.prim_func
    def multiply(
        A: T.Buffer((T.int64(3),), "float32"),
        B: T.Buffer((T.int64(3),), "float32"),
        T_multiply: T.Buffer((T.int64(3),), "float32"),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0 in range(T.int64(3)):
            with T.block("T_multiply"):
                v_ax0 = T.axis.spatial(T.int64(3), ax0)
                T.reads(A[v_ax0], B[v_ax0])
                T.writes(T_multiply[v_ax0])
                T_multiply[v_ax0] = A[v_ax0] * B[v_ax0]

    @T.prim_func
    def multiply1(
        A: T.Buffer((T.int64(64),), "float32"),
        B: T.Buffer((T.int64(64),), "float32"),
        T_multiply: T.Buffer((T.int64(64),), "float32"),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0 in range(T.int64(64)):
            with T.block("T_multiply"):
                v_ax0 = T.axis.spatial(T.int64(64), ax0)
                T.reads(A[v_ax0], B[v_ax0])
                T.writes(T_multiply[v_ax0])
                T_multiply[v_ax0] = A[v_ax0] * B[v_ax0]

    @T.prim_func
    def multiply10(
        A: T.Buffer((T.int64(128),), "float32"),
        B: T.Buffer((T.int64(128),), "float32"),
        T_multiply: T.Buffer((T.int64(128),), "float32"),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0 in range(T.int64(128)):
            with T.block("T_multiply"):
                v_ax0 = T.axis.spatial(T.int64(128), ax0)
                T.reads(A[v_ax0], B[v_ax0])
                T.writes(T_multiply[v_ax0])
                T_multiply[v_ax0] = A[v_ax0] * B[v_ax0]

    @T.prim_func
    def multiply11(
        A: T.Buffer((T.int64(128), T.int64(256), T.int64(1), T.int64(1)), "float32"),
        B: T.Buffer((T.int64(128), T.int64(1), T.int64(1), T.int64(1)), "float32"),
        T_multiply: T.Buffer(
            (T.int64(128), T.int64(256), T.int64(1), T.int64(1)), "float32"
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2, ax3 in T.grid(
            T.int64(128), T.int64(256), T.int64(1), T.int64(1)
        ):
            with T.block("T_multiply"):
                v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                T.reads(
                    A[v_ax0, v_ax1, v_ax2, v_ax3], B[v_ax0, T.int64(0), v_ax2, v_ax3]
                )
                T.writes(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3])
                T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] = (
                    A[v_ax0, v_ax1, v_ax2, v_ax3] * B[v_ax0, T.int64(0), v_ax2, v_ax3]
                )

    @T.prim_func
    def multiply12(
        A: T.Buffer((T.int64(128), T.int64(128), T.int64(3), T.int64(3)), "float32"),
        B: T.Buffer((T.int64(128), T.int64(1), T.int64(1), T.int64(1)), "float32"),
        T_multiply: T.Buffer(
            (T.int64(128), T.int64(128), T.int64(3), T.int64(3)), "float32"
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2, ax3 in T.grid(
            T.int64(128), T.int64(128), T.int64(3), T.int64(3)
        ):
            with T.block("T_multiply"):
                v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                T.reads(
                    A[v_ax0, v_ax1, v_ax2, v_ax3],
                    B[v_ax0, T.int64(0), T.int64(0), T.int64(0)],
                )
                T.writes(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3])
                T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] = (
                    A[v_ax0, v_ax1, v_ax2, v_ax3]
                    * B[v_ax0, T.int64(0), T.int64(0), T.int64(0)]
                )

    @T.prim_func
    def multiply13(
        A: T.Buffer((T.int64(512),), "float32"),
        B: T.Buffer((T.int64(512),), "float32"),
        T_multiply: T.Buffer((T.int64(512),), "float32"),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0 in range(T.int64(512)):
            with T.block("T_multiply"):
                v_ax0 = T.axis.spatial(T.int64(512), ax0)
                T.reads(A[v_ax0], B[v_ax0])
                T.writes(T_multiply[v_ax0])
                T_multiply[v_ax0] = A[v_ax0] * B[v_ax0]

    @T.prim_func
    def multiply14(
        A: T.Buffer(
            (T.int64(1), T.int64(128), T.int64(28), T.int64(28), T.int64(4)), "float32"
        ),
        B: T.Buffer(
            (T.int64(1), T.int64(128), T.int64(1), T.int64(1), T.int64(4)), "float32"
        ),
        T_multiply: T.Buffer(
            (T.int64(1), T.int64(128), T.int64(28), T.int64(28), T.int64(4)), "float32"
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2, ax3, ax4 in T.grid(
            T.int64(1), T.int64(128), T.int64(28), T.int64(28), T.int64(4)
        ):
            with T.block("T_multiply"):
                v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap(
                    "SSSSS", [ax0, ax1, ax2, ax3, ax4]
                )
                T.reads(
                    A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4],
                    B[v_ax0, v_ax1, T.int64(0), T.int64(0), v_ax4],
                )
                T.writes(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
                T_multiply[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = (
                    A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]
                    * B[v_ax0, v_ax1, T.int64(0), T.int64(0), v_ax4]
                )

    @T.prim_func
    def multiply15(
        A: T.Buffer((T.int64(128), T.int64(512), T.int64(1), T.int64(1)), "float32"),
        B: T.Buffer((T.int64(128), T.int64(1), T.int64(1), T.int64(1)), "float32"),
        T_multiply: T.Buffer(
            (T.int64(128), T.int64(512), T.int64(1), T.int64(1)), "float32"
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2, ax3 in T.grid(
            T.int64(128), T.int64(512), T.int64(1), T.int64(1)
        ):
            with T.block("T_multiply"):
                v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                T.reads(
                    A[v_ax0, v_ax1, v_ax2, v_ax3], B[v_ax0, T.int64(0), v_ax2, v_ax3]
                )
                T.writes(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3])
                T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] = (
                    A[v_ax0, v_ax1, v_ax2, v_ax3] * B[v_ax0, T.int64(0), v_ax2, v_ax3]
                )

    @T.prim_func
    def multiply16(
        A: T.Buffer((T.int64(256), T.int64(512), T.int64(1), T.int64(1)), "float32"),
        B: T.Buffer((T.int64(256), T.int64(1), T.int64(1), T.int64(1)), "float32"),
        T_multiply: T.Buffer(
            (T.int64(256), T.int64(512), T.int64(1), T.int64(1)), "float32"
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2, ax3 in T.grid(
            T.int64(256), T.int64(512), T.int64(1), T.int64(1)
        ):
            with T.block("T_multiply"):
                v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                T.reads(
                    A[v_ax0, v_ax1, v_ax2, v_ax3], B[v_ax0, T.int64(0), v_ax2, v_ax3]
                )
                T.writes(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3])
                T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] = (
                    A[v_ax0, v_ax1, v_ax2, v_ax3] * B[v_ax0, T.int64(0), v_ax2, v_ax3]
                )

    @T.prim_func
    def multiply17(
        A: T.Buffer((T.int64(256), T.int64(256), T.int64(3), T.int64(3)), "float32"),
        B: T.Buffer((T.int64(256), T.int64(1), T.int64(1), T.int64(1)), "float32"),
        T_multiply: T.Buffer(
            (T.int64(256), T.int64(256), T.int64(3), T.int64(3)), "float32"
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2, ax3 in T.grid(
            T.int64(256), T.int64(256), T.int64(3), T.int64(3)
        ):
            with T.block("T_multiply"):
                v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                T.reads(
                    A[v_ax0, v_ax1, v_ax2, v_ax3],
                    B[v_ax0, T.int64(0), T.int64(0), T.int64(0)],
                )
                T.writes(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3])
                T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] = (
                    A[v_ax0, v_ax1, v_ax2, v_ax3]
                    * B[v_ax0, T.int64(0), T.int64(0), T.int64(0)]
                )

    @T.prim_func
    def multiply18(
        A: T.Buffer((T.int64(1024),), "float32"),
        B: T.Buffer((T.int64(1024),), "float32"),
        T_multiply: T.Buffer((T.int64(1024),), "float32"),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0 in range(T.int64(1024)):
            with T.block("T_multiply"):
                v_ax0 = T.axis.spatial(T.int64(1024), ax0)
                T.reads(A[v_ax0], B[v_ax0])
                T.writes(T_multiply[v_ax0])
                T_multiply[v_ax0] = A[v_ax0] * B[v_ax0]

    @T.prim_func
    def multiply19(
        A: T.Buffer(
            (T.int64(1), T.int64(256), T.int64(14), T.int64(14), T.int64(4)), "float32"
        ),
        B: T.Buffer(
            (T.int64(1), T.int64(256), T.int64(1), T.int64(1), T.int64(4)), "float32"
        ),
        T_multiply: T.Buffer(
            (T.int64(1), T.int64(256), T.int64(14), T.int64(14), T.int64(4)), "float32"
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2, ax3, ax4 in T.grid(
            T.int64(1), T.int64(256), T.int64(14), T.int64(14), T.int64(4)
        ):
            with T.block("T_multiply"):
                v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap(
                    "SSSSS", [ax0, ax1, ax2, ax3, ax4]
                )
                T.reads(
                    A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4],
                    B[v_ax0, v_ax1, T.int64(0), T.int64(0), v_ax4],
                )
                T.writes(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
                T_multiply[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = (
                    A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]
                    * B[v_ax0, v_ax1, T.int64(0), T.int64(0), v_ax4]
                )

    @T.prim_func
    def multiply2(
        A: T.Buffer((T.int64(64), T.int64(3), T.int64(7), T.int64(7)), "float32"),
        B: T.Buffer((T.int64(64), T.int64(1), T.int64(1), T.int64(1)), "float32"),
        T_multiply: T.Buffer(
            (T.int64(64), T.int64(3), T.int64(7), T.int64(7)), "float32"
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2, ax3 in T.grid(
            T.int64(64), T.int64(3), T.int64(7), T.int64(7)
        ):
            with T.block("T_multiply"):
                v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                T.reads(
                    A[v_ax0, v_ax1, v_ax2, v_ax3],
                    B[v_ax0, T.int64(0), T.int64(0), T.int64(0)],
                )
                T.writes(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3])
                T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] = (
                    A[v_ax0, v_ax1, v_ax2, v_ax3]
                    * B[v_ax0, T.int64(0), T.int64(0), T.int64(0)]
                )

    @T.prim_func
    def multiply20(
        A: T.Buffer((T.int64(256), T.int64(1024), T.int64(1), T.int64(1)), "float32"),
        B: T.Buffer((T.int64(256), T.int64(1), T.int64(1), T.int64(1)), "float32"),
        T_multiply: T.Buffer(
            (T.int64(256), T.int64(1024), T.int64(1), T.int64(1)), "float32"
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2, ax3 in T.grid(
            T.int64(256), T.int64(1024), T.int64(1), T.int64(1)
        ):
            with T.block("T_multiply"):
                v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                T.reads(
                    A[v_ax0, v_ax1, v_ax2, v_ax3], B[v_ax0, T.int64(0), v_ax2, v_ax3]
                )
                T.writes(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3])
                T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] = (
                    A[v_ax0, v_ax1, v_ax2, v_ax3] * B[v_ax0, T.int64(0), v_ax2, v_ax3]
                )

    @T.prim_func
    def multiply21(
        A: T.Buffer((T.int64(512), T.int64(1024), T.int64(1), T.int64(1)), "float32"),
        B: T.Buffer((T.int64(512), T.int64(1), T.int64(1), T.int64(1)), "float32"),
        T_multiply: T.Buffer(
            (T.int64(512), T.int64(1024), T.int64(1), T.int64(1)), "float32"
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2, ax3 in T.grid(
            T.int64(512), T.int64(1024), T.int64(1), T.int64(1)
        ):
            with T.block("T_multiply"):
                v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                T.reads(
                    A[v_ax0, v_ax1, v_ax2, v_ax3], B[v_ax0, T.int64(0), v_ax2, v_ax3]
                )
                T.writes(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3])
                T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] = (
                    A[v_ax0, v_ax1, v_ax2, v_ax3] * B[v_ax0, T.int64(0), v_ax2, v_ax3]
                )

    @T.prim_func
    def multiply22(
        A: T.Buffer((T.int64(512), T.int64(512), T.int64(3), T.int64(3)), "float32"),
        B: T.Buffer((T.int64(512), T.int64(1), T.int64(1), T.int64(1)), "float32"),
        T_multiply: T.Buffer(
            (T.int64(512), T.int64(512), T.int64(3), T.int64(3)), "float32"
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2, ax3 in T.grid(
            T.int64(512), T.int64(512), T.int64(3), T.int64(3)
        ):
            with T.block("T_multiply"):
                v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                T.reads(
                    A[v_ax0, v_ax1, v_ax2, v_ax3],
                    B[v_ax0, T.int64(0), T.int64(0), T.int64(0)],
                )
                T.writes(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3])
                T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] = (
                    A[v_ax0, v_ax1, v_ax2, v_ax3]
                    * B[v_ax0, T.int64(0), T.int64(0), T.int64(0)]
                )

    @T.prim_func
    def multiply23(
        A: T.Buffer((T.int64(2048),), "float32"),
        B: T.Buffer((T.int64(2048),), "float32"),
        T_multiply: T.Buffer((T.int64(2048),), "float32"),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0 in range(T.int64(2048)):
            with T.block("T_multiply"):
                v_ax0 = T.axis.spatial(T.int64(2048), ax0)
                T.reads(A[v_ax0], B[v_ax0])
                T.writes(T_multiply[v_ax0])
                T_multiply[v_ax0] = A[v_ax0] * B[v_ax0]

    @T.prim_func
    def multiply24(
        A: T.Buffer(
            (T.int64(1), T.int64(512), T.int64(7), T.int64(7), T.int64(4)), "float32"
        ),
        B: T.Buffer(
            (T.int64(1), T.int64(512), T.int64(1), T.int64(1), T.int64(4)), "float32"
        ),
        T_multiply: T.Buffer(
            (T.int64(1), T.int64(512), T.int64(7), T.int64(7), T.int64(4)), "float32"
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2, ax3, ax4 in T.grid(
            T.int64(1), T.int64(512), T.int64(7), T.int64(7), T.int64(4)
        ):
            with T.block("T_multiply"):
                v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap(
                    "SSSSS", [ax0, ax1, ax2, ax3, ax4]
                )
                T.reads(
                    A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4],
                    B[v_ax0, v_ax1, T.int64(0), T.int64(0), v_ax4],
                )
                T.writes(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
                T_multiply[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = (
                    A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]
                    * B[v_ax0, v_ax1, T.int64(0), T.int64(0), v_ax4]
                )

    @T.prim_func
    def multiply25(
        A: T.Buffer((T.int64(512), T.int64(2048), T.int64(1), T.int64(1)), "float32"),
        B: T.Buffer((T.int64(512), T.int64(1), T.int64(1), T.int64(1)), "float32"),
        T_multiply: T.Buffer(
            (T.int64(512), T.int64(2048), T.int64(1), T.int64(1)), "float32"
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2, ax3 in T.grid(
            T.int64(512), T.int64(2048), T.int64(1), T.int64(1)
        ):
            with T.block("T_multiply"):
                v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                T.reads(
                    A[v_ax0, v_ax1, v_ax2, v_ax3], B[v_ax0, T.int64(0), v_ax2, v_ax3]
                )
                T.writes(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3])
                T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] = (
                    A[v_ax0, v_ax1, v_ax2, v_ax3] * B[v_ax0, T.int64(0), v_ax2, v_ax3]
                )

    @T.prim_func
    def multiply3(
        A: T.Buffer((T.int64(64), T.int64(3), T.int64(7), T.int64(7)), "float32"),
        B: T.Buffer((T.int64(3), T.int64(1), T.int64(1)), "float32"),
        T_multiply: T.Buffer(
            (T.int64(64), T.int64(3), T.int64(7), T.int64(7)), "float32"
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2, ax3 in T.grid(
            T.int64(64), T.int64(3), T.int64(7), T.int64(7)
        ):
            with T.block("T_multiply"):
                v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                T.reads(A[v_ax0, v_ax1, v_ax2, v_ax3], B[v_ax1, T.int64(0), T.int64(0)])
                T.writes(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3])
                T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] = (
                    A[v_ax0, v_ax1, v_ax2, v_ax3] * B[v_ax1, T.int64(0), T.int64(0)]
                )

    @T.prim_func
    def multiply4(
        A: T.Buffer(
            (T.int64(1), T.int64(16), T.int64(56), T.int64(56), T.int64(4)), "float32"
        ),
        B: T.Buffer(
            (T.int64(1), T.int64(16), T.int64(1), T.int64(1), T.int64(4)), "float32"
        ),
        T_multiply: T.Buffer(
            (T.int64(1), T.int64(16), T.int64(56), T.int64(56), T.int64(4)), "float32"
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2, ax3, ax4 in T.grid(
            T.int64(1), T.int64(16), T.int64(56), T.int64(56), T.int64(4)
        ):
            with T.block("T_multiply"):
                v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap(
                    "SSSSS", [ax0, ax1, ax2, ax3, ax4]
                )
                T.reads(
                    A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4],
                    B[v_ax0, v_ax1, T.int64(0), T.int64(0), v_ax4],
                )
                T.writes(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
                T_multiply[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = (
                    A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]
                    * B[v_ax0, v_ax1, T.int64(0), T.int64(0), v_ax4]
                )

    @T.prim_func
    def multiply5(
        A: T.Buffer((T.int64(64), T.int64(64), T.int64(1), T.int64(1)), "float32"),
        B: T.Buffer((T.int64(64), T.int64(1), T.int64(1), T.int64(1)), "float32"),
        T_multiply: T.Buffer(
            (T.int64(64), T.int64(64), T.int64(1), T.int64(1)), "float32"
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2, ax3 in T.grid(
            T.int64(64), T.int64(64), T.int64(1), T.int64(1)
        ):
            with T.block("T_multiply"):
                v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                T.reads(
                    A[v_ax0, v_ax1, v_ax2, v_ax3], B[v_ax0, T.int64(0), v_ax2, v_ax3]
                )
                T.writes(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3])
                T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] = (
                    A[v_ax0, v_ax1, v_ax2, v_ax3] * B[v_ax0, T.int64(0), v_ax2, v_ax3]
                )

    @T.prim_func
    def multiply6(
        A: T.Buffer((T.int64(64), T.int64(64), T.int64(3), T.int64(3)), "float32"),
        B: T.Buffer((T.int64(64), T.int64(1), T.int64(1), T.int64(1)), "float32"),
        T_multiply: T.Buffer(
            (T.int64(64), T.int64(64), T.int64(3), T.int64(3)), "float32"
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2, ax3 in T.grid(
            T.int64(64), T.int64(64), T.int64(3), T.int64(3)
        ):
            with T.block("T_multiply"):
                v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                T.reads(
                    A[v_ax0, v_ax1, v_ax2, v_ax3],
                    B[v_ax0, T.int64(0), T.int64(0), T.int64(0)],
                )
                T.writes(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3])
                T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] = (
                    A[v_ax0, v_ax1, v_ax2, v_ax3]
                    * B[v_ax0, T.int64(0), T.int64(0), T.int64(0)]
                )

    @T.prim_func
    def multiply7(
        A: T.Buffer((T.int64(256),), "float32"),
        B: T.Buffer((T.int64(256),), "float32"),
        T_multiply: T.Buffer((T.int64(256),), "float32"),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0 in range(T.int64(256)):
            with T.block("T_multiply"):
                v_ax0 = T.axis.spatial(T.int64(256), ax0)
                T.reads(A[v_ax0], B[v_ax0])
                T.writes(T_multiply[v_ax0])
                T_multiply[v_ax0] = A[v_ax0] * B[v_ax0]

    @T.prim_func
    def multiply8(
        A: T.Buffer(
            (T.int64(1), T.int64(64), T.int64(56), T.int64(56), T.int64(4)), "float32"
        ),
        B: T.Buffer(
            (T.int64(1), T.int64(64), T.int64(1), T.int64(1), T.int64(4)), "float32"
        ),
        T_multiply: T.Buffer(
            (T.int64(1), T.int64(64), T.int64(56), T.int64(56), T.int64(4)), "float32"
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2, ax3, ax4 in T.grid(
            T.int64(1), T.int64(64), T.int64(56), T.int64(56), T.int64(4)
        ):
            with T.block("T_multiply"):
                v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap(
                    "SSSSS", [ax0, ax1, ax2, ax3, ax4]
                )
                T.reads(
                    A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4],
                    B[v_ax0, v_ax1, T.int64(0), T.int64(0), v_ax4],
                )
                T.writes(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
                T_multiply[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = (
                    A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]
                    * B[v_ax0, v_ax1, T.int64(0), T.int64(0), v_ax4]
                )

    @T.prim_func
    def multiply9(
        A: T.Buffer((T.int64(64), T.int64(256), T.int64(1), T.int64(1)), "float32"),
        B: T.Buffer((T.int64(64), T.int64(1), T.int64(1), T.int64(1)), "float32"),
        T_multiply: T.Buffer(
            (T.int64(64), T.int64(256), T.int64(1), T.int64(1)), "float32"
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2, ax3 in T.grid(
            T.int64(64), T.int64(256), T.int64(1), T.int64(1)
        ):
            with T.block("T_multiply"):
                v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                T.reads(
                    A[v_ax0, v_ax1, v_ax2, v_ax3], B[v_ax0, T.int64(0), v_ax2, v_ax3]
                )
                T.writes(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3])
                T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] = (
                    A[v_ax0, v_ax1, v_ax2, v_ax3] * B[v_ax0, T.int64(0), v_ax2, v_ax3]
                )

    @T.prim_func
    def negative(
        A: T.Buffer((T.int64(3),), "float32"),
        T_negative: T.Buffer((T.int64(3),), "float32"),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0 in range(T.int64(3)):
            with T.block("T_negative"):
                v_ax0 = T.axis.spatial(T.int64(3), ax0)
                T.reads(A[v_ax0])
                T.writes(T_negative[v_ax0])
                T_negative[v_ax0] = T.float32(0) - A[v_ax0]

    @T.prim_func
    def negative1(
        A: T.Buffer((T.int64(64),), "float32"),
        T_negative: T.Buffer((T.int64(64),), "float32"),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0 in range(T.int64(64)):
            with T.block("T_negative"):
                v_ax0 = T.axis.spatial(T.int64(64), ax0)
                T.reads(A[v_ax0])
                T.writes(T_negative[v_ax0])
                T_negative[v_ax0] = T.float32(0) - A[v_ax0]

    @T.prim_func
    def negative2(
        A: T.Buffer((T.int64(256),), "float32"),
        T_negative: T.Buffer((T.int64(256),), "float32"),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0 in range(T.int64(256)):
            with T.block("T_negative"):
                v_ax0 = T.axis.spatial(T.int64(256), ax0)
                T.reads(A[v_ax0])
                T.writes(T_negative[v_ax0])
                T_negative[v_ax0] = T.float32(0) - A[v_ax0]

    @T.prim_func
    def negative3(
        A: T.Buffer((T.int64(128),), "float32"),
        T_negative: T.Buffer((T.int64(128),), "float32"),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0 in range(T.int64(128)):
            with T.block("T_negative"):
                v_ax0 = T.axis.spatial(T.int64(128), ax0)
                T.reads(A[v_ax0])
                T.writes(T_negative[v_ax0])
                T_negative[v_ax0] = T.float32(0) - A[v_ax0]

    @T.prim_func
    def negative4(
        A: T.Buffer((T.int64(512),), "float32"),
        T_negative: T.Buffer((T.int64(512),), "float32"),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0 in range(T.int64(512)):
            with T.block("T_negative"):
                v_ax0 = T.axis.spatial(T.int64(512), ax0)
                T.reads(A[v_ax0])
                T.writes(T_negative[v_ax0])
                T_negative[v_ax0] = T.float32(0) - A[v_ax0]

    @T.prim_func
    def negative5(
        A: T.Buffer((T.int64(1024),), "float32"),
        T_negative: T.Buffer((T.int64(1024),), "float32"),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0 in range(T.int64(1024)):
            with T.block("T_negative"):
                v_ax0 = T.axis.spatial(T.int64(1024), ax0)
                T.reads(A[v_ax0])
                T.writes(T_negative[v_ax0])
                T_negative[v_ax0] = T.float32(0) - A[v_ax0]

    @T.prim_func
    def negative6(
        A: T.Buffer((T.int64(2048),), "float32"),
        T_negative: T.Buffer((T.int64(2048),), "float32"),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0 in range(T.int64(2048)):
            with T.block("T_negative"):
                v_ax0 = T.axis.spatial(T.int64(2048), ax0)
                T.reads(A[v_ax0])
                T.writes(T_negative[v_ax0])
                T_negative[v_ax0] = T.float32(0) - A[v_ax0]

    @T.prim_func
    def relu(
        A: T.Buffer(
            (T.int64(1), T.int64(16), T.int64(112), T.int64(112), T.int64(4)), "float32"
        ),
        T_relu: T.Buffer(
            (T.int64(1), T.int64(16), T.int64(112), T.int64(112), T.int64(4)), "float32"
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2, ax3, ax4 in T.grid(
            T.int64(1), T.int64(16), T.int64(112), T.int64(112), T.int64(4)
        ):
            with T.block("T_relu"):
                v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap(
                    "SSSSS", [ax0, ax1, ax2, ax3, ax4]
                )
                T.reads(A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
                T.writes(T_relu[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
                T_relu[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = T.max(
                    A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4], T.float32(0)
                )

    @T.prim_func
    def relu1(
        A: T.Buffer(
            (T.int64(1), T.int64(16), T.int64(56), T.int64(56), T.int64(4)), "float32"
        ),
        T_relu: T.Buffer(
            (T.int64(1), T.int64(16), T.int64(56), T.int64(56), T.int64(4)), "float32"
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2, ax3, ax4 in T.grid(
            T.int64(1), T.int64(16), T.int64(56), T.int64(56), T.int64(4)
        ):
            with T.block("T_relu"):
                v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap(
                    "SSSSS", [ax0, ax1, ax2, ax3, ax4]
                )
                T.reads(A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
                T.writes(T_relu[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
                T_relu[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = T.max(
                    A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4], T.float32(0)
                )

    @T.prim_func
    def relu2(
        A: T.Buffer(
            (T.int64(1), T.int64(64), T.int64(56), T.int64(56), T.int64(4)), "float32"
        ),
        T_relu: T.Buffer(
            (T.int64(1), T.int64(64), T.int64(56), T.int64(56), T.int64(4)), "float32"
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2, ax3, ax4 in T.grid(
            T.int64(1), T.int64(64), T.int64(56), T.int64(56), T.int64(4)
        ):
            with T.block("T_relu"):
                v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap(
                    "SSSSS", [ax0, ax1, ax2, ax3, ax4]
                )
                T.reads(A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
                T.writes(T_relu[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
                T_relu[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = T.max(
                    A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4], T.float32(0)
                )

    @T.prim_func
    def relu3(
        A: T.Buffer(
            (T.int64(1), T.int64(32), T.int64(28), T.int64(28), T.int64(4)), "float32"
        ),
        T_relu: T.Buffer(
            (T.int64(1), T.int64(32), T.int64(28), T.int64(28), T.int64(4)), "float32"
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2, ax3, ax4 in T.grid(
            T.int64(1), T.int64(32), T.int64(28), T.int64(28), T.int64(4)
        ):
            with T.block("T_relu"):
                v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap(
                    "SSSSS", [ax0, ax1, ax2, ax3, ax4]
                )
                T.reads(A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
                T.writes(T_relu[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
                T_relu[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = T.max(
                    A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4], T.float32(0)
                )

    @T.prim_func
    def relu4(
        A: T.Buffer(
            (T.int64(1), T.int64(128), T.int64(28), T.int64(28), T.int64(4)), "float32"
        ),
        T_relu: T.Buffer(
            (T.int64(1), T.int64(128), T.int64(28), T.int64(28), T.int64(4)), "float32"
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2, ax3, ax4 in T.grid(
            T.int64(1), T.int64(128), T.int64(28), T.int64(28), T.int64(4)
        ):
            with T.block("T_relu"):
                v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap(
                    "SSSSS", [ax0, ax1, ax2, ax3, ax4]
                )
                T.reads(A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
                T.writes(T_relu[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
                T_relu[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = T.max(
                    A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4], T.float32(0)
                )

    @T.prim_func
    def relu5(
        A: T.Buffer(
            (T.int64(1), T.int64(64), T.int64(14), T.int64(14), T.int64(4)), "float32"
        ),
        T_relu: T.Buffer(
            (T.int64(1), T.int64(64), T.int64(14), T.int64(14), T.int64(4)), "float32"
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2, ax3, ax4 in T.grid(
            T.int64(1), T.int64(64), T.int64(14), T.int64(14), T.int64(4)
        ):
            with T.block("T_relu"):
                v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap(
                    "SSSSS", [ax0, ax1, ax2, ax3, ax4]
                )
                T.reads(A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
                T.writes(T_relu[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
                T_relu[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = T.max(
                    A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4], T.float32(0)
                )

    @T.prim_func
    def relu6(
        A: T.Buffer(
            (T.int64(1), T.int64(256), T.int64(14), T.int64(14), T.int64(4)), "float32"
        ),
        T_relu: T.Buffer(
            (T.int64(1), T.int64(256), T.int64(14), T.int64(14), T.int64(4)), "float32"
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2, ax3, ax4 in T.grid(
            T.int64(1), T.int64(256), T.int64(14), T.int64(14), T.int64(4)
        ):
            with T.block("T_relu"):
                v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap(
                    "SSSSS", [ax0, ax1, ax2, ax3, ax4]
                )
                T.reads(A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
                T.writes(T_relu[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
                T_relu[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = T.max(
                    A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4], T.float32(0)
                )

    @T.prim_func
    def relu7(
        A: T.Buffer(
            (T.int64(1), T.int64(128), T.int64(7), T.int64(7), T.int64(4)), "float32"
        ),
        T_relu: T.Buffer(
            (T.int64(1), T.int64(128), T.int64(7), T.int64(7), T.int64(4)), "float32"
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2, ax3, ax4 in T.grid(
            T.int64(1), T.int64(128), T.int64(7), T.int64(7), T.int64(4)
        ):
            with T.block("T_relu"):
                v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap(
                    "SSSSS", [ax0, ax1, ax2, ax3, ax4]
                )
                T.reads(A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
                T.writes(T_relu[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
                T_relu[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = T.max(
                    A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4], T.float32(0)
                )

    @T.prim_func
    def relu8(
        A: T.Buffer(
            (T.int64(1), T.int64(512), T.int64(7), T.int64(7), T.int64(4)), "float32"
        ),
        T_relu: T.Buffer(
            (T.int64(1), T.int64(512), T.int64(7), T.int64(7), T.int64(4)), "float32"
        ),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, ax2, ax3, ax4 in T.grid(
            T.int64(1), T.int64(512), T.int64(7), T.int64(7), T.int64(4)
        ):
            with T.block("T_relu"):
                v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap(
                    "SSSSS", [ax0, ax1, ax2, ax3, ax4]
                )
                T.reads(A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
                T.writes(T_relu[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
                T_relu[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = T.max(
                    A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4], T.float32(0)
                )

    @T.prim_func
    def rsqrt(
        A: T.Buffer((T.int64(3),), "float32"),
        tensor: T.Buffer((T.int64(3),), "float32"),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0 in range(T.int64(3)):
            with T.block("tensor"):
                v_ax0 = T.axis.spatial(T.int64(3), ax0)
                T.reads(A[v_ax0])
                T.writes(tensor[v_ax0])
                tensor[v_ax0] = T.float32(1) / T.sqrt(A[v_ax0])

    @T.prim_func
    def rsqrt1(
        A: T.Buffer((T.int64(64),), "float32"),
        tensor: T.Buffer((T.int64(64),), "float32"),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0 in range(T.int64(64)):
            with T.block("tensor"):
                v_ax0 = T.axis.spatial(T.int64(64), ax0)
                T.reads(A[v_ax0])
                T.writes(tensor[v_ax0])
                tensor[v_ax0] = T.float32(1) / T.sqrt(A[v_ax0])

    @T.prim_func
    def rsqrt2(
        A: T.Buffer((T.int64(256),), "float32"),
        tensor: T.Buffer((T.int64(256),), "float32"),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0 in range(T.int64(256)):
            with T.block("tensor"):
                v_ax0 = T.axis.spatial(T.int64(256), ax0)
                T.reads(A[v_ax0])
                T.writes(tensor[v_ax0])
                tensor[v_ax0] = T.float32(1) / T.sqrt(A[v_ax0])

    @T.prim_func
    def rsqrt3(
        A: T.Buffer((T.int64(128),), "float32"),
        tensor: T.Buffer((T.int64(128),), "float32"),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0 in range(T.int64(128)):
            with T.block("tensor"):
                v_ax0 = T.axis.spatial(T.int64(128), ax0)
                T.reads(A[v_ax0])
                T.writes(tensor[v_ax0])
                tensor[v_ax0] = T.float32(1) / T.sqrt(A[v_ax0])

    @T.prim_func
    def rsqrt4(
        A: T.Buffer((T.int64(512),), "float32"),
        tensor: T.Buffer((T.int64(512),), "float32"),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0 in range(T.int64(512)):
            with T.block("tensor"):
                v_ax0 = T.axis.spatial(T.int64(512), ax0)
                T.reads(A[v_ax0])
                T.writes(tensor[v_ax0])
                tensor[v_ax0] = T.float32(1) / T.sqrt(A[v_ax0])

    @T.prim_func
    def rsqrt5(
        A: T.Buffer((T.int64(1024),), "float32"),
        tensor: T.Buffer((T.int64(1024),), "float32"),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0 in range(T.int64(1024)):
            with T.block("tensor"):
                v_ax0 = T.axis.spatial(T.int64(1024), ax0)
                T.reads(A[v_ax0])
                T.writes(tensor[v_ax0])
                tensor[v_ax0] = T.float32(1) / T.sqrt(A[v_ax0])

    @T.prim_func
    def rsqrt6(
        A: T.Buffer((T.int64(2048),), "float32"),
        tensor: T.Buffer((T.int64(2048),), "float32"),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0 in range(T.int64(2048)):
            with T.block("tensor"):
                v_ax0 = T.axis.spatial(T.int64(2048), ax0)
                T.reads(A[v_ax0])
                T.writes(tensor[v_ax0])
                tensor[v_ax0] = T.float32(1) / T.sqrt(A[v_ax0])

    @T.prim_func
    def softmax(
        A: T.Buffer((T.int64(1), T.int64(1000)), "float32"),
        T_softmax_norm: T.Buffer((T.int64(1), T.int64(1000)), "float32"),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        T_softmax_maxelem = T.alloc_buffer((T.int64(1),))
        T_softmax_exp = T.alloc_buffer((T.int64(1), T.int64(1000)))
        T_softmax_expsum = T.alloc_buffer((T.int64(1),))
        for i0, k in T.grid(T.int64(1), T.int64(1000)):
            with T.block("T_softmax_maxelem"):
                v_i0, v_k = T.axis.remap("SR", [i0, k])
                T.reads(A[v_i0, v_k])
                T.writes(T_softmax_maxelem[v_i0])
                with T.init():
                    T_softmax_maxelem[v_i0] = T.float32(-3.4028234663852886e38)
                T_softmax_maxelem[v_i0] = T.max(T_softmax_maxelem[v_i0], A[v_i0, v_k])
        for i0, i1 in T.grid(T.int64(1), T.int64(1000)):
            with T.block("T_softmax_exp"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(A[v_i0, v_i1], T_softmax_maxelem[v_i0])
                T.writes(T_softmax_exp[v_i0, v_i1])
                T_softmax_exp[v_i0, v_i1] = T.exp(
                    A[v_i0, v_i1] - T_softmax_maxelem[v_i0]
                )
        for i0, k in T.grid(T.int64(1), T.int64(1000)):
            with T.block("T_softmax_expsum"):
                v_i0, v_k = T.axis.remap("SR", [i0, k])
                T.reads(T_softmax_exp[v_i0, v_k])
                T.writes(T_softmax_expsum[v_i0])
                with T.init():
                    T_softmax_expsum[v_i0] = T.float32(0)
                T_softmax_expsum[v_i0] = (
                    T_softmax_expsum[v_i0] + T_softmax_exp[v_i0, v_k]
                )
        for i0, i1 in T.grid(T.int64(1), T.int64(1000)):
            with T.block("T_softmax_norm"):
                v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
                T.reads(T_softmax_exp[v_i0, v_i1], T_softmax_expsum[v_i0])
                T.writes(T_softmax_norm[v_i0, v_i1])
                T.block_attr({"axis": 1})
                T_softmax_norm[v_i0, v_i1] = (
                    T_softmax_exp[v_i0, v_i1] / T_softmax_expsum[v_i0]
                )

    @T.prim_func
    def squeeze(
        A: T.Buffer((T.int64(3), T.int64(1), T.int64(1)), "float32"),
        T_squeeze: T.Buffer((T.int64(3),), "float32"),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0 in range(T.int64(3)):
            with T.block("T_squeeze"):
                v_ax0 = T.axis.spatial(T.int64(3), ax0)
                T.reads(A[v_ax0, T.int64(0), T.int64(0)])
                T.writes(T_squeeze[v_ax0])
                T_squeeze[v_ax0] = A[v_ax0, T.int64(0), T.int64(0)]

    @T.prim_func
    def squeeze1(
        A: T.Buffer((T.int64(64), T.int64(1), T.int64(1)), "float32"),
        T_squeeze: T.Buffer((T.int64(64),), "float32"),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0 in range(T.int64(64)):
            with T.block("T_squeeze"):
                v_ax0 = T.axis.spatial(T.int64(64), ax0)
                T.reads(A[v_ax0, T.int64(0), T.int64(0)])
                T.writes(T_squeeze[v_ax0])
                T_squeeze[v_ax0] = A[v_ax0, T.int64(0), T.int64(0)]

    @T.prim_func
    def squeeze2(
        A: T.Buffer((T.int64(128), T.int64(1), T.int64(1)), "float32"),
        T_squeeze: T.Buffer((T.int64(128),), "float32"),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0 in range(T.int64(128)):
            with T.block("T_squeeze"):
                v_ax0 = T.axis.spatial(T.int64(128), ax0)
                T.reads(A[v_ax0, T.int64(0), T.int64(0)])
                T.writes(T_squeeze[v_ax0])
                T_squeeze[v_ax0] = A[v_ax0, T.int64(0), T.int64(0)]

    @T.prim_func
    def squeeze3(
        A: T.Buffer((T.int64(256), T.int64(1), T.int64(1)), "float32"),
        T_squeeze: T.Buffer((T.int64(256),), "float32"),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0 in range(T.int64(256)):
            with T.block("T_squeeze"):
                v_ax0 = T.axis.spatial(T.int64(256), ax0)
                T.reads(A[v_ax0, T.int64(0), T.int64(0)])
                T.writes(T_squeeze[v_ax0])
                T_squeeze[v_ax0] = A[v_ax0, T.int64(0), T.int64(0)]

    @T.prim_func
    def squeeze4(
        A: T.Buffer((T.int64(512), T.int64(1), T.int64(1)), "float32"),
        T_squeeze: T.Buffer((T.int64(512),), "float32"),
    ):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0 in range(T.int64(512)):
            with T.block("T_squeeze"):
                v_ax0 = T.axis.spatial(T.int64(512), ax0)
                T.reads(A[v_ax0, T.int64(0), T.int64(0)])
                T.writes(T_squeeze[v_ax0])
                T_squeeze[v_ax0] = A[v_ax0, T.int64(0), T.int64(0)]

    @R.function
    def main(
        data: R.Tensor((1, 3, 224, 224), dtype="float32"),
        bn_data_gamma: R.Tensor((3,), dtype="float32"),
        bn_data_beta: R.Tensor((3,), dtype="float32"),
        bn_data_moving_mean: R.Tensor((3,), dtype="float32"),
        bn_data_moving_var: R.Tensor((3,), dtype="float32"),
        conv0_weight: R.Tensor((64, 3, 7, 7), dtype="float32"),
        bn0_gamma: R.Tensor((64,), dtype="float32"),
        bn0_beta: R.Tensor((64,), dtype="float32"),
        bn0_moving_mean: R.Tensor((64,), dtype="float32"),
        bn0_moving_var: R.Tensor((64,), dtype="float32"),
        stage1_unit1_bn1_gamma: R.Tensor((64,), dtype="float32"),
        stage1_unit1_bn1_beta: R.Tensor((64,), dtype="float32"),
        stage1_unit1_bn1_moving_mean: R.Tensor((64,), dtype="float32"),
        stage1_unit1_bn1_moving_var: R.Tensor((64,), dtype="float32"),
        stage1_unit1_conv1_weight: R.Tensor((64, 64, 1, 1), dtype="float32"),
        stage1_unit1_bn2_gamma: R.Tensor((64,), dtype="float32"),
        stage1_unit1_bn2_beta: R.Tensor((64,), dtype="float32"),
        stage1_unit1_bn2_moving_mean: R.Tensor((64,), dtype="float32"),
        stage1_unit1_bn2_moving_var: R.Tensor((64,), dtype="float32"),
        stage1_unit1_conv2_weight: R.Tensor((64, 64, 3, 3), dtype="float32"),
        stage1_unit1_bn3_gamma: R.Tensor((64,), dtype="float32"),
        stage1_unit1_bn3_beta: R.Tensor((64,), dtype="float32"),
        stage1_unit1_bn3_moving_mean: R.Tensor((64,), dtype="float32"),
        stage1_unit1_bn3_moving_var: R.Tensor((64,), dtype="float32"),
        stage1_unit1_conv3_weight: R.Tensor((256, 64, 1, 1), dtype="float32"),
        stage1_unit1_sc_weight: R.Tensor((256, 64, 1, 1), dtype="float32"),
        stage1_unit2_bn1_gamma: R.Tensor((256,), dtype="float32"),
        stage1_unit2_bn1_beta: R.Tensor((256,), dtype="float32"),
        stage1_unit2_bn1_moving_mean: R.Tensor((256,), dtype="float32"),
        stage1_unit2_bn1_moving_var: R.Tensor((256,), dtype="float32"),
        stage1_unit2_conv1_weight: R.Tensor((64, 256, 1, 1), dtype="float32"),
        stage1_unit2_bn2_gamma: R.Tensor((64,), dtype="float32"),
        stage1_unit2_bn2_beta: R.Tensor((64,), dtype="float32"),
        stage1_unit2_bn2_moving_mean: R.Tensor((64,), dtype="float32"),
        stage1_unit2_bn2_moving_var: R.Tensor((64,), dtype="float32"),
        stage1_unit2_conv2_weight: R.Tensor((64, 64, 3, 3), dtype="float32"),
        stage1_unit2_bn3_gamma: R.Tensor((64,), dtype="float32"),
        stage1_unit2_bn3_beta: R.Tensor((64,), dtype="float32"),
        stage1_unit2_bn3_moving_mean: R.Tensor((64,), dtype="float32"),
        stage1_unit2_bn3_moving_var: R.Tensor((64,), dtype="float32"),
        stage1_unit2_conv3_weight: R.Tensor((256, 64, 1, 1), dtype="float32"),
        stage1_unit3_bn1_gamma: R.Tensor((256,), dtype="float32"),
        stage1_unit3_bn1_beta: R.Tensor((256,), dtype="float32"),
        stage1_unit3_bn1_moving_mean: R.Tensor((256,), dtype="float32"),
        stage1_unit3_bn1_moving_var: R.Tensor((256,), dtype="float32"),
        stage1_unit3_conv1_weight: R.Tensor((64, 256, 1, 1), dtype="float32"),
        stage1_unit3_bn2_gamma: R.Tensor((64,), dtype="float32"),
        stage1_unit3_bn2_beta: R.Tensor((64,), dtype="float32"),
        stage1_unit3_bn2_moving_mean: R.Tensor((64,), dtype="float32"),
        stage1_unit3_bn2_moving_var: R.Tensor((64,), dtype="float32"),
        stage1_unit3_conv2_weight: R.Tensor((64, 64, 3, 3), dtype="float32"),
        stage1_unit3_bn3_gamma: R.Tensor((64,), dtype="float32"),
        stage1_unit3_bn3_beta: R.Tensor((64,), dtype="float32"),
        stage1_unit3_bn3_moving_mean: R.Tensor((64,), dtype="float32"),
        stage1_unit3_bn3_moving_var: R.Tensor((64,), dtype="float32"),
        stage1_unit3_conv3_weight: R.Tensor((256, 64, 1, 1), dtype="float32"),
        stage2_unit1_bn1_gamma: R.Tensor((256,), dtype="float32"),
        stage2_unit1_bn1_beta: R.Tensor((256,), dtype="float32"),
        stage2_unit1_bn1_moving_mean: R.Tensor((256,), dtype="float32"),
        stage2_unit1_bn1_moving_var: R.Tensor((256,), dtype="float32"),
        stage2_unit1_conv1_weight: R.Tensor((128, 256, 1, 1), dtype="float32"),
        stage2_unit1_bn2_gamma: R.Tensor((128,), dtype="float32"),
        stage2_unit1_bn2_beta: R.Tensor((128,), dtype="float32"),
        stage2_unit1_bn2_moving_mean: R.Tensor((128,), dtype="float32"),
        stage2_unit1_bn2_moving_var: R.Tensor((128,), dtype="float32"),
        stage2_unit1_conv2_weight: R.Tensor((128, 128, 3, 3), dtype="float32"),
        stage2_unit1_bn3_gamma: R.Tensor((128,), dtype="float32"),
        stage2_unit1_bn3_beta: R.Tensor((128,), dtype="float32"),
        stage2_unit1_bn3_moving_mean: R.Tensor((128,), dtype="float32"),
        stage2_unit1_bn3_moving_var: R.Tensor((128,), dtype="float32"),
        stage2_unit1_conv3_weight: R.Tensor((512, 128, 1, 1), dtype="float32"),
        stage2_unit1_sc_weight: R.Tensor((512, 256, 1, 1), dtype="float32"),
        stage2_unit2_bn1_gamma: R.Tensor((512,), dtype="float32"),
        stage2_unit2_bn1_beta: R.Tensor((512,), dtype="float32"),
        stage2_unit2_bn1_moving_mean: R.Tensor((512,), dtype="float32"),
        stage2_unit2_bn1_moving_var: R.Tensor((512,), dtype="float32"),
        stage2_unit2_conv1_weight: R.Tensor((128, 512, 1, 1), dtype="float32"),
        stage2_unit2_bn2_gamma: R.Tensor((128,), dtype="float32"),
        stage2_unit2_bn2_beta: R.Tensor((128,), dtype="float32"),
        stage2_unit2_bn2_moving_mean: R.Tensor((128,), dtype="float32"),
        stage2_unit2_bn2_moving_var: R.Tensor((128,), dtype="float32"),
        stage2_unit2_conv2_weight: R.Tensor((128, 128, 3, 3), dtype="float32"),
        stage2_unit2_bn3_gamma: R.Tensor((128,), dtype="float32"),
        stage2_unit2_bn3_beta: R.Tensor((128,), dtype="float32"),
        stage2_unit2_bn3_moving_mean: R.Tensor((128,), dtype="float32"),
        stage2_unit2_bn3_moving_var: R.Tensor((128,), dtype="float32"),
        stage2_unit2_conv3_weight: R.Tensor((512, 128, 1, 1), dtype="float32"),
        stage2_unit3_bn1_gamma: R.Tensor((512,), dtype="float32"),
        stage2_unit3_bn1_beta: R.Tensor((512,), dtype="float32"),
        stage2_unit3_bn1_moving_mean: R.Tensor((512,), dtype="float32"),
        stage2_unit3_bn1_moving_var: R.Tensor((512,), dtype="float32"),
        stage2_unit3_conv1_weight: R.Tensor((128, 512, 1, 1), dtype="float32"),
        stage2_unit3_bn2_gamma: R.Tensor((128,), dtype="float32"),
        stage2_unit3_bn2_beta: R.Tensor((128,), dtype="float32"),
        stage2_unit3_bn2_moving_mean: R.Tensor((128,), dtype="float32"),
        stage2_unit3_bn2_moving_var: R.Tensor((128,), dtype="float32"),
        stage2_unit3_conv2_weight: R.Tensor((128, 128, 3, 3), dtype="float32"),
        stage2_unit3_bn3_gamma: R.Tensor((128,), dtype="float32"),
        stage2_unit3_bn3_beta: R.Tensor((128,), dtype="float32"),
        stage2_unit3_bn3_moving_mean: R.Tensor((128,), dtype="float32"),
        stage2_unit3_bn3_moving_var: R.Tensor((128,), dtype="float32"),
        stage2_unit3_conv3_weight: R.Tensor((512, 128, 1, 1), dtype="float32"),
        stage2_unit4_bn1_gamma: R.Tensor((512,), dtype="float32"),
        stage2_unit4_bn1_beta: R.Tensor((512,), dtype="float32"),
        stage2_unit4_bn1_moving_mean: R.Tensor((512,), dtype="float32"),
        stage2_unit4_bn1_moving_var: R.Tensor((512,), dtype="float32"),
        stage2_unit4_conv1_weight: R.Tensor((128, 512, 1, 1), dtype="float32"),
        stage2_unit4_bn2_gamma: R.Tensor((128,), dtype="float32"),
        stage2_unit4_bn2_beta: R.Tensor((128,), dtype="float32"),
        stage2_unit4_bn2_moving_mean: R.Tensor((128,), dtype="float32"),
        stage2_unit4_bn2_moving_var: R.Tensor((128,), dtype="float32"),
        stage2_unit4_conv2_weight: R.Tensor((128, 128, 3, 3), dtype="float32"),
        stage2_unit4_bn3_gamma: R.Tensor((128,), dtype="float32"),
        stage2_unit4_bn3_beta: R.Tensor((128,), dtype="float32"),
        stage2_unit4_bn3_moving_mean: R.Tensor((128,), dtype="float32"),
        stage2_unit4_bn3_moving_var: R.Tensor((128,), dtype="float32"),
        stage2_unit4_conv3_weight: R.Tensor((512, 128, 1, 1), dtype="float32"),
        stage3_unit1_bn1_gamma: R.Tensor((512,), dtype="float32"),
        stage3_unit1_bn1_beta: R.Tensor((512,), dtype="float32"),
        stage3_unit1_bn1_moving_mean: R.Tensor((512,), dtype="float32"),
        stage3_unit1_bn1_moving_var: R.Tensor((512,), dtype="float32"),
        stage3_unit1_conv1_weight: R.Tensor((256, 512, 1, 1), dtype="float32"),
        stage3_unit1_bn2_gamma: R.Tensor((256,), dtype="float32"),
        stage3_unit1_bn2_beta: R.Tensor((256,), dtype="float32"),
        stage3_unit1_bn2_moving_mean: R.Tensor((256,), dtype="float32"),
        stage3_unit1_bn2_moving_var: R.Tensor((256,), dtype="float32"),
        stage3_unit1_conv2_weight: R.Tensor((256, 256, 3, 3), dtype="float32"),
        stage3_unit1_bn3_gamma: R.Tensor((256,), dtype="float32"),
        stage3_unit1_bn3_beta: R.Tensor((256,), dtype="float32"),
        stage3_unit1_bn3_moving_mean: R.Tensor((256,), dtype="float32"),
        stage3_unit1_bn3_moving_var: R.Tensor((256,), dtype="float32"),
        stage3_unit1_conv3_weight: R.Tensor((1024, 256, 1, 1), dtype="float32"),
        stage3_unit1_sc_weight: R.Tensor((1024, 512, 1, 1), dtype="float32"),
        stage3_unit2_bn1_gamma: R.Tensor((1024,), dtype="float32"),
        stage3_unit2_bn1_beta: R.Tensor((1024,), dtype="float32"),
        stage3_unit2_bn1_moving_mean: R.Tensor((1024,), dtype="float32"),
        stage3_unit2_bn1_moving_var: R.Tensor((1024,), dtype="float32"),
        stage3_unit2_conv1_weight: R.Tensor((256, 1024, 1, 1), dtype="float32"),
        stage3_unit2_bn2_gamma: R.Tensor((256,), dtype="float32"),
        stage3_unit2_bn2_beta: R.Tensor((256,), dtype="float32"),
        stage3_unit2_bn2_moving_mean: R.Tensor((256,), dtype="float32"),
        stage3_unit2_bn2_moving_var: R.Tensor((256,), dtype="float32"),
        stage3_unit2_conv2_weight: R.Tensor((256, 256, 3, 3), dtype="float32"),
        stage3_unit2_bn3_gamma: R.Tensor((256,), dtype="float32"),
        stage3_unit2_bn3_beta: R.Tensor((256,), dtype="float32"),
        stage3_unit2_bn3_moving_mean: R.Tensor((256,), dtype="float32"),
        stage3_unit2_bn3_moving_var: R.Tensor((256,), dtype="float32"),
        stage3_unit2_conv3_weight: R.Tensor((1024, 256, 1, 1), dtype="float32"),
        stage3_unit3_bn1_gamma: R.Tensor((1024,), dtype="float32"),
        stage3_unit3_bn1_beta: R.Tensor((1024,), dtype="float32"),
        stage3_unit3_bn1_moving_mean: R.Tensor((1024,), dtype="float32"),
        stage3_unit3_bn1_moving_var: R.Tensor((1024,), dtype="float32"),
        stage3_unit3_conv1_weight: R.Tensor((256, 1024, 1, 1), dtype="float32"),
        stage3_unit3_bn2_gamma: R.Tensor((256,), dtype="float32"),
        stage3_unit3_bn2_beta: R.Tensor((256,), dtype="float32"),
        stage3_unit3_bn2_moving_mean: R.Tensor((256,), dtype="float32"),
        stage3_unit3_bn2_moving_var: R.Tensor((256,), dtype="float32"),
        stage3_unit3_conv2_weight: R.Tensor((256, 256, 3, 3), dtype="float32"),
        stage3_unit3_bn3_gamma: R.Tensor((256,), dtype="float32"),
        stage3_unit3_bn3_beta: R.Tensor((256,), dtype="float32"),
        stage3_unit3_bn3_moving_mean: R.Tensor((256,), dtype="float32"),
        stage3_unit3_bn3_moving_var: R.Tensor((256,), dtype="float32"),
        stage3_unit3_conv3_weight: R.Tensor((1024, 256, 1, 1), dtype="float32"),
        stage3_unit4_bn1_gamma: R.Tensor((1024,), dtype="float32"),
        stage3_unit4_bn1_beta: R.Tensor((1024,), dtype="float32"),
        stage3_unit4_bn1_moving_mean: R.Tensor((1024,), dtype="float32"),
        stage3_unit4_bn1_moving_var: R.Tensor((1024,), dtype="float32"),
        stage3_unit4_conv1_weight: R.Tensor((256, 1024, 1, 1), dtype="float32"),
        stage3_unit4_bn2_gamma: R.Tensor((256,), dtype="float32"),
        stage3_unit4_bn2_beta: R.Tensor((256,), dtype="float32"),
        stage3_unit4_bn2_moving_mean: R.Tensor((256,), dtype="float32"),
        stage3_unit4_bn2_moving_var: R.Tensor((256,), dtype="float32"),
        stage3_unit4_conv2_weight: R.Tensor((256, 256, 3, 3), dtype="float32"),
        stage3_unit4_bn3_gamma: R.Tensor((256,), dtype="float32"),
        stage3_unit4_bn3_beta: R.Tensor((256,), dtype="float32"),
        stage3_unit4_bn3_moving_mean: R.Tensor((256,), dtype="float32"),
        stage3_unit4_bn3_moving_var: R.Tensor((256,), dtype="float32"),
        stage3_unit4_conv3_weight: R.Tensor((1024, 256, 1, 1), dtype="float32"),
        stage3_unit5_bn1_gamma: R.Tensor((1024,), dtype="float32"),
        stage3_unit5_bn1_beta: R.Tensor((1024,), dtype="float32"),
        stage3_unit5_bn1_moving_mean: R.Tensor((1024,), dtype="float32"),
        stage3_unit5_bn1_moving_var: R.Tensor((1024,), dtype="float32"),
        stage3_unit5_conv1_weight: R.Tensor((256, 1024, 1, 1), dtype="float32"),
        stage3_unit5_bn2_gamma: R.Tensor((256,), dtype="float32"),
        stage3_unit5_bn2_beta: R.Tensor((256,), dtype="float32"),
        stage3_unit5_bn2_moving_mean: R.Tensor((256,), dtype="float32"),
        stage3_unit5_bn2_moving_var: R.Tensor((256,), dtype="float32"),
        stage3_unit5_conv2_weight: R.Tensor((256, 256, 3, 3), dtype="float32"),
        stage3_unit5_bn3_gamma: R.Tensor((256,), dtype="float32"),
        stage3_unit5_bn3_beta: R.Tensor((256,), dtype="float32"),
        stage3_unit5_bn3_moving_mean: R.Tensor((256,), dtype="float32"),
        stage3_unit5_bn3_moving_var: R.Tensor((256,), dtype="float32"),
        stage3_unit5_conv3_weight: R.Tensor((1024, 256, 1, 1), dtype="float32"),
        stage3_unit6_bn1_gamma: R.Tensor((1024,), dtype="float32"),
        stage3_unit6_bn1_beta: R.Tensor((1024,), dtype="float32"),
        stage3_unit6_bn1_moving_mean: R.Tensor((1024,), dtype="float32"),
        stage3_unit6_bn1_moving_var: R.Tensor((1024,), dtype="float32"),
        stage3_unit6_conv1_weight: R.Tensor((256, 1024, 1, 1), dtype="float32"),
        stage3_unit6_bn2_gamma: R.Tensor((256,), dtype="float32"),
        stage3_unit6_bn2_beta: R.Tensor((256,), dtype="float32"),
        stage3_unit6_bn2_moving_mean: R.Tensor((256,), dtype="float32"),
        stage3_unit6_bn2_moving_var: R.Tensor((256,), dtype="float32"),
        stage3_unit6_conv2_weight: R.Tensor((256, 256, 3, 3), dtype="float32"),
        stage3_unit6_bn3_gamma: R.Tensor((256,), dtype="float32"),
        stage3_unit6_bn3_beta: R.Tensor((256,), dtype="float32"),
        stage3_unit6_bn3_moving_mean: R.Tensor((256,), dtype="float32"),
        stage3_unit6_bn3_moving_var: R.Tensor((256,), dtype="float32"),
        stage3_unit6_conv3_weight: R.Tensor((1024, 256, 1, 1), dtype="float32"),
        stage4_unit1_bn1_gamma: R.Tensor((1024,), dtype="float32"),
        stage4_unit1_bn1_beta: R.Tensor((1024,), dtype="float32"),
        stage4_unit1_bn1_moving_mean: R.Tensor((1024,), dtype="float32"),
        stage4_unit1_bn1_moving_var: R.Tensor((1024,), dtype="float32"),
        stage4_unit1_conv1_weight: R.Tensor((512, 1024, 1, 1), dtype="float32"),
        stage4_unit1_bn2_gamma: R.Tensor((512,), dtype="float32"),
        stage4_unit1_bn2_beta: R.Tensor((512,), dtype="float32"),
        stage4_unit1_bn2_moving_mean: R.Tensor((512,), dtype="float32"),
        stage4_unit1_bn2_moving_var: R.Tensor((512,), dtype="float32"),
        stage4_unit1_conv2_weight: R.Tensor((512, 512, 3, 3), dtype="float32"),
        stage4_unit1_bn3_gamma: R.Tensor((512,), dtype="float32"),
        stage4_unit1_bn3_beta: R.Tensor((512,), dtype="float32"),
        stage4_unit1_bn3_moving_mean: R.Tensor((512,), dtype="float32"),
        stage4_unit1_bn3_moving_var: R.Tensor((512,), dtype="float32"),
        stage4_unit1_conv3_weight: R.Tensor((2048, 512, 1, 1), dtype="float32"),
        stage4_unit1_sc_weight: R.Tensor((2048, 1024, 1, 1), dtype="float32"),
        stage4_unit2_bn1_gamma: R.Tensor((2048,), dtype="float32"),
        stage4_unit2_bn1_beta: R.Tensor((2048,), dtype="float32"),
        stage4_unit2_bn1_moving_mean: R.Tensor((2048,), dtype="float32"),
        stage4_unit2_bn1_moving_var: R.Tensor((2048,), dtype="float32"),
        stage4_unit2_conv1_weight: R.Tensor((512, 2048, 1, 1), dtype="float32"),
        stage4_unit2_bn2_gamma: R.Tensor((512,), dtype="float32"),
        stage4_unit2_bn2_beta: R.Tensor((512,), dtype="float32"),
        stage4_unit2_bn2_moving_mean: R.Tensor((512,), dtype="float32"),
        stage4_unit2_bn2_moving_var: R.Tensor((512,), dtype="float32"),
        stage4_unit2_conv2_weight: R.Tensor((512, 512, 3, 3), dtype="float32"),
        stage4_unit2_bn3_gamma: R.Tensor((512,), dtype="float32"),
        stage4_unit2_bn3_beta: R.Tensor((512,), dtype="float32"),
        stage4_unit2_bn3_moving_mean: R.Tensor((512,), dtype="float32"),
        stage4_unit2_bn3_moving_var: R.Tensor((512,), dtype="float32"),
        stage4_unit2_conv3_weight: R.Tensor((2048, 512, 1, 1), dtype="float32"),
        stage4_unit3_bn1_gamma: R.Tensor((2048,), dtype="float32"),
        stage4_unit3_bn1_beta: R.Tensor((2048,), dtype="float32"),
        stage4_unit3_bn1_moving_mean: R.Tensor((2048,), dtype="float32"),
        stage4_unit3_bn1_moving_var: R.Tensor((2048,), dtype="float32"),
        stage4_unit3_conv1_weight: R.Tensor((512, 2048, 1, 1), dtype="float32"),
        stage4_unit3_bn2_gamma: R.Tensor((512,), dtype="float32"),
        stage4_unit3_bn2_beta: R.Tensor((512,), dtype="float32"),
        stage4_unit3_bn2_moving_mean: R.Tensor((512,), dtype="float32"),
        stage4_unit3_bn2_moving_var: R.Tensor((512,), dtype="float32"),
        stage4_unit3_conv2_weight: R.Tensor((512, 512, 3, 3), dtype="float32"),
        stage4_unit3_bn3_gamma: R.Tensor((512,), dtype="float32"),
        stage4_unit3_bn3_beta: R.Tensor((512,), dtype="float32"),
        stage4_unit3_bn3_moving_mean: R.Tensor((512,), dtype="float32"),
        stage4_unit3_bn3_moving_var: R.Tensor((512,), dtype="float32"),
        stage4_unit3_conv3_weight: R.Tensor((2048, 512, 1, 1), dtype="float32"),
        bn1_gamma: R.Tensor((2048,), dtype="float32"),
        bn1_beta: R.Tensor((2048,), dtype="float32"),
        bn1_moving_mean: R.Tensor((2048,), dtype="float32"),
        bn1_moving_var: R.Tensor((2048,), dtype="float32"),
        fc1_weight: R.Tensor((1000, 2048), dtype="float32"),
        fc1_bias: R.Tensor((1000,), dtype="float32"),
    ) -> R.Tensor((1, 1000), dtype="float32"):
        cls = Module
        with R.dataflow():
            lv = R.call_tir(
                cls.negative,
                (bn_data_moving_mean,),
                out_sinfo=R.Tensor((3,), dtype="float32"),
            )
            lv1 = R.call_tir(
                cls.add,
                (bn_data_moving_var, R.const(1.9999999494757503e-05, "float32")),
                out_sinfo=R.Tensor((3,), dtype="float32"),
            )
            lv2 = R.call_tir(
                cls.rsqrt, (lv1,), out_sinfo=R.Tensor((3,), dtype="float32")
            )
            lv3 = R.call_tir(
                cls.multiply, (lv, lv2), out_sinfo=R.Tensor((3,), dtype="float32")
            )
            lv4 = R.call_tir(
                cls.add1, (lv3, bn_data_beta), out_sinfo=R.Tensor((3,), dtype="float32")
            )
            lv5 = R.call_tir(
                cls.expand_dims, (lv4,), out_sinfo=R.Tensor((3, 1, 1), dtype="float32")
            )
            lv6 = R.call_tir(
                cls.expand_dims, (lv2,), out_sinfo=R.Tensor((3, 1, 1), dtype="float32")
            )
            lv7 = R.call_tir(
                cls.squeeze, (lv6,), out_sinfo=R.Tensor((3,), dtype="float32")
            )
            lv8 = R.call_tir(
                cls.expand_dims, (lv7,), out_sinfo=R.Tensor((3, 1, 1), dtype="float32")
            )
            lv9 = R.call_tir(
                cls.divide, (lv5, lv8), out_sinfo=R.Tensor((3, 1, 1), dtype="float32")
            )
            lv10 = R.call_tir(
                cls.add2,
                (data, lv9),
                out_sinfo=R.Tensor((1, 3, 224, 224), dtype="float32"),
            )
            lv11 = R.call_tir(
                cls.layout_transform,
                (lv10,),
                out_sinfo=R.Tensor((1, 1, 224, 224, 3), dtype="float32"),
            )
            lv12 = R.call_tir(
                cls.add3,
                (bn0_moving_var, R.const(1.9999999494757503e-05, "float32")),
                out_sinfo=R.Tensor((64,), dtype="float32"),
            )
            lv13 = R.call_tir(
                cls.rsqrt1, (lv12,), out_sinfo=R.Tensor((64,), dtype="float32")
            )
            lv14 = R.call_tir(
                cls.multiply1,
                (lv13, bn0_gamma),
                out_sinfo=R.Tensor((64,), dtype="float32"),
            )
            lv15 = R.call_tir(
                cls.expand_dims1,
                (lv14,),
                out_sinfo=R.Tensor((64, 1, 1), dtype="float32"),
            )
            lv16 = R.call_tir(
                cls.squeeze1, (lv15,), out_sinfo=R.Tensor((64,), dtype="float32")
            )
            lv17 = R.call_tir(
                cls.expand_dims2,
                (lv16,),
                out_sinfo=R.Tensor((64, 1, 1, 1), dtype="float32"),
            )
            lv18 = R.call_tir(
                cls.multiply2,
                (conv0_weight, lv17),
                out_sinfo=R.Tensor((64, 3, 7, 7), dtype="float32"),
            )
            lv19 = R.call_tir(
                cls.expand_dims, (lv7,), out_sinfo=R.Tensor((3, 1, 1), dtype="float32")
            )
            lv20 = R.call_tir(
                cls.multiply3,
                (lv18, lv19),
                out_sinfo=R.Tensor((64, 3, 7, 7), dtype="float32"),
            )
            lv21 = R.call_tir(
                cls.layout_transform1,
                (lv20,),
                out_sinfo=R.Tensor((16, 1, 7, 7, 3, 4), dtype="float32"),
            )
            lv22 = R.call_tir(
                cls.contrib_conv2d_NCHWc,
                (lv11, lv21),
                out_sinfo=R.Tensor((1, 16, 112, 112, 4), dtype="float32"),
            )
            lv23 = R.call_tir(
                cls.negative1,
                (bn0_moving_mean,),
                out_sinfo=R.Tensor((64,), dtype="float32"),
            )
            lv24 = R.call_tir(
                cls.multiply1, (lv23, lv14), out_sinfo=R.Tensor((64,), dtype="float32")
            )
            lv25 = R.call_tir(
                cls.add4, (lv24, bn0_beta), out_sinfo=R.Tensor((64,), dtype="float32")
            )
            lv26 = R.call_tir(
                cls.expand_dims1,
                (lv25,),
                out_sinfo=R.Tensor((64, 1, 1), dtype="float32"),
            )
            lv27 = R.call_tir(
                cls.expand_dims3,
                (lv26,),
                out_sinfo=R.Tensor((1, 64, 1, 1), dtype="float32"),
            )
            lv28 = R.call_tir(
                cls.layout_transform2,
                (lv27,),
                out_sinfo=R.Tensor((1, 16, 1, 1, 4), dtype="float32"),
            )
            lv29 = R.call_tir(
                cls.add5,
                (lv22, lv28),
                out_sinfo=R.Tensor((1, 16, 112, 112, 4), dtype="float32"),
            )
            lv30 = R.call_tir(
                cls.relu,
                (lv29,),
                out_sinfo=R.Tensor((1, 16, 112, 112, 4), dtype="float32"),
            )
            lv31 = R.call_tir(
                cls.max_pool2d,
                (lv30,),
                out_sinfo=R.Tensor((1, 16, 56, 56, 4), dtype="float32"),
            )
            lv32 = R.call_tir(
                cls.add3,
                (
                    stage1_unit1_bn1_moving_var,
                    R.const(1.9999999494757503e-05, "float32"),
                ),
                out_sinfo=R.Tensor((64,), dtype="float32"),
            )
            lv33 = R.call_tir(
                cls.rsqrt1, (lv32,), out_sinfo=R.Tensor((64,), dtype="float32")
            )
            lv34 = R.call_tir(
                cls.multiply1,
                (lv33, stage1_unit1_bn1_gamma),
                out_sinfo=R.Tensor((64,), dtype="float32"),
            )
            lv35 = R.call_tir(
                cls.expand_dims1,
                (lv34,),
                out_sinfo=R.Tensor((64, 1, 1), dtype="float32"),
            )
            lv36 = R.call_tir(
                cls.expand_dims3,
                (lv35,),
                out_sinfo=R.Tensor((1, 64, 1, 1), dtype="float32"),
            )
            lv37 = R.call_tir(
                cls.layout_transform2,
                (lv36,),
                out_sinfo=R.Tensor((1, 16, 1, 1, 4), dtype="float32"),
            )
            lv38 = R.call_tir(
                cls.multiply4,
                (lv31, lv37),
                out_sinfo=R.Tensor((1, 16, 56, 56, 4), dtype="float32"),
            )
            lv39 = R.call_tir(
                cls.negative1,
                (stage1_unit1_bn1_moving_mean,),
                out_sinfo=R.Tensor((64,), dtype="float32"),
            )
            lv40 = R.call_tir(
                cls.multiply1, (lv39, lv34), out_sinfo=R.Tensor((64,), dtype="float32")
            )
            lv41 = R.call_tir(
                cls.add4,
                (lv40, stage1_unit1_bn1_beta),
                out_sinfo=R.Tensor((64,), dtype="float32"),
            )
            lv42 = R.call_tir(
                cls.expand_dims1,
                (lv41,),
                out_sinfo=R.Tensor((64, 1, 1), dtype="float32"),
            )
            lv43 = R.call_tir(
                cls.expand_dims3,
                (lv42,),
                out_sinfo=R.Tensor((1, 64, 1, 1), dtype="float32"),
            )
            lv44 = R.call_tir(
                cls.layout_transform2,
                (lv43,),
                out_sinfo=R.Tensor((1, 16, 1, 1, 4), dtype="float32"),
            )
            lv45 = R.call_tir(
                cls.add6,
                (lv38, lv44),
                out_sinfo=R.Tensor((1, 16, 56, 56, 4), dtype="float32"),
            )
            lv46 = R.call_tir(
                cls.relu1,
                (lv45,),
                out_sinfo=R.Tensor((1, 16, 56, 56, 4), dtype="float32"),
            )
            lv47 = R.call_tir(
                cls.add3,
                (
                    stage1_unit1_bn2_moving_var,
                    R.const(1.9999999494757503e-05, "float32"),
                ),
                out_sinfo=R.Tensor((64,), dtype="float32"),
            )
            lv48 = R.call_tir(
                cls.rsqrt1, (lv47,), out_sinfo=R.Tensor((64,), dtype="float32")
            )
            lv49 = R.call_tir(
                cls.multiply1,
                (lv48, stage1_unit1_bn2_gamma),
                out_sinfo=R.Tensor((64,), dtype="float32"),
            )
            lv50 = R.call_tir(
                cls.expand_dims1,
                (lv49,),
                out_sinfo=R.Tensor((64, 1, 1), dtype="float32"),
            )
            lv51 = R.call_tir(
                cls.squeeze1, (lv50,), out_sinfo=R.Tensor((64,), dtype="float32")
            )
            lv52 = R.call_tir(
                cls.expand_dims2,
                (lv51,),
                out_sinfo=R.Tensor((64, 1, 1, 1), dtype="float32"),
            )
            lv53 = R.call_tir(
                cls.multiply5,
                (stage1_unit1_conv1_weight, lv52),
                out_sinfo=R.Tensor((64, 64, 1, 1), dtype="float32"),
            )
            lv54 = R.call_tir(
                cls.layout_transform3,
                (lv53,),
                out_sinfo=R.Tensor((16, 16, 1, 1, 4, 4), dtype="float32"),
            )
            lv55 = R.call_tir(
                cls.contrib_conv2d_NCHWc1,
                (lv46, lv54),
                out_sinfo=R.Tensor((1, 16, 56, 56, 4), dtype="float32"),
            )
            lv56 = R.call_tir(
                cls.negative1,
                (stage1_unit1_bn2_moving_mean,),
                out_sinfo=R.Tensor((64,), dtype="float32"),
            )
            lv57 = R.call_tir(
                cls.multiply1, (lv56, lv49), out_sinfo=R.Tensor((64,), dtype="float32")
            )
            lv58 = R.call_tir(
                cls.add4,
                (lv57, stage1_unit1_bn2_beta),
                out_sinfo=R.Tensor((64,), dtype="float32"),
            )
            lv59 = R.call_tir(
                cls.expand_dims1,
                (lv58,),
                out_sinfo=R.Tensor((64, 1, 1), dtype="float32"),
            )
            lv60 = R.call_tir(
                cls.expand_dims3,
                (lv59,),
                out_sinfo=R.Tensor((1, 64, 1, 1), dtype="float32"),
            )
            lv61 = R.call_tir(
                cls.layout_transform2,
                (lv60,),
                out_sinfo=R.Tensor((1, 16, 1, 1, 4), dtype="float32"),
            )
            lv62 = R.call_tir(
                cls.add6,
                (lv55, lv61),
                out_sinfo=R.Tensor((1, 16, 56, 56, 4), dtype="float32"),
            )
            lv63 = R.call_tir(
                cls.relu1,
                (lv62,),
                out_sinfo=R.Tensor((1, 16, 56, 56, 4), dtype="float32"),
            )
            lv64 = R.call_tir(
                cls.add3,
                (
                    stage1_unit1_bn3_moving_var,
                    R.const(1.9999999494757503e-05, "float32"),
                ),
                out_sinfo=R.Tensor((64,), dtype="float32"),
            )
            lv65 = R.call_tir(
                cls.rsqrt1, (lv64,), out_sinfo=R.Tensor((64,), dtype="float32")
            )
            lv66 = R.call_tir(
                cls.multiply1,
                (lv65, stage1_unit1_bn3_gamma),
                out_sinfo=R.Tensor((64,), dtype="float32"),
            )
            lv67 = R.call_tir(
                cls.expand_dims1,
                (lv66,),
                out_sinfo=R.Tensor((64, 1, 1), dtype="float32"),
            )
            lv68 = R.call_tir(
                cls.squeeze1, (lv67,), out_sinfo=R.Tensor((64,), dtype="float32")
            )
            lv69 = R.call_tir(
                cls.expand_dims2,
                (lv68,),
                out_sinfo=R.Tensor((64, 1, 1, 1), dtype="float32"),
            )
            lv70 = R.call_tir(
                cls.multiply6,
                (stage1_unit1_conv2_weight, lv69),
                out_sinfo=R.Tensor((64, 64, 3, 3), dtype="float32"),
            )
            lv71 = R.call_tir(
                cls.layout_transform4,
                (lv70,),
                out_sinfo=R.Tensor((16, 16, 3, 3, 4, 4), dtype="float32"),
            )
            lv72 = R.call_tir(
                cls.contrib_conv2d_NCHWc2,
                (lv63, lv71),
                out_sinfo=R.Tensor((1, 16, 56, 56, 4), dtype="float32"),
            )
            lv73 = R.call_tir(
                cls.negative1,
                (stage1_unit1_bn3_moving_mean,),
                out_sinfo=R.Tensor((64,), dtype="float32"),
            )
            lv74 = R.call_tir(
                cls.multiply1, (lv73, lv66), out_sinfo=R.Tensor((64,), dtype="float32")
            )
            lv75 = R.call_tir(
                cls.add4,
                (lv74, stage1_unit1_bn3_beta),
                out_sinfo=R.Tensor((64,), dtype="float32"),
            )
            lv76 = R.call_tir(
                cls.expand_dims1,
                (lv75,),
                out_sinfo=R.Tensor((64, 1, 1), dtype="float32"),
            )
            lv77 = R.call_tir(
                cls.expand_dims3,
                (lv76,),
                out_sinfo=R.Tensor((1, 64, 1, 1), dtype="float32"),
            )
            lv78 = R.call_tir(
                cls.layout_transform2,
                (lv77,),
                out_sinfo=R.Tensor((1, 16, 1, 1, 4), dtype="float32"),
            )
            lv79 = R.call_tir(
                cls.add6,
                (lv72, lv78),
                out_sinfo=R.Tensor((1, 16, 56, 56, 4), dtype="float32"),
            )
            lv80 = R.call_tir(
                cls.relu1,
                (lv79,),
                out_sinfo=R.Tensor((1, 16, 56, 56, 4), dtype="float32"),
            )
            lv81 = R.call_tir(
                cls.layout_transform5,
                (stage1_unit1_conv3_weight,),
                out_sinfo=R.Tensor((64, 16, 1, 1, 4, 4), dtype="float32"),
            )
            lv82 = R.call_tir(
                cls.contrib_conv2d_NCHWc3,
                (lv80, lv81),
                out_sinfo=R.Tensor((1, 64, 56, 56, 4), dtype="float32"),
            )
            lv83 = R.call_tir(
                cls.layout_transform5,
                (stage1_unit1_sc_weight,),
                out_sinfo=R.Tensor((64, 16, 1, 1, 4, 4), dtype="float32"),
            )
            lv84 = R.call_tir(
                cls.contrib_conv2d_NCHWc3,
                (lv46, lv83),
                out_sinfo=R.Tensor((1, 64, 56, 56, 4), dtype="float32"),
            )
            lv85 = R.call_tir(
                cls.add7,
                (lv82, lv84),
                out_sinfo=R.Tensor((1, 64, 56, 56, 4), dtype="float32"),
            )
            lv86 = R.call_tir(
                cls.add8,
                (
                    stage1_unit2_bn1_moving_var,
                    R.const(1.9999999494757503e-05, "float32"),
                ),
                out_sinfo=R.Tensor((256,), dtype="float32"),
            )
            lv87 = R.call_tir(
                cls.rsqrt2, (lv86,), out_sinfo=R.Tensor((256,), dtype="float32")
            )
            lv88 = R.call_tir(
                cls.multiply7,
                (lv87, stage1_unit2_bn1_gamma),
                out_sinfo=R.Tensor((256,), dtype="float32"),
            )
            lv89 = R.call_tir(
                cls.expand_dims4,
                (lv88,),
                out_sinfo=R.Tensor((256, 1, 1), dtype="float32"),
            )
            lv90 = R.call_tir(
                cls.expand_dims5,
                (lv89,),
                out_sinfo=R.Tensor((1, 256, 1, 1), dtype="float32"),
            )
            lv91 = R.call_tir(
                cls.layout_transform6,
                (lv90,),
                out_sinfo=R.Tensor((1, 64, 1, 1, 4), dtype="float32"),
            )
            lv92 = R.call_tir(
                cls.multiply8,
                (lv85, lv91),
                out_sinfo=R.Tensor((1, 64, 56, 56, 4), dtype="float32"),
            )
            lv93 = R.call_tir(
                cls.negative2,
                (stage1_unit2_bn1_moving_mean,),
                out_sinfo=R.Tensor((256,), dtype="float32"),
            )
            lv94 = R.call_tir(
                cls.multiply7, (lv93, lv88), out_sinfo=R.Tensor((256,), dtype="float32")
            )
            lv95 = R.call_tir(
                cls.add9,
                (lv94, stage1_unit2_bn1_beta),
                out_sinfo=R.Tensor((256,), dtype="float32"),
            )
            lv96 = R.call_tir(
                cls.expand_dims4,
                (lv95,),
                out_sinfo=R.Tensor((256, 1, 1), dtype="float32"),
            )
            lv97 = R.call_tir(
                cls.expand_dims5,
                (lv96,),
                out_sinfo=R.Tensor((1, 256, 1, 1), dtype="float32"),
            )
            lv98 = R.call_tir(
                cls.layout_transform6,
                (lv97,),
                out_sinfo=R.Tensor((1, 64, 1, 1, 4), dtype="float32"),
            )
            lv99 = R.call_tir(
                cls.add10,
                (lv92, lv98),
                out_sinfo=R.Tensor((1, 64, 56, 56, 4), dtype="float32"),
            )
            lv100 = R.call_tir(
                cls.relu2,
                (lv99,),
                out_sinfo=R.Tensor((1, 64, 56, 56, 4), dtype="float32"),
            )
            lv101 = R.call_tir(
                cls.add3,
                (
                    stage1_unit2_bn2_moving_var,
                    R.const(1.9999999494757503e-05, "float32"),
                ),
                out_sinfo=R.Tensor((64,), dtype="float32"),
            )
            lv102 = R.call_tir(
                cls.rsqrt1, (lv101,), out_sinfo=R.Tensor((64,), dtype="float32")
            )
            lv103 = R.call_tir(
                cls.multiply1,
                (lv102, stage1_unit2_bn2_gamma),
                out_sinfo=R.Tensor((64,), dtype="float32"),
            )
            lv104 = R.call_tir(
                cls.expand_dims1,
                (lv103,),
                out_sinfo=R.Tensor((64, 1, 1), dtype="float32"),
            )
            lv105 = R.call_tir(
                cls.squeeze1, (lv104,), out_sinfo=R.Tensor((64,), dtype="float32")
            )
            lv106 = R.call_tir(
                cls.expand_dims2,
                (lv105,),
                out_sinfo=R.Tensor((64, 1, 1, 1), dtype="float32"),
            )
            lv107 = R.call_tir(
                cls.multiply9,
                (stage1_unit2_conv1_weight, lv106),
                out_sinfo=R.Tensor((64, 256, 1, 1), dtype="float32"),
            )
            lv108 = R.call_tir(
                cls.layout_transform7,
                (lv107,),
                out_sinfo=R.Tensor((16, 64, 1, 1, 4, 4), dtype="float32"),
            )
            lv109 = R.call_tir(
                cls.contrib_conv2d_NCHWc4,
                (lv100, lv108),
                out_sinfo=R.Tensor((1, 16, 56, 56, 4), dtype="float32"),
            )
            lv110 = R.call_tir(
                cls.negative1,
                (stage1_unit2_bn2_moving_mean,),
                out_sinfo=R.Tensor((64,), dtype="float32"),
            )
            lv111 = R.call_tir(
                cls.multiply1,
                (lv110, lv103),
                out_sinfo=R.Tensor((64,), dtype="float32"),
            )
            lv112 = R.call_tir(
                cls.add4,
                (lv111, stage1_unit2_bn2_beta),
                out_sinfo=R.Tensor((64,), dtype="float32"),
            )
            lv113 = R.call_tir(
                cls.expand_dims1,
                (lv112,),
                out_sinfo=R.Tensor((64, 1, 1), dtype="float32"),
            )
            lv114 = R.call_tir(
                cls.expand_dims3,
                (lv113,),
                out_sinfo=R.Tensor((1, 64, 1, 1), dtype="float32"),
            )
            lv115 = R.call_tir(
                cls.layout_transform2,
                (lv114,),
                out_sinfo=R.Tensor((1, 16, 1, 1, 4), dtype="float32"),
            )
            lv116 = R.call_tir(
                cls.add6,
                (lv109, lv115),
                out_sinfo=R.Tensor((1, 16, 56, 56, 4), dtype="float32"),
            )
            lv117 = R.call_tir(
                cls.relu1,
                (lv116,),
                out_sinfo=R.Tensor((1, 16, 56, 56, 4), dtype="float32"),
            )
            lv118 = R.call_tir(
                cls.add3,
                (
                    stage1_unit2_bn3_moving_var,
                    R.const(1.9999999494757503e-05, "float32"),
                ),
                out_sinfo=R.Tensor((64,), dtype="float32"),
            )
            lv119 = R.call_tir(
                cls.rsqrt1, (lv118,), out_sinfo=R.Tensor((64,), dtype="float32")
            )
            lv120 = R.call_tir(
                cls.multiply1,
                (lv119, stage1_unit2_bn3_gamma),
                out_sinfo=R.Tensor((64,), dtype="float32"),
            )
            lv121 = R.call_tir(
                cls.expand_dims1,
                (lv120,),
                out_sinfo=R.Tensor((64, 1, 1), dtype="float32"),
            )
            lv122 = R.call_tir(
                cls.squeeze1, (lv121,), out_sinfo=R.Tensor((64,), dtype="float32")
            )
            lv123 = R.call_tir(
                cls.expand_dims2,
                (lv122,),
                out_sinfo=R.Tensor((64, 1, 1, 1), dtype="float32"),
            )
            lv124 = R.call_tir(
                cls.multiply6,
                (stage1_unit2_conv2_weight, lv123),
                out_sinfo=R.Tensor((64, 64, 3, 3), dtype="float32"),
            )
            lv125 = R.call_tir(
                cls.layout_transform4,
                (lv124,),
                out_sinfo=R.Tensor((16, 16, 3, 3, 4, 4), dtype="float32"),
            )
            lv126 = R.call_tir(
                cls.contrib_conv2d_NCHWc2,
                (lv117, lv125),
                out_sinfo=R.Tensor((1, 16, 56, 56, 4), dtype="float32"),
            )
            lv127 = R.call_tir(
                cls.negative1,
                (stage1_unit2_bn3_moving_mean,),
                out_sinfo=R.Tensor((64,), dtype="float32"),
            )
            lv128 = R.call_tir(
                cls.multiply1,
                (lv127, lv120),
                out_sinfo=R.Tensor((64,), dtype="float32"),
            )
            lv129 = R.call_tir(
                cls.add4,
                (lv128, stage1_unit2_bn3_beta),
                out_sinfo=R.Tensor((64,), dtype="float32"),
            )
            lv130 = R.call_tir(
                cls.expand_dims1,
                (lv129,),
                out_sinfo=R.Tensor((64, 1, 1), dtype="float32"),
            )
            lv131 = R.call_tir(
                cls.expand_dims3,
                (lv130,),
                out_sinfo=R.Tensor((1, 64, 1, 1), dtype="float32"),
            )
            lv132 = R.call_tir(
                cls.layout_transform2,
                (lv131,),
                out_sinfo=R.Tensor((1, 16, 1, 1, 4), dtype="float32"),
            )
            lv133 = R.call_tir(
                cls.add6,
                (lv126, lv132),
                out_sinfo=R.Tensor((1, 16, 56, 56, 4), dtype="float32"),
            )
            lv134 = R.call_tir(
                cls.relu1,
                (lv133,),
                out_sinfo=R.Tensor((1, 16, 56, 56, 4), dtype="float32"),
            )
            lv135 = R.call_tir(
                cls.layout_transform5,
                (stage1_unit2_conv3_weight,),
                out_sinfo=R.Tensor((64, 16, 1, 1, 4, 4), dtype="float32"),
            )
            lv136 = R.call_tir(
                cls.contrib_conv2d_NCHWc3,
                (lv134, lv135),
                out_sinfo=R.Tensor((1, 64, 56, 56, 4), dtype="float32"),
            )
            lv137 = R.call_tir(
                cls.add7,
                (lv136, lv85),
                out_sinfo=R.Tensor((1, 64, 56, 56, 4), dtype="float32"),
            )
            lv138 = R.call_tir(
                cls.add8,
                (
                    stage1_unit3_bn1_moving_var,
                    R.const(1.9999999494757503e-05, "float32"),
                ),
                out_sinfo=R.Tensor((256,), dtype="float32"),
            )
            lv139 = R.call_tir(
                cls.rsqrt2, (lv138,), out_sinfo=R.Tensor((256,), dtype="float32")
            )
            lv140 = R.call_tir(
                cls.multiply7,
                (lv139, stage1_unit3_bn1_gamma),
                out_sinfo=R.Tensor((256,), dtype="float32"),
            )
            lv141 = R.call_tir(
                cls.expand_dims4,
                (lv140,),
                out_sinfo=R.Tensor((256, 1, 1), dtype="float32"),
            )
            lv142 = R.call_tir(
                cls.expand_dims5,
                (lv141,),
                out_sinfo=R.Tensor((1, 256, 1, 1), dtype="float32"),
            )
            lv143 = R.call_tir(
                cls.layout_transform6,
                (lv142,),
                out_sinfo=R.Tensor((1, 64, 1, 1, 4), dtype="float32"),
            )
            lv144 = R.call_tir(
                cls.multiply8,
                (lv137, lv143),
                out_sinfo=R.Tensor((1, 64, 56, 56, 4), dtype="float32"),
            )
            lv145 = R.call_tir(
                cls.negative2,
                (stage1_unit3_bn1_moving_mean,),
                out_sinfo=R.Tensor((256,), dtype="float32"),
            )
            lv146 = R.call_tir(
                cls.multiply7,
                (lv145, lv140),
                out_sinfo=R.Tensor((256,), dtype="float32"),
            )
            lv147 = R.call_tir(
                cls.add9,
                (lv146, stage1_unit3_bn1_beta),
                out_sinfo=R.Tensor((256,), dtype="float32"),
            )
            lv148 = R.call_tir(
                cls.expand_dims4,
                (lv147,),
                out_sinfo=R.Tensor((256, 1, 1), dtype="float32"),
            )
            lv149 = R.call_tir(
                cls.expand_dims5,
                (lv148,),
                out_sinfo=R.Tensor((1, 256, 1, 1), dtype="float32"),
            )
            lv150 = R.call_tir(
                cls.layout_transform6,
                (lv149,),
                out_sinfo=R.Tensor((1, 64, 1, 1, 4), dtype="float32"),
            )
            lv151 = R.call_tir(
                cls.add10,
                (lv144, lv150),
                out_sinfo=R.Tensor((1, 64, 56, 56, 4), dtype="float32"),
            )
            lv152 = R.call_tir(
                cls.relu2,
                (lv151,),
                out_sinfo=R.Tensor((1, 64, 56, 56, 4), dtype="float32"),
            )
            lv153 = R.call_tir(
                cls.add3,
                (
                    stage1_unit3_bn2_moving_var,
                    R.const(1.9999999494757503e-05, "float32"),
                ),
                out_sinfo=R.Tensor((64,), dtype="float32"),
            )
            lv154 = R.call_tir(
                cls.rsqrt1, (lv153,), out_sinfo=R.Tensor((64,), dtype="float32")
            )
            lv155 = R.call_tir(
                cls.multiply1,
                (lv154, stage1_unit3_bn2_gamma),
                out_sinfo=R.Tensor((64,), dtype="float32"),
            )
            lv156 = R.call_tir(
                cls.expand_dims1,
                (lv155,),
                out_sinfo=R.Tensor((64, 1, 1), dtype="float32"),
            )
            lv157 = R.call_tir(
                cls.squeeze1, (lv156,), out_sinfo=R.Tensor((64,), dtype="float32")
            )
            lv158 = R.call_tir(
                cls.expand_dims2,
                (lv157,),
                out_sinfo=R.Tensor((64, 1, 1, 1), dtype="float32"),
            )
            lv159 = R.call_tir(
                cls.multiply9,
                (stage1_unit3_conv1_weight, lv158),
                out_sinfo=R.Tensor((64, 256, 1, 1), dtype="float32"),
            )
            lv160 = R.call_tir(
                cls.layout_transform7,
                (lv159,),
                out_sinfo=R.Tensor((16, 64, 1, 1, 4, 4), dtype="float32"),
            )
            lv161 = R.call_tir(
                cls.contrib_conv2d_NCHWc4,
                (lv152, lv160),
                out_sinfo=R.Tensor((1, 16, 56, 56, 4), dtype="float32"),
            )
            lv162 = R.call_tir(
                cls.negative1,
                (stage1_unit3_bn2_moving_mean,),
                out_sinfo=R.Tensor((64,), dtype="float32"),
            )
            lv163 = R.call_tir(
                cls.multiply1,
                (lv162, lv155),
                out_sinfo=R.Tensor((64,), dtype="float32"),
            )
            lv164 = R.call_tir(
                cls.add4,
                (lv163, stage1_unit3_bn2_beta),
                out_sinfo=R.Tensor((64,), dtype="float32"),
            )
            lv165 = R.call_tir(
                cls.expand_dims1,
                (lv164,),
                out_sinfo=R.Tensor((64, 1, 1), dtype="float32"),
            )
            lv166 = R.call_tir(
                cls.expand_dims3,
                (lv165,),
                out_sinfo=R.Tensor((1, 64, 1, 1), dtype="float32"),
            )
            lv167 = R.call_tir(
                cls.layout_transform2,
                (lv166,),
                out_sinfo=R.Tensor((1, 16, 1, 1, 4), dtype="float32"),
            )
            lv168 = R.call_tir(
                cls.add6,
                (lv161, lv167),
                out_sinfo=R.Tensor((1, 16, 56, 56, 4), dtype="float32"),
            )
            lv169 = R.call_tir(
                cls.relu1,
                (lv168,),
                out_sinfo=R.Tensor((1, 16, 56, 56, 4), dtype="float32"),
            )
            lv170 = R.call_tir(
                cls.add3,
                (
                    stage1_unit3_bn3_moving_var,
                    R.const(1.9999999494757503e-05, "float32"),
                ),
                out_sinfo=R.Tensor((64,), dtype="float32"),
            )
            lv171 = R.call_tir(
                cls.rsqrt1, (lv170,), out_sinfo=R.Tensor((64,), dtype="float32")
            )
            lv172 = R.call_tir(
                cls.multiply1,
                (lv171, stage1_unit3_bn3_gamma),
                out_sinfo=R.Tensor((64,), dtype="float32"),
            )
            lv173 = R.call_tir(
                cls.expand_dims1,
                (lv172,),
                out_sinfo=R.Tensor((64, 1, 1), dtype="float32"),
            )
            lv174 = R.call_tir(
                cls.squeeze1, (lv173,), out_sinfo=R.Tensor((64,), dtype="float32")
            )
            lv175 = R.call_tir(
                cls.expand_dims2,
                (lv174,),
                out_sinfo=R.Tensor((64, 1, 1, 1), dtype="float32"),
            )
            lv176 = R.call_tir(
                cls.multiply6,
                (stage1_unit3_conv2_weight, lv175),
                out_sinfo=R.Tensor((64, 64, 3, 3), dtype="float32"),
            )
            lv177 = R.call_tir(
                cls.layout_transform4,
                (lv176,),
                out_sinfo=R.Tensor((16, 16, 3, 3, 4, 4), dtype="float32"),
            )
            lv178 = R.call_tir(
                cls.contrib_conv2d_NCHWc2,
                (lv169, lv177),
                out_sinfo=R.Tensor((1, 16, 56, 56, 4), dtype="float32"),
            )
            lv179 = R.call_tir(
                cls.negative1,
                (stage1_unit3_bn3_moving_mean,),
                out_sinfo=R.Tensor((64,), dtype="float32"),
            )
            lv180 = R.call_tir(
                cls.multiply1,
                (lv179, lv172),
                out_sinfo=R.Tensor((64,), dtype="float32"),
            )
            lv181 = R.call_tir(
                cls.add4,
                (lv180, stage1_unit3_bn3_beta),
                out_sinfo=R.Tensor((64,), dtype="float32"),
            )
            lv182 = R.call_tir(
                cls.expand_dims1,
                (lv181,),
                out_sinfo=R.Tensor((64, 1, 1), dtype="float32"),
            )
            lv183 = R.call_tir(
                cls.expand_dims3,
                (lv182,),
                out_sinfo=R.Tensor((1, 64, 1, 1), dtype="float32"),
            )
            lv184 = R.call_tir(
                cls.layout_transform2,
                (lv183,),
                out_sinfo=R.Tensor((1, 16, 1, 1, 4), dtype="float32"),
            )
            lv185 = R.call_tir(
                cls.add6,
                (lv178, lv184),
                out_sinfo=R.Tensor((1, 16, 56, 56, 4), dtype="float32"),
            )
            lv186 = R.call_tir(
                cls.relu1,
                (lv185,),
                out_sinfo=R.Tensor((1, 16, 56, 56, 4), dtype="float32"),
            )
            lv187 = R.call_tir(
                cls.layout_transform5,
                (stage1_unit3_conv3_weight,),
                out_sinfo=R.Tensor((64, 16, 1, 1, 4, 4), dtype="float32"),
            )
            lv188 = R.call_tir(
                cls.contrib_conv2d_NCHWc3,
                (lv186, lv187),
                out_sinfo=R.Tensor((1, 64, 56, 56, 4), dtype="float32"),
            )
            lv189 = R.call_tir(
                cls.add7,
                (lv188, lv137),
                out_sinfo=R.Tensor((1, 64, 56, 56, 4), dtype="float32"),
            )
            lv190 = R.call_tir(
                cls.add8,
                (
                    stage2_unit1_bn1_moving_var,
                    R.const(1.9999999494757503e-05, "float32"),
                ),
                out_sinfo=R.Tensor((256,), dtype="float32"),
            )
            lv191 = R.call_tir(
                cls.rsqrt2, (lv190,), out_sinfo=R.Tensor((256,), dtype="float32")
            )
            lv192 = R.call_tir(
                cls.multiply7,
                (lv191, stage2_unit1_bn1_gamma),
                out_sinfo=R.Tensor((256,), dtype="float32"),
            )
            lv193 = R.call_tir(
                cls.expand_dims4,
                (lv192,),
                out_sinfo=R.Tensor((256, 1, 1), dtype="float32"),
            )
            lv194 = R.call_tir(
                cls.expand_dims5,
                (lv193,),
                out_sinfo=R.Tensor((1, 256, 1, 1), dtype="float32"),
            )
            lv195 = R.call_tir(
                cls.layout_transform6,
                (lv194,),
                out_sinfo=R.Tensor((1, 64, 1, 1, 4), dtype="float32"),
            )
            lv196 = R.call_tir(
                cls.multiply8,
                (lv189, lv195),
                out_sinfo=R.Tensor((1, 64, 56, 56, 4), dtype="float32"),
            )
            lv197 = R.call_tir(
                cls.negative2,
                (stage2_unit1_bn1_moving_mean,),
                out_sinfo=R.Tensor((256,), dtype="float32"),
            )
            lv198 = R.call_tir(
                cls.multiply7,
                (lv197, lv192),
                out_sinfo=R.Tensor((256,), dtype="float32"),
            )
            lv199 = R.call_tir(
                cls.add9,
                (lv198, stage2_unit1_bn1_beta),
                out_sinfo=R.Tensor((256,), dtype="float32"),
            )
            lv200 = R.call_tir(
                cls.expand_dims4,
                (lv199,),
                out_sinfo=R.Tensor((256, 1, 1), dtype="float32"),
            )
            lv201 = R.call_tir(
                cls.expand_dims5,
                (lv200,),
                out_sinfo=R.Tensor((1, 256, 1, 1), dtype="float32"),
            )
            lv202 = R.call_tir(
                cls.layout_transform6,
                (lv201,),
                out_sinfo=R.Tensor((1, 64, 1, 1, 4), dtype="float32"),
            )
            lv203 = R.call_tir(
                cls.add10,
                (lv196, lv202),
                out_sinfo=R.Tensor((1, 64, 56, 56, 4), dtype="float32"),
            )
            lv204 = R.call_tir(
                cls.relu2,
                (lv203,),
                out_sinfo=R.Tensor((1, 64, 56, 56, 4), dtype="float32"),
            )
            lv205 = R.call_tir(
                cls.add11,
                (
                    stage2_unit1_bn2_moving_var,
                    R.const(1.9999999494757503e-05, "float32"),
                ),
                out_sinfo=R.Tensor((128,), dtype="float32"),
            )
            lv206 = R.call_tir(
                cls.rsqrt3, (lv205,), out_sinfo=R.Tensor((128,), dtype="float32")
            )
            lv207 = R.call_tir(
                cls.multiply10,
                (lv206, stage2_unit1_bn2_gamma),
                out_sinfo=R.Tensor((128,), dtype="float32"),
            )
            lv208 = R.call_tir(
                cls.expand_dims6,
                (lv207,),
                out_sinfo=R.Tensor((128, 1, 1), dtype="float32"),
            )
            lv209 = R.call_tir(
                cls.squeeze2, (lv208,), out_sinfo=R.Tensor((128,), dtype="float32")
            )
            lv210 = R.call_tir(
                cls.expand_dims7,
                (lv209,),
                out_sinfo=R.Tensor((128, 1, 1, 1), dtype="float32"),
            )
            lv211 = R.call_tir(
                cls.multiply11,
                (stage2_unit1_conv1_weight, lv210),
                out_sinfo=R.Tensor((128, 256, 1, 1), dtype="float32"),
            )
            lv212 = R.call_tir(
                cls.layout_transform8,
                (lv211,),
                out_sinfo=R.Tensor((32, 64, 1, 1, 4, 4), dtype="float32"),
            )
            lv213 = R.call_tir(
                cls.contrib_conv2d_NCHWc5,
                (lv204, lv212),
                out_sinfo=R.Tensor((1, 32, 28, 28, 4), dtype="float32"),
            )
            lv214 = R.call_tir(
                cls.negative3,
                (stage2_unit1_bn2_moving_mean,),
                out_sinfo=R.Tensor((128,), dtype="float32"),
            )
            lv215 = R.call_tir(
                cls.multiply10,
                (lv214, lv207),
                out_sinfo=R.Tensor((128,), dtype="float32"),
            )
            lv216 = R.call_tir(
                cls.add12,
                (lv215, stage2_unit1_bn2_beta),
                out_sinfo=R.Tensor((128,), dtype="float32"),
            )
            lv217 = R.call_tir(
                cls.expand_dims6,
                (lv216,),
                out_sinfo=R.Tensor((128, 1, 1), dtype="float32"),
            )
            lv218 = R.call_tir(
                cls.expand_dims8,
                (lv217,),
                out_sinfo=R.Tensor((1, 128, 1, 1), dtype="float32"),
            )
            lv219 = R.call_tir(
                cls.layout_transform9,
                (lv218,),
                out_sinfo=R.Tensor((1, 32, 1, 1, 4), dtype="float32"),
            )
            lv220 = R.call_tir(
                cls.add13,
                (lv213, lv219),
                out_sinfo=R.Tensor((1, 32, 28, 28, 4), dtype="float32"),
            )
            lv221 = R.call_tir(
                cls.relu3,
                (lv220,),
                out_sinfo=R.Tensor((1, 32, 28, 28, 4), dtype="float32"),
            )
            lv222 = R.call_tir(
                cls.add11,
                (
                    stage2_unit1_bn3_moving_var,
                    R.const(1.9999999494757503e-05, "float32"),
                ),
                out_sinfo=R.Tensor((128,), dtype="float32"),
            )
            lv223 = R.call_tir(
                cls.rsqrt3, (lv222,), out_sinfo=R.Tensor((128,), dtype="float32")
            )
            lv224 = R.call_tir(
                cls.multiply10,
                (lv223, stage2_unit1_bn3_gamma),
                out_sinfo=R.Tensor((128,), dtype="float32"),
            )
            lv225 = R.call_tir(
                cls.expand_dims6,
                (lv224,),
                out_sinfo=R.Tensor((128, 1, 1), dtype="float32"),
            )
            lv226 = R.call_tir(
                cls.squeeze2, (lv225,), out_sinfo=R.Tensor((128,), dtype="float32")
            )
            lv227 = R.call_tir(
                cls.expand_dims7,
                (lv226,),
                out_sinfo=R.Tensor((128, 1, 1, 1), dtype="float32"),
            )
            lv228 = R.call_tir(
                cls.multiply12,
                (stage2_unit1_conv2_weight, lv227),
                out_sinfo=R.Tensor((128, 128, 3, 3), dtype="float32"),
            )
            lv229 = R.call_tir(
                cls.layout_transform10,
                (lv228,),
                out_sinfo=R.Tensor((32, 32, 3, 3, 4, 4), dtype="float32"),
            )
            lv230 = R.call_tir(
                cls.contrib_conv2d_NCHWc6,
                (lv221, lv229),
                out_sinfo=R.Tensor((1, 32, 28, 28, 4), dtype="float32"),
            )
            lv231 = R.call_tir(
                cls.negative3,
                (stage2_unit1_bn3_moving_mean,),
                out_sinfo=R.Tensor((128,), dtype="float32"),
            )
            lv232 = R.call_tir(
                cls.multiply10,
                (lv231, lv224),
                out_sinfo=R.Tensor((128,), dtype="float32"),
            )
            lv233 = R.call_tir(
                cls.add12,
                (lv232, stage2_unit1_bn3_beta),
                out_sinfo=R.Tensor((128,), dtype="float32"),
            )
            lv234 = R.call_tir(
                cls.expand_dims6,
                (lv233,),
                out_sinfo=R.Tensor((128, 1, 1), dtype="float32"),
            )
            lv235 = R.call_tir(
                cls.expand_dims8,
                (lv234,),
                out_sinfo=R.Tensor((1, 128, 1, 1), dtype="float32"),
            )
            lv236 = R.call_tir(
                cls.layout_transform9,
                (lv235,),
                out_sinfo=R.Tensor((1, 32, 1, 1, 4), dtype="float32"),
            )
            lv237 = R.call_tir(
                cls.add13,
                (lv230, lv236),
                out_sinfo=R.Tensor((1, 32, 28, 28, 4), dtype="float32"),
            )
            lv238 = R.call_tir(
                cls.relu3,
                (lv237,),
                out_sinfo=R.Tensor((1, 32, 28, 28, 4), dtype="float32"),
            )
            lv239 = R.call_tir(
                cls.layout_transform11,
                (stage2_unit1_conv3_weight,),
                out_sinfo=R.Tensor((128, 32, 1, 1, 4, 4), dtype="float32"),
            )
            lv240 = R.call_tir(
                cls.contrib_conv2d_NCHWc7,
                (lv238, lv239),
                out_sinfo=R.Tensor((1, 128, 28, 28, 4), dtype="float32"),
            )
            lv241 = R.call_tir(
                cls.layout_transform12,
                (stage2_unit1_sc_weight,),
                out_sinfo=R.Tensor((128, 64, 1, 1, 4, 4), dtype="float32"),
            )
            lv242 = R.call_tir(
                cls.contrib_conv2d_NCHWc8,
                (lv204, lv241),
                out_sinfo=R.Tensor((1, 128, 28, 28, 4), dtype="float32"),
            )
            lv243 = R.call_tir(
                cls.add14,
                (lv240, lv242),
                out_sinfo=R.Tensor((1, 128, 28, 28, 4), dtype="float32"),
            )
            lv244 = R.call_tir(
                cls.add15,
                (
                    stage2_unit2_bn1_moving_var,
                    R.const(1.9999999494757503e-05, "float32"),
                ),
                out_sinfo=R.Tensor((512,), dtype="float32"),
            )
            lv245 = R.call_tir(
                cls.rsqrt4, (lv244,), out_sinfo=R.Tensor((512,), dtype="float32")
            )
            lv246 = R.call_tir(
                cls.multiply13,
                (lv245, stage2_unit2_bn1_gamma),
                out_sinfo=R.Tensor((512,), dtype="float32"),
            )
            lv247 = R.call_tir(
                cls.expand_dims9,
                (lv246,),
                out_sinfo=R.Tensor((512, 1, 1), dtype="float32"),
            )
            lv248 = R.call_tir(
                cls.expand_dims10,
                (lv247,),
                out_sinfo=R.Tensor((1, 512, 1, 1), dtype="float32"),
            )
            lv249 = R.call_tir(
                cls.layout_transform13,
                (lv248,),
                out_sinfo=R.Tensor((1, 128, 1, 1, 4), dtype="float32"),
            )
            lv250 = R.call_tir(
                cls.multiply14,
                (lv243, lv249),
                out_sinfo=R.Tensor((1, 128, 28, 28, 4), dtype="float32"),
            )
            lv251 = R.call_tir(
                cls.negative4,
                (stage2_unit2_bn1_moving_mean,),
                out_sinfo=R.Tensor((512,), dtype="float32"),
            )
            lv252 = R.call_tir(
                cls.multiply13,
                (lv251, lv246),
                out_sinfo=R.Tensor((512,), dtype="float32"),
            )
            lv253 = R.call_tir(
                cls.add16,
                (lv252, stage2_unit2_bn1_beta),
                out_sinfo=R.Tensor((512,), dtype="float32"),
            )
            lv254 = R.call_tir(
                cls.expand_dims9,
                (lv253,),
                out_sinfo=R.Tensor((512, 1, 1), dtype="float32"),
            )
            lv255 = R.call_tir(
                cls.expand_dims10,
                (lv254,),
                out_sinfo=R.Tensor((1, 512, 1, 1), dtype="float32"),
            )
            lv256 = R.call_tir(
                cls.layout_transform13,
                (lv255,),
                out_sinfo=R.Tensor((1, 128, 1, 1, 4), dtype="float32"),
            )
            lv257 = R.call_tir(
                cls.add17,
                (lv250, lv256),
                out_sinfo=R.Tensor((1, 128, 28, 28, 4), dtype="float32"),
            )
            lv258 = R.call_tir(
                cls.relu4,
                (lv257,),
                out_sinfo=R.Tensor((1, 128, 28, 28, 4), dtype="float32"),
            )
            lv259 = R.call_tir(
                cls.add11,
                (
                    stage2_unit2_bn2_moving_var,
                    R.const(1.9999999494757503e-05, "float32"),
                ),
                out_sinfo=R.Tensor((128,), dtype="float32"),
            )
            lv260 = R.call_tir(
                cls.rsqrt3, (lv259,), out_sinfo=R.Tensor((128,), dtype="float32")
            )
            lv261 = R.call_tir(
                cls.multiply10,
                (lv260, stage2_unit2_bn2_gamma),
                out_sinfo=R.Tensor((128,), dtype="float32"),
            )
            lv262 = R.call_tir(
                cls.expand_dims6,
                (lv261,),
                out_sinfo=R.Tensor((128, 1, 1), dtype="float32"),
            )
            lv263 = R.call_tir(
                cls.squeeze2, (lv262,), out_sinfo=R.Tensor((128,), dtype="float32")
            )
            lv264 = R.call_tir(
                cls.expand_dims7,
                (lv263,),
                out_sinfo=R.Tensor((128, 1, 1, 1), dtype="float32"),
            )
            lv265 = R.call_tir(
                cls.multiply15,
                (stage2_unit2_conv1_weight, lv264),
                out_sinfo=R.Tensor((128, 512, 1, 1), dtype="float32"),
            )
            lv266 = R.call_tir(
                cls.layout_transform14,
                (lv265,),
                out_sinfo=R.Tensor((32, 128, 1, 1, 4, 4), dtype="float32"),
            )
            lv267 = R.call_tir(
                cls.contrib_conv2d_NCHWc9,
                (lv258, lv266),
                out_sinfo=R.Tensor((1, 32, 28, 28, 4), dtype="float32"),
            )
            lv268 = R.call_tir(
                cls.negative3,
                (stage2_unit2_bn2_moving_mean,),
                out_sinfo=R.Tensor((128,), dtype="float32"),
            )
            lv269 = R.call_tir(
                cls.multiply10,
                (lv268, lv261),
                out_sinfo=R.Tensor((128,), dtype="float32"),
            )
            lv270 = R.call_tir(
                cls.add12,
                (lv269, stage2_unit2_bn2_beta),
                out_sinfo=R.Tensor((128,), dtype="float32"),
            )
            lv271 = R.call_tir(
                cls.expand_dims6,
                (lv270,),
                out_sinfo=R.Tensor((128, 1, 1), dtype="float32"),
            )
            lv272 = R.call_tir(
                cls.expand_dims8,
                (lv271,),
                out_sinfo=R.Tensor((1, 128, 1, 1), dtype="float32"),
            )
            lv273 = R.call_tir(
                cls.layout_transform9,
                (lv272,),
                out_sinfo=R.Tensor((1, 32, 1, 1, 4), dtype="float32"),
            )
            lv274 = R.call_tir(
                cls.add13,
                (lv267, lv273),
                out_sinfo=R.Tensor((1, 32, 28, 28, 4), dtype="float32"),
            )
            lv275 = R.call_tir(
                cls.relu3,
                (lv274,),
                out_sinfo=R.Tensor((1, 32, 28, 28, 4), dtype="float32"),
            )
            lv276 = R.call_tir(
                cls.add11,
                (
                    stage2_unit2_bn3_moving_var,
                    R.const(1.9999999494757503e-05, "float32"),
                ),
                out_sinfo=R.Tensor((128,), dtype="float32"),
            )
            lv277 = R.call_tir(
                cls.rsqrt3, (lv276,), out_sinfo=R.Tensor((128,), dtype="float32")
            )
            lv278 = R.call_tir(
                cls.multiply10,
                (lv277, stage2_unit2_bn3_gamma),
                out_sinfo=R.Tensor((128,), dtype="float32"),
            )
            lv279 = R.call_tir(
                cls.expand_dims6,
                (lv278,),
                out_sinfo=R.Tensor((128, 1, 1), dtype="float32"),
            )
            lv280 = R.call_tir(
                cls.squeeze2, (lv279,), out_sinfo=R.Tensor((128,), dtype="float32")
            )
            lv281 = R.call_tir(
                cls.expand_dims7,
                (lv280,),
                out_sinfo=R.Tensor((128, 1, 1, 1), dtype="float32"),
            )
            lv282 = R.call_tir(
                cls.multiply12,
                (stage2_unit2_conv2_weight, lv281),
                out_sinfo=R.Tensor((128, 128, 3, 3), dtype="float32"),
            )
            lv283 = R.call_tir(
                cls.layout_transform10,
                (lv282,),
                out_sinfo=R.Tensor((32, 32, 3, 3, 4, 4), dtype="float32"),
            )
            lv284 = R.call_tir(
                cls.contrib_conv2d_NCHWc6,
                (lv275, lv283),
                out_sinfo=R.Tensor((1, 32, 28, 28, 4), dtype="float32"),
            )
            lv285 = R.call_tir(
                cls.negative3,
                (stage2_unit2_bn3_moving_mean,),
                out_sinfo=R.Tensor((128,), dtype="float32"),
            )
            lv286 = R.call_tir(
                cls.multiply10,
                (lv285, lv278),
                out_sinfo=R.Tensor((128,), dtype="float32"),
            )
            lv287 = R.call_tir(
                cls.add12,
                (lv286, stage2_unit2_bn3_beta),
                out_sinfo=R.Tensor((128,), dtype="float32"),
            )
            lv288 = R.call_tir(
                cls.expand_dims6,
                (lv287,),
                out_sinfo=R.Tensor((128, 1, 1), dtype="float32"),
            )
            lv289 = R.call_tir(
                cls.expand_dims8,
                (lv288,),
                out_sinfo=R.Tensor((1, 128, 1, 1), dtype="float32"),
            )
            lv290 = R.call_tir(
                cls.layout_transform9,
                (lv289,),
                out_sinfo=R.Tensor((1, 32, 1, 1, 4), dtype="float32"),
            )
            lv291 = R.call_tir(
                cls.add13,
                (lv284, lv290),
                out_sinfo=R.Tensor((1, 32, 28, 28, 4), dtype="float32"),
            )
            lv292 = R.call_tir(
                cls.relu3,
                (lv291,),
                out_sinfo=R.Tensor((1, 32, 28, 28, 4), dtype="float32"),
            )
            lv293 = R.call_tir(
                cls.layout_transform11,
                (stage2_unit2_conv3_weight,),
                out_sinfo=R.Tensor((128, 32, 1, 1, 4, 4), dtype="float32"),
            )
            lv294 = R.call_tir(
                cls.contrib_conv2d_NCHWc7,
                (lv292, lv293),
                out_sinfo=R.Tensor((1, 128, 28, 28, 4), dtype="float32"),
            )
            lv295 = R.call_tir(
                cls.add14,
                (lv294, lv243),
                out_sinfo=R.Tensor((1, 128, 28, 28, 4), dtype="float32"),
            )
            lv296 = R.call_tir(
                cls.add15,
                (
                    stage2_unit3_bn1_moving_var,
                    R.const(1.9999999494757503e-05, "float32"),
                ),
                out_sinfo=R.Tensor((512,), dtype="float32"),
            )
            lv297 = R.call_tir(
                cls.rsqrt4, (lv296,), out_sinfo=R.Tensor((512,), dtype="float32")
            )
            lv298 = R.call_tir(
                cls.multiply13,
                (lv297, stage2_unit3_bn1_gamma),
                out_sinfo=R.Tensor((512,), dtype="float32"),
            )
            lv299 = R.call_tir(
                cls.expand_dims9,
                (lv298,),
                out_sinfo=R.Tensor((512, 1, 1), dtype="float32"),
            )
            lv300 = R.call_tir(
                cls.expand_dims10,
                (lv299,),
                out_sinfo=R.Tensor((1, 512, 1, 1), dtype="float32"),
            )
            lv301 = R.call_tir(
                cls.layout_transform13,
                (lv300,),
                out_sinfo=R.Tensor((1, 128, 1, 1, 4), dtype="float32"),
            )
            lv302 = R.call_tir(
                cls.multiply14,
                (lv295, lv301),
                out_sinfo=R.Tensor((1, 128, 28, 28, 4), dtype="float32"),
            )
            lv303 = R.call_tir(
                cls.negative4,
                (stage2_unit3_bn1_moving_mean,),
                out_sinfo=R.Tensor((512,), dtype="float32"),
            )
            lv304 = R.call_tir(
                cls.multiply13,
                (lv303, lv298),
                out_sinfo=R.Tensor((512,), dtype="float32"),
            )
            lv305 = R.call_tir(
                cls.add16,
                (lv304, stage2_unit3_bn1_beta),
                out_sinfo=R.Tensor((512,), dtype="float32"),
            )
            lv306 = R.call_tir(
                cls.expand_dims9,
                (lv305,),
                out_sinfo=R.Tensor((512, 1, 1), dtype="float32"),
            )
            lv307 = R.call_tir(
                cls.expand_dims10,
                (lv306,),
                out_sinfo=R.Tensor((1, 512, 1, 1), dtype="float32"),
            )
            lv308 = R.call_tir(
                cls.layout_transform13,
                (lv307,),
                out_sinfo=R.Tensor((1, 128, 1, 1, 4), dtype="float32"),
            )
            lv309 = R.call_tir(
                cls.add17,
                (lv302, lv308),
                out_sinfo=R.Tensor((1, 128, 28, 28, 4), dtype="float32"),
            )
            lv310 = R.call_tir(
                cls.relu4,
                (lv309,),
                out_sinfo=R.Tensor((1, 128, 28, 28, 4), dtype="float32"),
            )
            lv311 = R.call_tir(
                cls.add11,
                (
                    stage2_unit3_bn2_moving_var,
                    R.const(1.9999999494757503e-05, "float32"),
                ),
                out_sinfo=R.Tensor((128,), dtype="float32"),
            )
            lv312 = R.call_tir(
                cls.rsqrt3, (lv311,), out_sinfo=R.Tensor((128,), dtype="float32")
            )
            lv313 = R.call_tir(
                cls.multiply10,
                (lv312, stage2_unit3_bn2_gamma),
                out_sinfo=R.Tensor((128,), dtype="float32"),
            )
            lv314 = R.call_tir(
                cls.expand_dims6,
                (lv313,),
                out_sinfo=R.Tensor((128, 1, 1), dtype="float32"),
            )
            lv315 = R.call_tir(
                cls.squeeze2, (lv314,), out_sinfo=R.Tensor((128,), dtype="float32")
            )
            lv316 = R.call_tir(
                cls.expand_dims7,
                (lv315,),
                out_sinfo=R.Tensor((128, 1, 1, 1), dtype="float32"),
            )
            lv317 = R.call_tir(
                cls.multiply15,
                (stage2_unit3_conv1_weight, lv316),
                out_sinfo=R.Tensor((128, 512, 1, 1), dtype="float32"),
            )
            lv318 = R.call_tir(
                cls.layout_transform14,
                (lv317,),
                out_sinfo=R.Tensor((32, 128, 1, 1, 4, 4), dtype="float32"),
            )
            lv319 = R.call_tir(
                cls.contrib_conv2d_NCHWc9,
                (lv310, lv318),
                out_sinfo=R.Tensor((1, 32, 28, 28, 4), dtype="float32"),
            )
            lv320 = R.call_tir(
                cls.negative3,
                (stage2_unit3_bn2_moving_mean,),
                out_sinfo=R.Tensor((128,), dtype="float32"),
            )
            lv321 = R.call_tir(
                cls.multiply10,
                (lv320, lv313),
                out_sinfo=R.Tensor((128,), dtype="float32"),
            )
            lv322 = R.call_tir(
                cls.add12,
                (lv321, stage2_unit3_bn2_beta),
                out_sinfo=R.Tensor((128,), dtype="float32"),
            )
            lv323 = R.call_tir(
                cls.expand_dims6,
                (lv322,),
                out_sinfo=R.Tensor((128, 1, 1), dtype="float32"),
            )
            lv324 = R.call_tir(
                cls.expand_dims8,
                (lv323,),
                out_sinfo=R.Tensor((1, 128, 1, 1), dtype="float32"),
            )
            lv325 = R.call_tir(
                cls.layout_transform9,
                (lv324,),
                out_sinfo=R.Tensor((1, 32, 1, 1, 4), dtype="float32"),
            )
            lv326 = R.call_tir(
                cls.add13,
                (lv319, lv325),
                out_sinfo=R.Tensor((1, 32, 28, 28, 4), dtype="float32"),
            )
            lv327 = R.call_tir(
                cls.relu3,
                (lv326,),
                out_sinfo=R.Tensor((1, 32, 28, 28, 4), dtype="float32"),
            )
            lv328 = R.call_tir(
                cls.add11,
                (
                    stage2_unit3_bn3_moving_var,
                    R.const(1.9999999494757503e-05, "float32"),
                ),
                out_sinfo=R.Tensor((128,), dtype="float32"),
            )
            lv329 = R.call_tir(
                cls.rsqrt3, (lv328,), out_sinfo=R.Tensor((128,), dtype="float32")
            )
            lv330 = R.call_tir(
                cls.multiply10,
                (lv329, stage2_unit3_bn3_gamma),
                out_sinfo=R.Tensor((128,), dtype="float32"),
            )
            lv331 = R.call_tir(
                cls.expand_dims6,
                (lv330,),
                out_sinfo=R.Tensor((128, 1, 1), dtype="float32"),
            )
            lv332 = R.call_tir(
                cls.squeeze2, (lv331,), out_sinfo=R.Tensor((128,), dtype="float32")
            )
            lv333 = R.call_tir(
                cls.expand_dims7,
                (lv332,),
                out_sinfo=R.Tensor((128, 1, 1, 1), dtype="float32"),
            )
            lv334 = R.call_tir(
                cls.multiply12,
                (stage2_unit3_conv2_weight, lv333),
                out_sinfo=R.Tensor((128, 128, 3, 3), dtype="float32"),
            )
            lv335 = R.call_tir(
                cls.layout_transform10,
                (lv334,),
                out_sinfo=R.Tensor((32, 32, 3, 3, 4, 4), dtype="float32"),
            )
            lv336 = R.call_tir(
                cls.contrib_conv2d_NCHWc6,
                (lv327, lv335),
                out_sinfo=R.Tensor((1, 32, 28, 28, 4), dtype="float32"),
            )
            lv337 = R.call_tir(
                cls.negative3,
                (stage2_unit3_bn3_moving_mean,),
                out_sinfo=R.Tensor((128,), dtype="float32"),
            )
            lv338 = R.call_tir(
                cls.multiply10,
                (lv337, lv330),
                out_sinfo=R.Tensor((128,), dtype="float32"),
            )
            lv339 = R.call_tir(
                cls.add12,
                (lv338, stage2_unit3_bn3_beta),
                out_sinfo=R.Tensor((128,), dtype="float32"),
            )
            lv340 = R.call_tir(
                cls.expand_dims6,
                (lv339,),
                out_sinfo=R.Tensor((128, 1, 1), dtype="float32"),
            )
            lv341 = R.call_tir(
                cls.expand_dims8,
                (lv340,),
                out_sinfo=R.Tensor((1, 128, 1, 1), dtype="float32"),
            )
            lv342 = R.call_tir(
                cls.layout_transform9,
                (lv341,),
                out_sinfo=R.Tensor((1, 32, 1, 1, 4), dtype="float32"),
            )
            lv343 = R.call_tir(
                cls.add13,
                (lv336, lv342),
                out_sinfo=R.Tensor((1, 32, 28, 28, 4), dtype="float32"),
            )
            lv344 = R.call_tir(
                cls.relu3,
                (lv343,),
                out_sinfo=R.Tensor((1, 32, 28, 28, 4), dtype="float32"),
            )
            lv345 = R.call_tir(
                cls.layout_transform11,
                (stage2_unit3_conv3_weight,),
                out_sinfo=R.Tensor((128, 32, 1, 1, 4, 4), dtype="float32"),
            )
            lv346 = R.call_tir(
                cls.contrib_conv2d_NCHWc7,
                (lv344, lv345),
                out_sinfo=R.Tensor((1, 128, 28, 28, 4), dtype="float32"),
            )
            lv347 = R.call_tir(
                cls.add14,
                (lv346, lv295),
                out_sinfo=R.Tensor((1, 128, 28, 28, 4), dtype="float32"),
            )
            lv348 = R.call_tir(
                cls.add15,
                (
                    stage2_unit4_bn1_moving_var,
                    R.const(1.9999999494757503e-05, "float32"),
                ),
                out_sinfo=R.Tensor((512,), dtype="float32"),
            )
            lv349 = R.call_tir(
                cls.rsqrt4, (lv348,), out_sinfo=R.Tensor((512,), dtype="float32")
            )
            lv350 = R.call_tir(
                cls.multiply13,
                (lv349, stage2_unit4_bn1_gamma),
                out_sinfo=R.Tensor((512,), dtype="float32"),
            )
            lv351 = R.call_tir(
                cls.expand_dims9,
                (lv350,),
                out_sinfo=R.Tensor((512, 1, 1), dtype="float32"),
            )
            lv352 = R.call_tir(
                cls.expand_dims10,
                (lv351,),
                out_sinfo=R.Tensor((1, 512, 1, 1), dtype="float32"),
            )
            lv353 = R.call_tir(
                cls.layout_transform13,
                (lv352,),
                out_sinfo=R.Tensor((1, 128, 1, 1, 4), dtype="float32"),
            )
            lv354 = R.call_tir(
                cls.multiply14,
                (lv347, lv353),
                out_sinfo=R.Tensor((1, 128, 28, 28, 4), dtype="float32"),
            )
            lv355 = R.call_tir(
                cls.negative4,
                (stage2_unit4_bn1_moving_mean,),
                out_sinfo=R.Tensor((512,), dtype="float32"),
            )
            lv356 = R.call_tir(
                cls.multiply13,
                (lv355, lv350),
                out_sinfo=R.Tensor((512,), dtype="float32"),
            )
            lv357 = R.call_tir(
                cls.add16,
                (lv356, stage2_unit4_bn1_beta),
                out_sinfo=R.Tensor((512,), dtype="float32"),
            )
            lv358 = R.call_tir(
                cls.expand_dims9,
                (lv357,),
                out_sinfo=R.Tensor((512, 1, 1), dtype="float32"),
            )
            lv359 = R.call_tir(
                cls.expand_dims10,
                (lv358,),
                out_sinfo=R.Tensor((1, 512, 1, 1), dtype="float32"),
            )
            lv360 = R.call_tir(
                cls.layout_transform13,
                (lv359,),
                out_sinfo=R.Tensor((1, 128, 1, 1, 4), dtype="float32"),
            )
            lv361 = R.call_tir(
                cls.add17,
                (lv354, lv360),
                out_sinfo=R.Tensor((1, 128, 28, 28, 4), dtype="float32"),
            )
            lv362 = R.call_tir(
                cls.relu4,
                (lv361,),
                out_sinfo=R.Tensor((1, 128, 28, 28, 4), dtype="float32"),
            )
            lv363 = R.call_tir(
                cls.add11,
                (
                    stage2_unit4_bn2_moving_var,
                    R.const(1.9999999494757503e-05, "float32"),
                ),
                out_sinfo=R.Tensor((128,), dtype="float32"),
            )
            lv364 = R.call_tir(
                cls.rsqrt3, (lv363,), out_sinfo=R.Tensor((128,), dtype="float32")
            )
            lv365 = R.call_tir(
                cls.multiply10,
                (lv364, stage2_unit4_bn2_gamma),
                out_sinfo=R.Tensor((128,), dtype="float32"),
            )
            lv366 = R.call_tir(
                cls.expand_dims6,
                (lv365,),
                out_sinfo=R.Tensor((128, 1, 1), dtype="float32"),
            )
            lv367 = R.call_tir(
                cls.squeeze2, (lv366,), out_sinfo=R.Tensor((128,), dtype="float32")
            )
            lv368 = R.call_tir(
                cls.expand_dims7,
                (lv367,),
                out_sinfo=R.Tensor((128, 1, 1, 1), dtype="float32"),
            )
            lv369 = R.call_tir(
                cls.multiply15,
                (stage2_unit4_conv1_weight, lv368),
                out_sinfo=R.Tensor((128, 512, 1, 1), dtype="float32"),
            )
            lv370 = R.call_tir(
                cls.layout_transform14,
                (lv369,),
                out_sinfo=R.Tensor((32, 128, 1, 1, 4, 4), dtype="float32"),
            )
            lv371 = R.call_tir(
                cls.contrib_conv2d_NCHWc9,
                (lv362, lv370),
                out_sinfo=R.Tensor((1, 32, 28, 28, 4), dtype="float32"),
            )
            lv372 = R.call_tir(
                cls.negative3,
                (stage2_unit4_bn2_moving_mean,),
                out_sinfo=R.Tensor((128,), dtype="float32"),
            )
            lv373 = R.call_tir(
                cls.multiply10,
                (lv372, lv365),
                out_sinfo=R.Tensor((128,), dtype="float32"),
            )
            lv374 = R.call_tir(
                cls.add12,
                (lv373, stage2_unit4_bn2_beta),
                out_sinfo=R.Tensor((128,), dtype="float32"),
            )
            lv375 = R.call_tir(
                cls.expand_dims6,
                (lv374,),
                out_sinfo=R.Tensor((128, 1, 1), dtype="float32"),
            )
            lv376 = R.call_tir(
                cls.expand_dims8,
                (lv375,),
                out_sinfo=R.Tensor((1, 128, 1, 1), dtype="float32"),
            )
            lv377 = R.call_tir(
                cls.layout_transform9,
                (lv376,),
                out_sinfo=R.Tensor((1, 32, 1, 1, 4), dtype="float32"),
            )
            lv378 = R.call_tir(
                cls.add13,
                (lv371, lv377),
                out_sinfo=R.Tensor((1, 32, 28, 28, 4), dtype="float32"),
            )
            lv379 = R.call_tir(
                cls.relu3,
                (lv378,),
                out_sinfo=R.Tensor((1, 32, 28, 28, 4), dtype="float32"),
            )
            lv380 = R.call_tir(
                cls.add11,
                (
                    stage2_unit4_bn3_moving_var,
                    R.const(1.9999999494757503e-05, "float32"),
                ),
                out_sinfo=R.Tensor((128,), dtype="float32"),
            )
            lv381 = R.call_tir(
                cls.rsqrt3, (lv380,), out_sinfo=R.Tensor((128,), dtype="float32")
            )
            lv382 = R.call_tir(
                cls.multiply10,
                (lv381, stage2_unit4_bn3_gamma),
                out_sinfo=R.Tensor((128,), dtype="float32"),
            )
            lv383 = R.call_tir(
                cls.expand_dims6,
                (lv382,),
                out_sinfo=R.Tensor((128, 1, 1), dtype="float32"),
            )
            lv384 = R.call_tir(
                cls.squeeze2, (lv383,), out_sinfo=R.Tensor((128,), dtype="float32")
            )
            lv385 = R.call_tir(
                cls.expand_dims7,
                (lv384,),
                out_sinfo=R.Tensor((128, 1, 1, 1), dtype="float32"),
            )
            lv386 = R.call_tir(
                cls.multiply12,
                (stage2_unit4_conv2_weight, lv385),
                out_sinfo=R.Tensor((128, 128, 3, 3), dtype="float32"),
            )
            lv387 = R.call_tir(
                cls.layout_transform10,
                (lv386,),
                out_sinfo=R.Tensor((32, 32, 3, 3, 4, 4), dtype="float32"),
            )
            lv388 = R.call_tir(
                cls.contrib_conv2d_NCHWc6,
                (lv379, lv387),
                out_sinfo=R.Tensor((1, 32, 28, 28, 4), dtype="float32"),
            )
            lv389 = R.call_tir(
                cls.negative3,
                (stage2_unit4_bn3_moving_mean,),
                out_sinfo=R.Tensor((128,), dtype="float32"),
            )
            lv390 = R.call_tir(
                cls.multiply10,
                (lv389, lv382),
                out_sinfo=R.Tensor((128,), dtype="float32"),
            )
            lv391 = R.call_tir(
                cls.add12,
                (lv390, stage2_unit4_bn3_beta),
                out_sinfo=R.Tensor((128,), dtype="float32"),
            )
            lv392 = R.call_tir(
                cls.expand_dims6,
                (lv391,),
                out_sinfo=R.Tensor((128, 1, 1), dtype="float32"),
            )
            lv393 = R.call_tir(
                cls.expand_dims8,
                (lv392,),
                out_sinfo=R.Tensor((1, 128, 1, 1), dtype="float32"),
            )
            lv394 = R.call_tir(
                cls.layout_transform9,
                (lv393,),
                out_sinfo=R.Tensor((1, 32, 1, 1, 4), dtype="float32"),
            )
            lv395 = R.call_tir(
                cls.add13,
                (lv388, lv394),
                out_sinfo=R.Tensor((1, 32, 28, 28, 4), dtype="float32"),
            )
            lv396 = R.call_tir(
                cls.relu3,
                (lv395,),
                out_sinfo=R.Tensor((1, 32, 28, 28, 4), dtype="float32"),
            )
            lv397 = R.call_tir(
                cls.layout_transform11,
                (stage2_unit4_conv3_weight,),
                out_sinfo=R.Tensor((128, 32, 1, 1, 4, 4), dtype="float32"),
            )
            lv398 = R.call_tir(
                cls.contrib_conv2d_NCHWc7,
                (lv396, lv397),
                out_sinfo=R.Tensor((1, 128, 28, 28, 4), dtype="float32"),
            )
            lv399 = R.call_tir(
                cls.add14,
                (lv398, lv347),
                out_sinfo=R.Tensor((1, 128, 28, 28, 4), dtype="float32"),
            )
            lv400 = R.call_tir(
                cls.add15,
                (
                    stage3_unit1_bn1_moving_var,
                    R.const(1.9999999494757503e-05, "float32"),
                ),
                out_sinfo=R.Tensor((512,), dtype="float32"),
            )
            lv401 = R.call_tir(
                cls.rsqrt4, (lv400,), out_sinfo=R.Tensor((512,), dtype="float32")
            )
            lv402 = R.call_tir(
                cls.multiply13,
                (lv401, stage3_unit1_bn1_gamma),
                out_sinfo=R.Tensor((512,), dtype="float32"),
            )
            lv403 = R.call_tir(
                cls.expand_dims9,
                (lv402,),
                out_sinfo=R.Tensor((512, 1, 1), dtype="float32"),
            )
            lv404 = R.call_tir(
                cls.expand_dims10,
                (lv403,),
                out_sinfo=R.Tensor((1, 512, 1, 1), dtype="float32"),
            )
            lv405 = R.call_tir(
                cls.layout_transform13,
                (lv404,),
                out_sinfo=R.Tensor((1, 128, 1, 1, 4), dtype="float32"),
            )
            lv406 = R.call_tir(
                cls.multiply14,
                (lv399, lv405),
                out_sinfo=R.Tensor((1, 128, 28, 28, 4), dtype="float32"),
            )
            lv407 = R.call_tir(
                cls.negative4,
                (stage3_unit1_bn1_moving_mean,),
                out_sinfo=R.Tensor((512,), dtype="float32"),
            )
            lv408 = R.call_tir(
                cls.multiply13,
                (lv407, lv402),
                out_sinfo=R.Tensor((512,), dtype="float32"),
            )
            lv409 = R.call_tir(
                cls.add16,
                (lv408, stage3_unit1_bn1_beta),
                out_sinfo=R.Tensor((512,), dtype="float32"),
            )
            lv410 = R.call_tir(
                cls.expand_dims9,
                (lv409,),
                out_sinfo=R.Tensor((512, 1, 1), dtype="float32"),
            )
            lv411 = R.call_tir(
                cls.expand_dims10,
                (lv410,),
                out_sinfo=R.Tensor((1, 512, 1, 1), dtype="float32"),
            )
            lv412 = R.call_tir(
                cls.layout_transform13,
                (lv411,),
                out_sinfo=R.Tensor((1, 128, 1, 1, 4), dtype="float32"),
            )
            lv413 = R.call_tir(
                cls.add17,
                (lv406, lv412),
                out_sinfo=R.Tensor((1, 128, 28, 28, 4), dtype="float32"),
            )
            lv414 = R.call_tir(
                cls.relu4,
                (lv413,),
                out_sinfo=R.Tensor((1, 128, 28, 28, 4), dtype="float32"),
            )
            lv415 = R.call_tir(
                cls.add8,
                (
                    stage3_unit1_bn2_moving_var,
                    R.const(1.9999999494757503e-05, "float32"),
                ),
                out_sinfo=R.Tensor((256,), dtype="float32"),
            )
            lv416 = R.call_tir(
                cls.rsqrt2, (lv415,), out_sinfo=R.Tensor((256,), dtype="float32")
            )
            lv417 = R.call_tir(
                cls.multiply7,
                (lv416, stage3_unit1_bn2_gamma),
                out_sinfo=R.Tensor((256,), dtype="float32"),
            )
            lv418 = R.call_tir(
                cls.expand_dims4,
                (lv417,),
                out_sinfo=R.Tensor((256, 1, 1), dtype="float32"),
            )
            lv419 = R.call_tir(
                cls.squeeze3, (lv418,), out_sinfo=R.Tensor((256,), dtype="float32")
            )
            lv420 = R.call_tir(
                cls.expand_dims11,
                (lv419,),
                out_sinfo=R.Tensor((256, 1, 1, 1), dtype="float32"),
            )
            lv421 = R.call_tir(
                cls.multiply16,
                (stage3_unit1_conv1_weight, lv420),
                out_sinfo=R.Tensor((256, 512, 1, 1), dtype="float32"),
            )
            lv422 = R.call_tir(
                cls.layout_transform15,
                (lv421,),
                out_sinfo=R.Tensor((64, 128, 1, 1, 4, 4), dtype="float32"),
            )
            lv423 = R.call_tir(
                cls.contrib_conv2d_NCHWc10,
                (lv414, lv422),
                out_sinfo=R.Tensor((1, 64, 14, 14, 4), dtype="float32"),
            )
            lv424 = R.call_tir(
                cls.negative2,
                (stage3_unit1_bn2_moving_mean,),
                out_sinfo=R.Tensor((256,), dtype="float32"),
            )
            lv425 = R.call_tir(
                cls.multiply7,
                (lv424, lv417),
                out_sinfo=R.Tensor((256,), dtype="float32"),
            )
            lv426 = R.call_tir(
                cls.add9,
                (lv425, stage3_unit1_bn2_beta),
                out_sinfo=R.Tensor((256,), dtype="float32"),
            )
            lv427 = R.call_tir(
                cls.expand_dims4,
                (lv426,),
                out_sinfo=R.Tensor((256, 1, 1), dtype="float32"),
            )
            lv428 = R.call_tir(
                cls.expand_dims5,
                (lv427,),
                out_sinfo=R.Tensor((1, 256, 1, 1), dtype="float32"),
            )
            lv429 = R.call_tir(
                cls.layout_transform6,
                (lv428,),
                out_sinfo=R.Tensor((1, 64, 1, 1, 4), dtype="float32"),
            )
            lv430 = R.call_tir(
                cls.add18,
                (lv423, lv429),
                out_sinfo=R.Tensor((1, 64, 14, 14, 4), dtype="float32"),
            )
            lv431 = R.call_tir(
                cls.relu5,
                (lv430,),
                out_sinfo=R.Tensor((1, 64, 14, 14, 4), dtype="float32"),
            )
            lv432 = R.call_tir(
                cls.add8,
                (
                    stage3_unit1_bn3_moving_var,
                    R.const(1.9999999494757503e-05, "float32"),
                ),
                out_sinfo=R.Tensor((256,), dtype="float32"),
            )
            lv433 = R.call_tir(
                cls.rsqrt2, (lv432,), out_sinfo=R.Tensor((256,), dtype="float32")
            )
            lv434 = R.call_tir(
                cls.multiply7,
                (lv433, stage3_unit1_bn3_gamma),
                out_sinfo=R.Tensor((256,), dtype="float32"),
            )
            lv435 = R.call_tir(
                cls.expand_dims4,
                (lv434,),
                out_sinfo=R.Tensor((256, 1, 1), dtype="float32"),
            )
            lv436 = R.call_tir(
                cls.squeeze3, (lv435,), out_sinfo=R.Tensor((256,), dtype="float32")
            )
            lv437 = R.call_tir(
                cls.expand_dims11,
                (lv436,),
                out_sinfo=R.Tensor((256, 1, 1, 1), dtype="float32"),
            )
            lv438 = R.call_tir(
                cls.multiply17,
                (stage3_unit1_conv2_weight, lv437),
                out_sinfo=R.Tensor((256, 256, 3, 3), dtype="float32"),
            )
            lv439 = R.call_tir(
                cls.layout_transform16,
                (lv438,),
                out_sinfo=R.Tensor((64, 64, 3, 3, 4, 4), dtype="float32"),
            )
            lv440 = R.call_tir(
                cls.contrib_conv2d_NCHWc11,
                (lv431, lv439),
                out_sinfo=R.Tensor((1, 64, 14, 14, 4), dtype="float32"),
            )
            lv441 = R.call_tir(
                cls.negative2,
                (stage3_unit1_bn3_moving_mean,),
                out_sinfo=R.Tensor((256,), dtype="float32"),
            )
            lv442 = R.call_tir(
                cls.multiply7,
                (lv441, lv434),
                out_sinfo=R.Tensor((256,), dtype="float32"),
            )
            lv443 = R.call_tir(
                cls.add9,
                (lv442, stage3_unit1_bn3_beta),
                out_sinfo=R.Tensor((256,), dtype="float32"),
            )
            lv444 = R.call_tir(
                cls.expand_dims4,
                (lv443,),
                out_sinfo=R.Tensor((256, 1, 1), dtype="float32"),
            )
            lv445 = R.call_tir(
                cls.expand_dims5,
                (lv444,),
                out_sinfo=R.Tensor((1, 256, 1, 1), dtype="float32"),
            )
            lv446 = R.call_tir(
                cls.layout_transform6,
                (lv445,),
                out_sinfo=R.Tensor((1, 64, 1, 1, 4), dtype="float32"),
            )
            lv447 = R.call_tir(
                cls.add18,
                (lv440, lv446),
                out_sinfo=R.Tensor((1, 64, 14, 14, 4), dtype="float32"),
            )
            lv448 = R.call_tir(
                cls.relu5,
                (lv447,),
                out_sinfo=R.Tensor((1, 64, 14, 14, 4), dtype="float32"),
            )
            lv449 = R.call_tir(
                cls.layout_transform17,
                (stage3_unit1_conv3_weight,),
                out_sinfo=R.Tensor((256, 64, 1, 1, 4, 4), dtype="float32"),
            )
            lv450 = R.call_tir(
                cls.contrib_conv2d_NCHWc12,
                (lv448, lv449),
                out_sinfo=R.Tensor((1, 256, 14, 14, 4), dtype="float32"),
            )
            lv451 = R.call_tir(
                cls.layout_transform18,
                (stage3_unit1_sc_weight,),
                out_sinfo=R.Tensor((256, 128, 1, 1, 4, 4), dtype="float32"),
            )
            lv452 = R.call_tir(
                cls.contrib_conv2d_NCHWc13,
                (lv414, lv451),
                out_sinfo=R.Tensor((1, 256, 14, 14, 4), dtype="float32"),
            )
            lv453 = R.call_tir(
                cls.add19,
                (lv450, lv452),
                out_sinfo=R.Tensor((1, 256, 14, 14, 4), dtype="float32"),
            )
            lv454 = R.call_tir(
                cls.add20,
                (
                    stage3_unit2_bn1_moving_var,
                    R.const(1.9999999494757503e-05, "float32"),
                ),
                out_sinfo=R.Tensor((1024,), dtype="float32"),
            )
            lv455 = R.call_tir(
                cls.rsqrt5, (lv454,), out_sinfo=R.Tensor((1024,), dtype="float32")
            )
            lv456 = R.call_tir(
                cls.multiply18,
                (lv455, stage3_unit2_bn1_gamma),
                out_sinfo=R.Tensor((1024,), dtype="float32"),
            )
            lv457 = R.call_tir(
                cls.expand_dims12,
                (lv456,),
                out_sinfo=R.Tensor((1024, 1, 1), dtype="float32"),
            )
            lv458 = R.call_tir(
                cls.expand_dims13,
                (lv457,),
                out_sinfo=R.Tensor((1, 1024, 1, 1), dtype="float32"),
            )
            lv459 = R.call_tir(
                cls.layout_transform19,
                (lv458,),
                out_sinfo=R.Tensor((1, 256, 1, 1, 4), dtype="float32"),
            )
            lv460 = R.call_tir(
                cls.multiply19,
                (lv453, lv459),
                out_sinfo=R.Tensor((1, 256, 14, 14, 4), dtype="float32"),
            )
            lv461 = R.call_tir(
                cls.negative5,
                (stage3_unit2_bn1_moving_mean,),
                out_sinfo=R.Tensor((1024,), dtype="float32"),
            )
            lv462 = R.call_tir(
                cls.multiply18,
                (lv461, lv456),
                out_sinfo=R.Tensor((1024,), dtype="float32"),
            )
            lv463 = R.call_tir(
                cls.add21,
                (lv462, stage3_unit2_bn1_beta),
                out_sinfo=R.Tensor((1024,), dtype="float32"),
            )
            lv464 = R.call_tir(
                cls.expand_dims12,
                (lv463,),
                out_sinfo=R.Tensor((1024, 1, 1), dtype="float32"),
            )
            lv465 = R.call_tir(
                cls.expand_dims13,
                (lv464,),
                out_sinfo=R.Tensor((1, 1024, 1, 1), dtype="float32"),
            )
            lv466 = R.call_tir(
                cls.layout_transform19,
                (lv465,),
                out_sinfo=R.Tensor((1, 256, 1, 1, 4), dtype="float32"),
            )
            lv467 = R.call_tir(
                cls.add22,
                (lv460, lv466),
                out_sinfo=R.Tensor((1, 256, 14, 14, 4), dtype="float32"),
            )
            lv468 = R.call_tir(
                cls.relu6,
                (lv467,),
                out_sinfo=R.Tensor((1, 256, 14, 14, 4), dtype="float32"),
            )
            lv469 = R.call_tir(
                cls.add8,
                (
                    stage3_unit2_bn2_moving_var,
                    R.const(1.9999999494757503e-05, "float32"),
                ),
                out_sinfo=R.Tensor((256,), dtype="float32"),
            )
            lv470 = R.call_tir(
                cls.rsqrt2, (lv469,), out_sinfo=R.Tensor((256,), dtype="float32")
            )
            lv471 = R.call_tir(
                cls.multiply7,
                (lv470, stage3_unit2_bn2_gamma),
                out_sinfo=R.Tensor((256,), dtype="float32"),
            )
            lv472 = R.call_tir(
                cls.expand_dims4,
                (lv471,),
                out_sinfo=R.Tensor((256, 1, 1), dtype="float32"),
            )
            lv473 = R.call_tir(
                cls.squeeze3, (lv472,), out_sinfo=R.Tensor((256,), dtype="float32")
            )
            lv474 = R.call_tir(
                cls.expand_dims11,
                (lv473,),
                out_sinfo=R.Tensor((256, 1, 1, 1), dtype="float32"),
            )
            lv475 = R.call_tir(
                cls.multiply20,
                (stage3_unit2_conv1_weight, lv474),
                out_sinfo=R.Tensor((256, 1024, 1, 1), dtype="float32"),
            )
            lv476 = R.call_tir(
                cls.layout_transform20,
                (lv475,),
                out_sinfo=R.Tensor((64, 256, 1, 1, 4, 4), dtype="float32"),
            )
            lv477 = R.call_tir(
                cls.contrib_conv2d_NCHWc14,
                (lv468, lv476),
                out_sinfo=R.Tensor((1, 64, 14, 14, 4), dtype="float32"),
            )
            lv478 = R.call_tir(
                cls.negative2,
                (stage3_unit2_bn2_moving_mean,),
                out_sinfo=R.Tensor((256,), dtype="float32"),
            )
            lv479 = R.call_tir(
                cls.multiply7,
                (lv478, lv471),
                out_sinfo=R.Tensor((256,), dtype="float32"),
            )
            lv480 = R.call_tir(
                cls.add9,
                (lv479, stage3_unit2_bn2_beta),
                out_sinfo=R.Tensor((256,), dtype="float32"),
            )
            lv481 = R.call_tir(
                cls.expand_dims4,
                (lv480,),
                out_sinfo=R.Tensor((256, 1, 1), dtype="float32"),
            )
            lv482 = R.call_tir(
                cls.expand_dims5,
                (lv481,),
                out_sinfo=R.Tensor((1, 256, 1, 1), dtype="float32"),
            )
            lv483 = R.call_tir(
                cls.layout_transform6,
                (lv482,),
                out_sinfo=R.Tensor((1, 64, 1, 1, 4), dtype="float32"),
            )
            lv484 = R.call_tir(
                cls.add18,
                (lv477, lv483),
                out_sinfo=R.Tensor((1, 64, 14, 14, 4), dtype="float32"),
            )
            lv485 = R.call_tir(
                cls.relu5,
                (lv484,),
                out_sinfo=R.Tensor((1, 64, 14, 14, 4), dtype="float32"),
            )
            lv486 = R.call_tir(
                cls.add8,
                (
                    stage3_unit2_bn3_moving_var,
                    R.const(1.9999999494757503e-05, "float32"),
                ),
                out_sinfo=R.Tensor((256,), dtype="float32"),
            )
            lv487 = R.call_tir(
                cls.rsqrt2, (lv486,), out_sinfo=R.Tensor((256,), dtype="float32")
            )
            lv488 = R.call_tir(
                cls.multiply7,
                (lv487, stage3_unit2_bn3_gamma),
                out_sinfo=R.Tensor((256,), dtype="float32"),
            )
            lv489 = R.call_tir(
                cls.expand_dims4,
                (lv488,),
                out_sinfo=R.Tensor((256, 1, 1), dtype="float32"),
            )
            lv490 = R.call_tir(
                cls.squeeze3, (lv489,), out_sinfo=R.Tensor((256,), dtype="float32")
            )
            lv491 = R.call_tir(
                cls.expand_dims11,
                (lv490,),
                out_sinfo=R.Tensor((256, 1, 1, 1), dtype="float32"),
            )
            lv492 = R.call_tir(
                cls.multiply17,
                (stage3_unit2_conv2_weight, lv491),
                out_sinfo=R.Tensor((256, 256, 3, 3), dtype="float32"),
            )
            lv493 = R.call_tir(
                cls.layout_transform16,
                (lv492,),
                out_sinfo=R.Tensor((64, 64, 3, 3, 4, 4), dtype="float32"),
            )
            lv494 = R.call_tir(
                cls.contrib_conv2d_NCHWc11,
                (lv485, lv493),
                out_sinfo=R.Tensor((1, 64, 14, 14, 4), dtype="float32"),
            )
            lv495 = R.call_tir(
                cls.negative2,
                (stage3_unit2_bn3_moving_mean,),
                out_sinfo=R.Tensor((256,), dtype="float32"),
            )
            lv496 = R.call_tir(
                cls.multiply7,
                (lv495, lv488),
                out_sinfo=R.Tensor((256,), dtype="float32"),
            )
            lv497 = R.call_tir(
                cls.add9,
                (lv496, stage3_unit2_bn3_beta),
                out_sinfo=R.Tensor((256,), dtype="float32"),
            )
            lv498 = R.call_tir(
                cls.expand_dims4,
                (lv497,),
                out_sinfo=R.Tensor((256, 1, 1), dtype="float32"),
            )
            lv499 = R.call_tir(
                cls.expand_dims5,
                (lv498,),
                out_sinfo=R.Tensor((1, 256, 1, 1), dtype="float32"),
            )
            lv500 = R.call_tir(
                cls.layout_transform6,
                (lv499,),
                out_sinfo=R.Tensor((1, 64, 1, 1, 4), dtype="float32"),
            )
            lv501 = R.call_tir(
                cls.add18,
                (lv494, lv500),
                out_sinfo=R.Tensor((1, 64, 14, 14, 4), dtype="float32"),
            )
            lv502 = R.call_tir(
                cls.relu5,
                (lv501,),
                out_sinfo=R.Tensor((1, 64, 14, 14, 4), dtype="float32"),
            )
            lv503 = R.call_tir(
                cls.layout_transform17,
                (stage3_unit2_conv3_weight,),
                out_sinfo=R.Tensor((256, 64, 1, 1, 4, 4), dtype="float32"),
            )
            lv504 = R.call_tir(
                cls.contrib_conv2d_NCHWc12,
                (lv502, lv503),
                out_sinfo=R.Tensor((1, 256, 14, 14, 4), dtype="float32"),
            )
            lv505 = R.call_tir(
                cls.add19,
                (lv504, lv453),
                out_sinfo=R.Tensor((1, 256, 14, 14, 4), dtype="float32"),
            )
            lv506 = R.call_tir(
                cls.add20,
                (
                    stage3_unit3_bn1_moving_var,
                    R.const(1.9999999494757503e-05, "float32"),
                ),
                out_sinfo=R.Tensor((1024,), dtype="float32"),
            )
            lv507 = R.call_tir(
                cls.rsqrt5, (lv506,), out_sinfo=R.Tensor((1024,), dtype="float32")
            )
            lv508 = R.call_tir(
                cls.multiply18,
                (lv507, stage3_unit3_bn1_gamma),
                out_sinfo=R.Tensor((1024,), dtype="float32"),
            )
            lv509 = R.call_tir(
                cls.expand_dims12,
                (lv508,),
                out_sinfo=R.Tensor((1024, 1, 1), dtype="float32"),
            )
            lv510 = R.call_tir(
                cls.expand_dims13,
                (lv509,),
                out_sinfo=R.Tensor((1, 1024, 1, 1), dtype="float32"),
            )
            lv511 = R.call_tir(
                cls.layout_transform19,
                (lv510,),
                out_sinfo=R.Tensor((1, 256, 1, 1, 4), dtype="float32"),
            )
            lv512 = R.call_tir(
                cls.multiply19,
                (lv505, lv511),
                out_sinfo=R.Tensor((1, 256, 14, 14, 4), dtype="float32"),
            )
            lv513 = R.call_tir(
                cls.negative5,
                (stage3_unit3_bn1_moving_mean,),
                out_sinfo=R.Tensor((1024,), dtype="float32"),
            )
            lv514 = R.call_tir(
                cls.multiply18,
                (lv513, lv508),
                out_sinfo=R.Tensor((1024,), dtype="float32"),
            )
            lv515 = R.call_tir(
                cls.add21,
                (lv514, stage3_unit3_bn1_beta),
                out_sinfo=R.Tensor((1024,), dtype="float32"),
            )
            lv516 = R.call_tir(
                cls.expand_dims12,
                (lv515,),
                out_sinfo=R.Tensor((1024, 1, 1), dtype="float32"),
            )
            lv517 = R.call_tir(
                cls.expand_dims13,
                (lv516,),
                out_sinfo=R.Tensor((1, 1024, 1, 1), dtype="float32"),
            )
            lv518 = R.call_tir(
                cls.layout_transform19,
                (lv517,),
                out_sinfo=R.Tensor((1, 256, 1, 1, 4), dtype="float32"),
            )
            lv519 = R.call_tir(
                cls.add22,
                (lv512, lv518),
                out_sinfo=R.Tensor((1, 256, 14, 14, 4), dtype="float32"),
            )
            lv520 = R.call_tir(
                cls.relu6,
                (lv519,),
                out_sinfo=R.Tensor((1, 256, 14, 14, 4), dtype="float32"),
            )
            lv521 = R.call_tir(
                cls.add8,
                (
                    stage3_unit3_bn2_moving_var,
                    R.const(1.9999999494757503e-05, "float32"),
                ),
                out_sinfo=R.Tensor((256,), dtype="float32"),
            )
            lv522 = R.call_tir(
                cls.rsqrt2, (lv521,), out_sinfo=R.Tensor((256,), dtype="float32")
            )
            lv523 = R.call_tir(
                cls.multiply7,
                (lv522, stage3_unit3_bn2_gamma),
                out_sinfo=R.Tensor((256,), dtype="float32"),
            )
            lv524 = R.call_tir(
                cls.expand_dims4,
                (lv523,),
                out_sinfo=R.Tensor((256, 1, 1), dtype="float32"),
            )
            lv525 = R.call_tir(
                cls.squeeze3, (lv524,), out_sinfo=R.Tensor((256,), dtype="float32")
            )
            lv526 = R.call_tir(
                cls.expand_dims11,
                (lv525,),
                out_sinfo=R.Tensor((256, 1, 1, 1), dtype="float32"),
            )
            lv527 = R.call_tir(
                cls.multiply20,
                (stage3_unit3_conv1_weight, lv526),
                out_sinfo=R.Tensor((256, 1024, 1, 1), dtype="float32"),
            )
            lv528 = R.call_tir(
                cls.layout_transform20,
                (lv527,),
                out_sinfo=R.Tensor((64, 256, 1, 1, 4, 4), dtype="float32"),
            )
            lv529 = R.call_tir(
                cls.contrib_conv2d_NCHWc14,
                (lv520, lv528),
                out_sinfo=R.Tensor((1, 64, 14, 14, 4), dtype="float32"),
            )
            lv530 = R.call_tir(
                cls.negative2,
                (stage3_unit3_bn2_moving_mean,),
                out_sinfo=R.Tensor((256,), dtype="float32"),
            )
            lv531 = R.call_tir(
                cls.multiply7,
                (lv530, lv523),
                out_sinfo=R.Tensor((256,), dtype="float32"),
            )
            lv532 = R.call_tir(
                cls.add9,
                (lv531, stage3_unit3_bn2_beta),
                out_sinfo=R.Tensor((256,), dtype="float32"),
            )
            lv533 = R.call_tir(
                cls.expand_dims4,
                (lv532,),
                out_sinfo=R.Tensor((256, 1, 1), dtype="float32"),
            )
            lv534 = R.call_tir(
                cls.expand_dims5,
                (lv533,),
                out_sinfo=R.Tensor((1, 256, 1, 1), dtype="float32"),
            )
            lv535 = R.call_tir(
                cls.layout_transform6,
                (lv534,),
                out_sinfo=R.Tensor((1, 64, 1, 1, 4), dtype="float32"),
            )
            lv536 = R.call_tir(
                cls.add18,
                (lv529, lv535),
                out_sinfo=R.Tensor((1, 64, 14, 14, 4), dtype="float32"),
            )
            lv537 = R.call_tir(
                cls.relu5,
                (lv536,),
                out_sinfo=R.Tensor((1, 64, 14, 14, 4), dtype="float32"),
            )
            lv538 = R.call_tir(
                cls.add8,
                (
                    stage3_unit3_bn3_moving_var,
                    R.const(1.9999999494757503e-05, "float32"),
                ),
                out_sinfo=R.Tensor((256,), dtype="float32"),
            )
            lv539 = R.call_tir(
                cls.rsqrt2, (lv538,), out_sinfo=R.Tensor((256,), dtype="float32")
            )
            lv540 = R.call_tir(
                cls.multiply7,
                (lv539, stage3_unit3_bn3_gamma),
                out_sinfo=R.Tensor((256,), dtype="float32"),
            )
            lv541 = R.call_tir(
                cls.expand_dims4,
                (lv540,),
                out_sinfo=R.Tensor((256, 1, 1), dtype="float32"),
            )
            lv542 = R.call_tir(
                cls.squeeze3, (lv541,), out_sinfo=R.Tensor((256,), dtype="float32")
            )
            lv543 = R.call_tir(
                cls.expand_dims11,
                (lv542,),
                out_sinfo=R.Tensor((256, 1, 1, 1), dtype="float32"),
            )
            lv544 = R.call_tir(
                cls.multiply17,
                (stage3_unit3_conv2_weight, lv543),
                out_sinfo=R.Tensor((256, 256, 3, 3), dtype="float32"),
            )
            lv545 = R.call_tir(
                cls.layout_transform16,
                (lv544,),
                out_sinfo=R.Tensor((64, 64, 3, 3, 4, 4), dtype="float32"),
            )
            lv546 = R.call_tir(
                cls.contrib_conv2d_NCHWc11,
                (lv537, lv545),
                out_sinfo=R.Tensor((1, 64, 14, 14, 4), dtype="float32"),
            )
            lv547 = R.call_tir(
                cls.negative2,
                (stage3_unit3_bn3_moving_mean,),
                out_sinfo=R.Tensor((256,), dtype="float32"),
            )
            lv548 = R.call_tir(
                cls.multiply7,
                (lv547, lv540),
                out_sinfo=R.Tensor((256,), dtype="float32"),
            )
            lv549 = R.call_tir(
                cls.add9,
                (lv548, stage3_unit3_bn3_beta),
                out_sinfo=R.Tensor((256,), dtype="float32"),
            )
            lv550 = R.call_tir(
                cls.expand_dims4,
                (lv549,),
                out_sinfo=R.Tensor((256, 1, 1), dtype="float32"),
            )
            lv551 = R.call_tir(
                cls.expand_dims5,
                (lv550,),
                out_sinfo=R.Tensor((1, 256, 1, 1), dtype="float32"),
            )
            lv552 = R.call_tir(
                cls.layout_transform6,
                (lv551,),
                out_sinfo=R.Tensor((1, 64, 1, 1, 4), dtype="float32"),
            )
            lv553 = R.call_tir(
                cls.add18,
                (lv546, lv552),
                out_sinfo=R.Tensor((1, 64, 14, 14, 4), dtype="float32"),
            )
            lv554 = R.call_tir(
                cls.relu5,
                (lv553,),
                out_sinfo=R.Tensor((1, 64, 14, 14, 4), dtype="float32"),
            )
            lv555 = R.call_tir(
                cls.layout_transform17,
                (stage3_unit3_conv3_weight,),
                out_sinfo=R.Tensor((256, 64, 1, 1, 4, 4), dtype="float32"),
            )
            lv556 = R.call_tir(
                cls.contrib_conv2d_NCHWc12,
                (lv554, lv555),
                out_sinfo=R.Tensor((1, 256, 14, 14, 4), dtype="float32"),
            )
            lv557 = R.call_tir(
                cls.add19,
                (lv556, lv505),
                out_sinfo=R.Tensor((1, 256, 14, 14, 4), dtype="float32"),
            )
            lv558 = R.call_tir(
                cls.add20,
                (
                    stage3_unit4_bn1_moving_var,
                    R.const(1.9999999494757503e-05, "float32"),
                ),
                out_sinfo=R.Tensor((1024,), dtype="float32"),
            )
            lv559 = R.call_tir(
                cls.rsqrt5, (lv558,), out_sinfo=R.Tensor((1024,), dtype="float32")
            )
            lv560 = R.call_tir(
                cls.multiply18,
                (lv559, stage3_unit4_bn1_gamma),
                out_sinfo=R.Tensor((1024,), dtype="float32"),
            )
            lv561 = R.call_tir(
                cls.expand_dims12,
                (lv560,),
                out_sinfo=R.Tensor((1024, 1, 1), dtype="float32"),
            )
            lv562 = R.call_tir(
                cls.expand_dims13,
                (lv561,),
                out_sinfo=R.Tensor((1, 1024, 1, 1), dtype="float32"),
            )
            lv563 = R.call_tir(
                cls.layout_transform19,
                (lv562,),
                out_sinfo=R.Tensor((1, 256, 1, 1, 4), dtype="float32"),
            )
            lv564 = R.call_tir(
                cls.multiply19,
                (lv557, lv563),
                out_sinfo=R.Tensor((1, 256, 14, 14, 4), dtype="float32"),
            )
            lv565 = R.call_tir(
                cls.negative5,
                (stage3_unit4_bn1_moving_mean,),
                out_sinfo=R.Tensor((1024,), dtype="float32"),
            )
            lv566 = R.call_tir(
                cls.multiply18,
                (lv565, lv560),
                out_sinfo=R.Tensor((1024,), dtype="float32"),
            )
            lv567 = R.call_tir(
                cls.add21,
                (lv566, stage3_unit4_bn1_beta),
                out_sinfo=R.Tensor((1024,), dtype="float32"),
            )
            lv568 = R.call_tir(
                cls.expand_dims12,
                (lv567,),
                out_sinfo=R.Tensor((1024, 1, 1), dtype="float32"),
            )
            lv569 = R.call_tir(
                cls.expand_dims13,
                (lv568,),
                out_sinfo=R.Tensor((1, 1024, 1, 1), dtype="float32"),
            )
            lv570 = R.call_tir(
                cls.layout_transform19,
                (lv569,),
                out_sinfo=R.Tensor((1, 256, 1, 1, 4), dtype="float32"),
            )
            lv571 = R.call_tir(
                cls.add22,
                (lv564, lv570),
                out_sinfo=R.Tensor((1, 256, 14, 14, 4), dtype="float32"),
            )
            lv572 = R.call_tir(
                cls.relu6,
                (lv571,),
                out_sinfo=R.Tensor((1, 256, 14, 14, 4), dtype="float32"),
            )
            lv573 = R.call_tir(
                cls.add8,
                (
                    stage3_unit4_bn2_moving_var,
                    R.const(1.9999999494757503e-05, "float32"),
                ),
                out_sinfo=R.Tensor((256,), dtype="float32"),
            )
            lv574 = R.call_tir(
                cls.rsqrt2, (lv573,), out_sinfo=R.Tensor((256,), dtype="float32")
            )
            lv575 = R.call_tir(
                cls.multiply7,
                (lv574, stage3_unit4_bn2_gamma),
                out_sinfo=R.Tensor((256,), dtype="float32"),
            )
            lv576 = R.call_tir(
                cls.expand_dims4,
                (lv575,),
                out_sinfo=R.Tensor((256, 1, 1), dtype="float32"),
            )
            lv577 = R.call_tir(
                cls.squeeze3, (lv576,), out_sinfo=R.Tensor((256,), dtype="float32")
            )
            lv578 = R.call_tir(
                cls.expand_dims11,
                (lv577,),
                out_sinfo=R.Tensor((256, 1, 1, 1), dtype="float32"),
            )
            lv579 = R.call_tir(
                cls.multiply20,
                (stage3_unit4_conv1_weight, lv578),
                out_sinfo=R.Tensor((256, 1024, 1, 1), dtype="float32"),
            )
            lv580 = R.call_tir(
                cls.layout_transform20,
                (lv579,),
                out_sinfo=R.Tensor((64, 256, 1, 1, 4, 4), dtype="float32"),
            )
            lv581 = R.call_tir(
                cls.contrib_conv2d_NCHWc14,
                (lv572, lv580),
                out_sinfo=R.Tensor((1, 64, 14, 14, 4), dtype="float32"),
            )
            lv582 = R.call_tir(
                cls.negative2,
                (stage3_unit4_bn2_moving_mean,),
                out_sinfo=R.Tensor((256,), dtype="float32"),
            )
            lv583 = R.call_tir(
                cls.multiply7,
                (lv582, lv575),
                out_sinfo=R.Tensor((256,), dtype="float32"),
            )
            lv584 = R.call_tir(
                cls.add9,
                (lv583, stage3_unit4_bn2_beta),
                out_sinfo=R.Tensor((256,), dtype="float32"),
            )
            lv585 = R.call_tir(
                cls.expand_dims4,
                (lv584,),
                out_sinfo=R.Tensor((256, 1, 1), dtype="float32"),
            )
            lv586 = R.call_tir(
                cls.expand_dims5,
                (lv585,),
                out_sinfo=R.Tensor((1, 256, 1, 1), dtype="float32"),
            )
            lv587 = R.call_tir(
                cls.layout_transform6,
                (lv586,),
                out_sinfo=R.Tensor((1, 64, 1, 1, 4), dtype="float32"),
            )
            lv588 = R.call_tir(
                cls.add18,
                (lv581, lv587),
                out_sinfo=R.Tensor((1, 64, 14, 14, 4), dtype="float32"),
            )
            lv589 = R.call_tir(
                cls.relu5,
                (lv588,),
                out_sinfo=R.Tensor((1, 64, 14, 14, 4), dtype="float32"),
            )
            lv590 = R.call_tir(
                cls.add8,
                (
                    stage3_unit4_bn3_moving_var,
                    R.const(1.9999999494757503e-05, "float32"),
                ),
                out_sinfo=R.Tensor((256,), dtype="float32"),
            )
            lv591 = R.call_tir(
                cls.rsqrt2, (lv590,), out_sinfo=R.Tensor((256,), dtype="float32")
            )
            lv592 = R.call_tir(
                cls.multiply7,
                (lv591, stage3_unit4_bn3_gamma),
                out_sinfo=R.Tensor((256,), dtype="float32"),
            )
            lv593 = R.call_tir(
                cls.expand_dims4,
                (lv592,),
                out_sinfo=R.Tensor((256, 1, 1), dtype="float32"),
            )
            lv594 = R.call_tir(
                cls.squeeze3, (lv593,), out_sinfo=R.Tensor((256,), dtype="float32")
            )
            lv595 = R.call_tir(
                cls.expand_dims11,
                (lv594,),
                out_sinfo=R.Tensor((256, 1, 1, 1), dtype="float32"),
            )
            lv596 = R.call_tir(
                cls.multiply17,
                (stage3_unit4_conv2_weight, lv595),
                out_sinfo=R.Tensor((256, 256, 3, 3), dtype="float32"),
            )
            lv597 = R.call_tir(
                cls.layout_transform16,
                (lv596,),
                out_sinfo=R.Tensor((64, 64, 3, 3, 4, 4), dtype="float32"),
            )
            lv598 = R.call_tir(
                cls.contrib_conv2d_NCHWc11,
                (lv589, lv597),
                out_sinfo=R.Tensor((1, 64, 14, 14, 4), dtype="float32"),
            )
            lv599 = R.call_tir(
                cls.negative2,
                (stage3_unit4_bn3_moving_mean,),
                out_sinfo=R.Tensor((256,), dtype="float32"),
            )
            lv600 = R.call_tir(
                cls.multiply7,
                (lv599, lv592),
                out_sinfo=R.Tensor((256,), dtype="float32"),
            )
            lv601 = R.call_tir(
                cls.add9,
                (lv600, stage3_unit4_bn3_beta),
                out_sinfo=R.Tensor((256,), dtype="float32"),
            )
            lv602 = R.call_tir(
                cls.expand_dims4,
                (lv601,),
                out_sinfo=R.Tensor((256, 1, 1), dtype="float32"),
            )
            lv603 = R.call_tir(
                cls.expand_dims5,
                (lv602,),
                out_sinfo=R.Tensor((1, 256, 1, 1), dtype="float32"),
            )
            lv604 = R.call_tir(
                cls.layout_transform6,
                (lv603,),
                out_sinfo=R.Tensor((1, 64, 1, 1, 4), dtype="float32"),
            )
            lv605 = R.call_tir(
                cls.add18,
                (lv598, lv604),
                out_sinfo=R.Tensor((1, 64, 14, 14, 4), dtype="float32"),
            )
            lv606 = R.call_tir(
                cls.relu5,
                (lv605,),
                out_sinfo=R.Tensor((1, 64, 14, 14, 4), dtype="float32"),
            )
            lv607 = R.call_tir(
                cls.layout_transform17,
                (stage3_unit4_conv3_weight,),
                out_sinfo=R.Tensor((256, 64, 1, 1, 4, 4), dtype="float32"),
            )
            lv608 = R.call_tir(
                cls.contrib_conv2d_NCHWc12,
                (lv606, lv607),
                out_sinfo=R.Tensor((1, 256, 14, 14, 4), dtype="float32"),
            )
            lv609 = R.call_tir(
                cls.add19,
                (lv608, lv557),
                out_sinfo=R.Tensor((1, 256, 14, 14, 4), dtype="float32"),
            )
            lv610 = R.call_tir(
                cls.add20,
                (
                    stage3_unit5_bn1_moving_var,
                    R.const(1.9999999494757503e-05, "float32"),
                ),
                out_sinfo=R.Tensor((1024,), dtype="float32"),
            )
            lv611 = R.call_tir(
                cls.rsqrt5, (lv610,), out_sinfo=R.Tensor((1024,), dtype="float32")
            )
            lv612 = R.call_tir(
                cls.multiply18,
                (lv611, stage3_unit5_bn1_gamma),
                out_sinfo=R.Tensor((1024,), dtype="float32"),
            )
            lv613 = R.call_tir(
                cls.expand_dims12,
                (lv612,),
                out_sinfo=R.Tensor((1024, 1, 1), dtype="float32"),
            )
            lv614 = R.call_tir(
                cls.expand_dims13,
                (lv613,),
                out_sinfo=R.Tensor((1, 1024, 1, 1), dtype="float32"),
            )
            lv615 = R.call_tir(
                cls.layout_transform19,
                (lv614,),
                out_sinfo=R.Tensor((1, 256, 1, 1, 4), dtype="float32"),
            )
            lv616 = R.call_tir(
                cls.multiply19,
                (lv609, lv615),
                out_sinfo=R.Tensor((1, 256, 14, 14, 4), dtype="float32"),
            )
            lv617 = R.call_tir(
                cls.negative5,
                (stage3_unit5_bn1_moving_mean,),
                out_sinfo=R.Tensor((1024,), dtype="float32"),
            )
            lv618 = R.call_tir(
                cls.multiply18,
                (lv617, lv612),
                out_sinfo=R.Tensor((1024,), dtype="float32"),
            )
            lv619 = R.call_tir(
                cls.add21,
                (lv618, stage3_unit5_bn1_beta),
                out_sinfo=R.Tensor((1024,), dtype="float32"),
            )
            lv620 = R.call_tir(
                cls.expand_dims12,
                (lv619,),
                out_sinfo=R.Tensor((1024, 1, 1), dtype="float32"),
            )
            lv621 = R.call_tir(
                cls.expand_dims13,
                (lv620,),
                out_sinfo=R.Tensor((1, 1024, 1, 1), dtype="float32"),
            )
            lv622 = R.call_tir(
                cls.layout_transform19,
                (lv621,),
                out_sinfo=R.Tensor((1, 256, 1, 1, 4), dtype="float32"),
            )
            lv623 = R.call_tir(
                cls.add22,
                (lv616, lv622),
                out_sinfo=R.Tensor((1, 256, 14, 14, 4), dtype="float32"),
            )
            lv624 = R.call_tir(
                cls.relu6,
                (lv623,),
                out_sinfo=R.Tensor((1, 256, 14, 14, 4), dtype="float32"),
            )
            lv625 = R.call_tir(
                cls.add8,
                (
                    stage3_unit5_bn2_moving_var,
                    R.const(1.9999999494757503e-05, "float32"),
                ),
                out_sinfo=R.Tensor((256,), dtype="float32"),
            )
            lv626 = R.call_tir(
                cls.rsqrt2, (lv625,), out_sinfo=R.Tensor((256,), dtype="float32")
            )
            lv627 = R.call_tir(
                cls.multiply7,
                (lv626, stage3_unit5_bn2_gamma),
                out_sinfo=R.Tensor((256,), dtype="float32"),
            )
            lv628 = R.call_tir(
                cls.expand_dims4,
                (lv627,),
                out_sinfo=R.Tensor((256, 1, 1), dtype="float32"),
            )
            lv629 = R.call_tir(
                cls.squeeze3, (lv628,), out_sinfo=R.Tensor((256,), dtype="float32")
            )
            lv630 = R.call_tir(
                cls.expand_dims11,
                (lv629,),
                out_sinfo=R.Tensor((256, 1, 1, 1), dtype="float32"),
            )
            lv631 = R.call_tir(
                cls.multiply20,
                (stage3_unit5_conv1_weight, lv630),
                out_sinfo=R.Tensor((256, 1024, 1, 1), dtype="float32"),
            )
            lv632 = R.call_tir(
                cls.layout_transform20,
                (lv631,),
                out_sinfo=R.Tensor((64, 256, 1, 1, 4, 4), dtype="float32"),
            )
            lv633 = R.call_tir(
                cls.contrib_conv2d_NCHWc14,
                (lv624, lv632),
                out_sinfo=R.Tensor((1, 64, 14, 14, 4), dtype="float32"),
            )
            lv634 = R.call_tir(
                cls.negative2,
                (stage3_unit5_bn2_moving_mean,),
                out_sinfo=R.Tensor((256,), dtype="float32"),
            )
            lv635 = R.call_tir(
                cls.multiply7,
                (lv634, lv627),
                out_sinfo=R.Tensor((256,), dtype="float32"),
            )
            lv636 = R.call_tir(
                cls.add9,
                (lv635, stage3_unit5_bn2_beta),
                out_sinfo=R.Tensor((256,), dtype="float32"),
            )
            lv637 = R.call_tir(
                cls.expand_dims4,
                (lv636,),
                out_sinfo=R.Tensor((256, 1, 1), dtype="float32"),
            )
            lv638 = R.call_tir(
                cls.expand_dims5,
                (lv637,),
                out_sinfo=R.Tensor((1, 256, 1, 1), dtype="float32"),
            )
            lv639 = R.call_tir(
                cls.layout_transform6,
                (lv638,),
                out_sinfo=R.Tensor((1, 64, 1, 1, 4), dtype="float32"),
            )
            lv640 = R.call_tir(
                cls.add18,
                (lv633, lv639),
                out_sinfo=R.Tensor((1, 64, 14, 14, 4), dtype="float32"),
            )
            lv641 = R.call_tir(
                cls.relu5,
                (lv640,),
                out_sinfo=R.Tensor((1, 64, 14, 14, 4), dtype="float32"),
            )
            lv642 = R.call_tir(
                cls.add8,
                (
                    stage3_unit5_bn3_moving_var,
                    R.const(1.9999999494757503e-05, "float32"),
                ),
                out_sinfo=R.Tensor((256,), dtype="float32"),
            )
            lv643 = R.call_tir(
                cls.rsqrt2, (lv642,), out_sinfo=R.Tensor((256,), dtype="float32")
            )
            lv644 = R.call_tir(
                cls.multiply7,
                (lv643, stage3_unit5_bn3_gamma),
                out_sinfo=R.Tensor((256,), dtype="float32"),
            )
            lv645 = R.call_tir(
                cls.expand_dims4,
                (lv644,),
                out_sinfo=R.Tensor((256, 1, 1), dtype="float32"),
            )
            lv646 = R.call_tir(
                cls.squeeze3, (lv645,), out_sinfo=R.Tensor((256,), dtype="float32")
            )
            lv647 = R.call_tir(
                cls.expand_dims11,
                (lv646,),
                out_sinfo=R.Tensor((256, 1, 1, 1), dtype="float32"),
            )
            lv648 = R.call_tir(
                cls.multiply17,
                (stage3_unit5_conv2_weight, lv647),
                out_sinfo=R.Tensor((256, 256, 3, 3), dtype="float32"),
            )
            lv649 = R.call_tir(
                cls.layout_transform16,
                (lv648,),
                out_sinfo=R.Tensor((64, 64, 3, 3, 4, 4), dtype="float32"),
            )
            lv650 = R.call_tir(
                cls.contrib_conv2d_NCHWc11,
                (lv641, lv649),
                out_sinfo=R.Tensor((1, 64, 14, 14, 4), dtype="float32"),
            )
            lv651 = R.call_tir(
                cls.negative2,
                (stage3_unit5_bn3_moving_mean,),
                out_sinfo=R.Tensor((256,), dtype="float32"),
            )
            lv652 = R.call_tir(
                cls.multiply7,
                (lv651, lv644),
                out_sinfo=R.Tensor((256,), dtype="float32"),
            )
            lv653 = R.call_tir(
                cls.add9,
                (lv652, stage3_unit5_bn3_beta),
                out_sinfo=R.Tensor((256,), dtype="float32"),
            )
            lv654 = R.call_tir(
                cls.expand_dims4,
                (lv653,),
                out_sinfo=R.Tensor((256, 1, 1), dtype="float32"),
            )
            lv655 = R.call_tir(
                cls.expand_dims5,
                (lv654,),
                out_sinfo=R.Tensor((1, 256, 1, 1), dtype="float32"),
            )
            lv656 = R.call_tir(
                cls.layout_transform6,
                (lv655,),
                out_sinfo=R.Tensor((1, 64, 1, 1, 4), dtype="float32"),
            )
            lv657 = R.call_tir(
                cls.add18,
                (lv650, lv656),
                out_sinfo=R.Tensor((1, 64, 14, 14, 4), dtype="float32"),
            )
            lv658 = R.call_tir(
                cls.relu5,
                (lv657,),
                out_sinfo=R.Tensor((1, 64, 14, 14, 4), dtype="float32"),
            )
            lv659 = R.call_tir(
                cls.layout_transform17,
                (stage3_unit5_conv3_weight,),
                out_sinfo=R.Tensor((256, 64, 1, 1, 4, 4), dtype="float32"),
            )
            lv660 = R.call_tir(
                cls.contrib_conv2d_NCHWc12,
                (lv658, lv659),
                out_sinfo=R.Tensor((1, 256, 14, 14, 4), dtype="float32"),
            )
            lv661 = R.call_tir(
                cls.add19,
                (lv660, lv609),
                out_sinfo=R.Tensor((1, 256, 14, 14, 4), dtype="float32"),
            )
            lv662 = R.call_tir(
                cls.add20,
                (
                    stage3_unit6_bn1_moving_var,
                    R.const(1.9999999494757503e-05, "float32"),
                ),
                out_sinfo=R.Tensor((1024,), dtype="float32"),
            )
            lv663 = R.call_tir(
                cls.rsqrt5, (lv662,), out_sinfo=R.Tensor((1024,), dtype="float32")
            )
            lv664 = R.call_tir(
                cls.multiply18,
                (lv663, stage3_unit6_bn1_gamma),
                out_sinfo=R.Tensor((1024,), dtype="float32"),
            )
            lv665 = R.call_tir(
                cls.expand_dims12,
                (lv664,),
                out_sinfo=R.Tensor((1024, 1, 1), dtype="float32"),
            )
            lv666 = R.call_tir(
                cls.expand_dims13,
                (lv665,),
                out_sinfo=R.Tensor((1, 1024, 1, 1), dtype="float32"),
            )
            lv667 = R.call_tir(
                cls.layout_transform19,
                (lv666,),
                out_sinfo=R.Tensor((1, 256, 1, 1, 4), dtype="float32"),
            )
            lv668 = R.call_tir(
                cls.multiply19,
                (lv661, lv667),
                out_sinfo=R.Tensor((1, 256, 14, 14, 4), dtype="float32"),
            )
            lv669 = R.call_tir(
                cls.negative5,
                (stage3_unit6_bn1_moving_mean,),
                out_sinfo=R.Tensor((1024,), dtype="float32"),
            )
            lv670 = R.call_tir(
                cls.multiply18,
                (lv669, lv664),
                out_sinfo=R.Tensor((1024,), dtype="float32"),
            )
            lv671 = R.call_tir(
                cls.add21,
                (lv670, stage3_unit6_bn1_beta),
                out_sinfo=R.Tensor((1024,), dtype="float32"),
            )
            lv672 = R.call_tir(
                cls.expand_dims12,
                (lv671,),
                out_sinfo=R.Tensor((1024, 1, 1), dtype="float32"),
            )
            lv673 = R.call_tir(
                cls.expand_dims13,
                (lv672,),
                out_sinfo=R.Tensor((1, 1024, 1, 1), dtype="float32"),
            )
            lv674 = R.call_tir(
                cls.layout_transform19,
                (lv673,),
                out_sinfo=R.Tensor((1, 256, 1, 1, 4), dtype="float32"),
            )
            lv675 = R.call_tir(
                cls.add22,
                (lv668, lv674),
                out_sinfo=R.Tensor((1, 256, 14, 14, 4), dtype="float32"),
            )
            lv676 = R.call_tir(
                cls.relu6,
                (lv675,),
                out_sinfo=R.Tensor((1, 256, 14, 14, 4), dtype="float32"),
            )
            lv677 = R.call_tir(
                cls.add8,
                (
                    stage3_unit6_bn2_moving_var,
                    R.const(1.9999999494757503e-05, "float32"),
                ),
                out_sinfo=R.Tensor((256,), dtype="float32"),
            )
            lv678 = R.call_tir(
                cls.rsqrt2, (lv677,), out_sinfo=R.Tensor((256,), dtype="float32")
            )
            lv679 = R.call_tir(
                cls.multiply7,
                (lv678, stage3_unit6_bn2_gamma),
                out_sinfo=R.Tensor((256,), dtype="float32"),
            )
            lv680 = R.call_tir(
                cls.expand_dims4,
                (lv679,),
                out_sinfo=R.Tensor((256, 1, 1), dtype="float32"),
            )
            lv681 = R.call_tir(
                cls.squeeze3, (lv680,), out_sinfo=R.Tensor((256,), dtype="float32")
            )
            lv682 = R.call_tir(
                cls.expand_dims11,
                (lv681,),
                out_sinfo=R.Tensor((256, 1, 1, 1), dtype="float32"),
            )
            lv683 = R.call_tir(
                cls.multiply20,
                (stage3_unit6_conv1_weight, lv682),
                out_sinfo=R.Tensor((256, 1024, 1, 1), dtype="float32"),
            )
            lv684 = R.call_tir(
                cls.layout_transform20,
                (lv683,),
                out_sinfo=R.Tensor((64, 256, 1, 1, 4, 4), dtype="float32"),
            )
            lv685 = R.call_tir(
                cls.contrib_conv2d_NCHWc14,
                (lv676, lv684),
                out_sinfo=R.Tensor((1, 64, 14, 14, 4), dtype="float32"),
            )
            lv686 = R.call_tir(
                cls.negative2,
                (stage3_unit6_bn2_moving_mean,),
                out_sinfo=R.Tensor((256,), dtype="float32"),
            )
            lv687 = R.call_tir(
                cls.multiply7,
                (lv686, lv679),
                out_sinfo=R.Tensor((256,), dtype="float32"),
            )
            lv688 = R.call_tir(
                cls.add9,
                (lv687, stage3_unit6_bn2_beta),
                out_sinfo=R.Tensor((256,), dtype="float32"),
            )
            lv689 = R.call_tir(
                cls.expand_dims4,
                (lv688,),
                out_sinfo=R.Tensor((256, 1, 1), dtype="float32"),
            )
            lv690 = R.call_tir(
                cls.expand_dims5,
                (lv689,),
                out_sinfo=R.Tensor((1, 256, 1, 1), dtype="float32"),
            )
            lv691 = R.call_tir(
                cls.layout_transform6,
                (lv690,),
                out_sinfo=R.Tensor((1, 64, 1, 1, 4), dtype="float32"),
            )
            lv692 = R.call_tir(
                cls.add18,
                (lv685, lv691),
                out_sinfo=R.Tensor((1, 64, 14, 14, 4), dtype="float32"),
            )
            lv693 = R.call_tir(
                cls.relu5,
                (lv692,),
                out_sinfo=R.Tensor((1, 64, 14, 14, 4), dtype="float32"),
            )
            lv694 = R.call_tir(
                cls.add8,
                (
                    stage3_unit6_bn3_moving_var,
                    R.const(1.9999999494757503e-05, "float32"),
                ),
                out_sinfo=R.Tensor((256,), dtype="float32"),
            )
            lv695 = R.call_tir(
                cls.rsqrt2, (lv694,), out_sinfo=R.Tensor((256,), dtype="float32")
            )
            lv696 = R.call_tir(
                cls.multiply7,
                (lv695, stage3_unit6_bn3_gamma),
                out_sinfo=R.Tensor((256,), dtype="float32"),
            )
            lv697 = R.call_tir(
                cls.expand_dims4,
                (lv696,),
                out_sinfo=R.Tensor((256, 1, 1), dtype="float32"),
            )
            lv698 = R.call_tir(
                cls.squeeze3, (lv697,), out_sinfo=R.Tensor((256,), dtype="float32")
            )
            lv699 = R.call_tir(
                cls.expand_dims11,
                (lv698,),
                out_sinfo=R.Tensor((256, 1, 1, 1), dtype="float32"),
            )
            lv700 = R.call_tir(
                cls.multiply17,
                (stage3_unit6_conv2_weight, lv699),
                out_sinfo=R.Tensor((256, 256, 3, 3), dtype="float32"),
            )
            lv701 = R.call_tir(
                cls.layout_transform16,
                (lv700,),
                out_sinfo=R.Tensor((64, 64, 3, 3, 4, 4), dtype="float32"),
            )
            lv702 = R.call_tir(
                cls.contrib_conv2d_NCHWc11,
                (lv693, lv701),
                out_sinfo=R.Tensor((1, 64, 14, 14, 4), dtype="float32"),
            )
            lv703 = R.call_tir(
                cls.negative2,
                (stage3_unit6_bn3_moving_mean,),
                out_sinfo=R.Tensor((256,), dtype="float32"),
            )
            lv704 = R.call_tir(
                cls.multiply7,
                (lv703, lv696),
                out_sinfo=R.Tensor((256,), dtype="float32"),
            )
            lv705 = R.call_tir(
                cls.add9,
                (lv704, stage3_unit6_bn3_beta),
                out_sinfo=R.Tensor((256,), dtype="float32"),
            )
            lv706 = R.call_tir(
                cls.expand_dims4,
                (lv705,),
                out_sinfo=R.Tensor((256, 1, 1), dtype="float32"),
            )
            lv707 = R.call_tir(
                cls.expand_dims5,
                (lv706,),
                out_sinfo=R.Tensor((1, 256, 1, 1), dtype="float32"),
            )
            lv708 = R.call_tir(
                cls.layout_transform6,
                (lv707,),
                out_sinfo=R.Tensor((1, 64, 1, 1, 4), dtype="float32"),
            )
            lv709 = R.call_tir(
                cls.add18,
                (lv702, lv708),
                out_sinfo=R.Tensor((1, 64, 14, 14, 4), dtype="float32"),
            )
            lv710 = R.call_tir(
                cls.relu5,
                (lv709,),
                out_sinfo=R.Tensor((1, 64, 14, 14, 4), dtype="float32"),
            )
            lv711 = R.call_tir(
                cls.layout_transform17,
                (stage3_unit6_conv3_weight,),
                out_sinfo=R.Tensor((256, 64, 1, 1, 4, 4), dtype="float32"),
            )
            lv712 = R.call_tir(
                cls.contrib_conv2d_NCHWc12,
                (lv710, lv711),
                out_sinfo=R.Tensor((1, 256, 14, 14, 4), dtype="float32"),
            )
            lv713 = R.call_tir(
                cls.add19,
                (lv712, lv661),
                out_sinfo=R.Tensor((1, 256, 14, 14, 4), dtype="float32"),
            )
            lv714 = R.call_tir(
                cls.add20,
                (
                    stage4_unit1_bn1_moving_var,
                    R.const(1.9999999494757503e-05, "float32"),
                ),
                out_sinfo=R.Tensor((1024,), dtype="float32"),
            )
            lv715 = R.call_tir(
                cls.rsqrt5, (lv714,), out_sinfo=R.Tensor((1024,), dtype="float32")
            )
            lv716 = R.call_tir(
                cls.multiply18,
                (lv715, stage4_unit1_bn1_gamma),
                out_sinfo=R.Tensor((1024,), dtype="float32"),
            )
            lv717 = R.call_tir(
                cls.expand_dims12,
                (lv716,),
                out_sinfo=R.Tensor((1024, 1, 1), dtype="float32"),
            )
            lv718 = R.call_tir(
                cls.expand_dims13,
                (lv717,),
                out_sinfo=R.Tensor((1, 1024, 1, 1), dtype="float32"),
            )
            lv719 = R.call_tir(
                cls.layout_transform19,
                (lv718,),
                out_sinfo=R.Tensor((1, 256, 1, 1, 4), dtype="float32"),
            )
            lv720 = R.call_tir(
                cls.multiply19,
                (lv713, lv719),
                out_sinfo=R.Tensor((1, 256, 14, 14, 4), dtype="float32"),
            )
            lv721 = R.call_tir(
                cls.negative5,
                (stage4_unit1_bn1_moving_mean,),
                out_sinfo=R.Tensor((1024,), dtype="float32"),
            )
            lv722 = R.call_tir(
                cls.multiply18,
                (lv721, lv716),
                out_sinfo=R.Tensor((1024,), dtype="float32"),
            )
            lv723 = R.call_tir(
                cls.add21,
                (lv722, stage4_unit1_bn1_beta),
                out_sinfo=R.Tensor((1024,), dtype="float32"),
            )
            lv724 = R.call_tir(
                cls.expand_dims12,
                (lv723,),
                out_sinfo=R.Tensor((1024, 1, 1), dtype="float32"),
            )
            lv725 = R.call_tir(
                cls.expand_dims13,
                (lv724,),
                out_sinfo=R.Tensor((1, 1024, 1, 1), dtype="float32"),
            )
            lv726 = R.call_tir(
                cls.layout_transform19,
                (lv725,),
                out_sinfo=R.Tensor((1, 256, 1, 1, 4), dtype="float32"),
            )
            lv727 = R.call_tir(
                cls.add22,
                (lv720, lv726),
                out_sinfo=R.Tensor((1, 256, 14, 14, 4), dtype="float32"),
            )
            lv728 = R.call_tir(
                cls.relu6,
                (lv727,),
                out_sinfo=R.Tensor((1, 256, 14, 14, 4), dtype="float32"),
            )
            lv729 = R.call_tir(
                cls.add15,
                (
                    stage4_unit1_bn2_moving_var,
                    R.const(1.9999999494757503e-05, "float32"),
                ),
                out_sinfo=R.Tensor((512,), dtype="float32"),
            )
            lv730 = R.call_tir(
                cls.rsqrt4, (lv729,), out_sinfo=R.Tensor((512,), dtype="float32")
            )
            lv731 = R.call_tir(
                cls.multiply13,
                (lv730, stage4_unit1_bn2_gamma),
                out_sinfo=R.Tensor((512,), dtype="float32"),
            )
            lv732 = R.call_tir(
                cls.expand_dims9,
                (lv731,),
                out_sinfo=R.Tensor((512, 1, 1), dtype="float32"),
            )
            lv733 = R.call_tir(
                cls.squeeze4, (lv732,), out_sinfo=R.Tensor((512,), dtype="float32")
            )
            lv734 = R.call_tir(
                cls.expand_dims14,
                (lv733,),
                out_sinfo=R.Tensor((512, 1, 1, 1), dtype="float32"),
            )
            lv735 = R.call_tir(
                cls.multiply21,
                (stage4_unit1_conv1_weight, lv734),
                out_sinfo=R.Tensor((512, 1024, 1, 1), dtype="float32"),
            )
            lv736 = R.call_tir(
                cls.layout_transform21,
                (lv735,),
                out_sinfo=R.Tensor((128, 256, 1, 1, 4, 4), dtype="float32"),
            )
            lv737 = R.call_tir(
                cls.contrib_conv2d_NCHWc15,
                (lv728, lv736),
                out_sinfo=R.Tensor((1, 128, 7, 7, 4), dtype="float32"),
            )
            lv738 = R.call_tir(
                cls.negative4,
                (stage4_unit1_bn2_moving_mean,),
                out_sinfo=R.Tensor((512,), dtype="float32"),
            )
            lv739 = R.call_tir(
                cls.multiply13,
                (lv738, lv731),
                out_sinfo=R.Tensor((512,), dtype="float32"),
            )
            lv740 = R.call_tir(
                cls.add16,
                (lv739, stage4_unit1_bn2_beta),
                out_sinfo=R.Tensor((512,), dtype="float32"),
            )
            lv741 = R.call_tir(
                cls.expand_dims9,
                (lv740,),
                out_sinfo=R.Tensor((512, 1, 1), dtype="float32"),
            )
            lv742 = R.call_tir(
                cls.expand_dims10,
                (lv741,),
                out_sinfo=R.Tensor((1, 512, 1, 1), dtype="float32"),
            )
            lv743 = R.call_tir(
                cls.layout_transform13,
                (lv742,),
                out_sinfo=R.Tensor((1, 128, 1, 1, 4), dtype="float32"),
            )
            lv744 = R.call_tir(
                cls.add23,
                (lv737, lv743),
                out_sinfo=R.Tensor((1, 128, 7, 7, 4), dtype="float32"),
            )
            lv745 = R.call_tir(
                cls.relu7,
                (lv744,),
                out_sinfo=R.Tensor((1, 128, 7, 7, 4), dtype="float32"),
            )
            lv746 = R.call_tir(
                cls.add15,
                (
                    stage4_unit1_bn3_moving_var,
                    R.const(1.9999999494757503e-05, "float32"),
                ),
                out_sinfo=R.Tensor((512,), dtype="float32"),
            )
            lv747 = R.call_tir(
                cls.rsqrt4, (lv746,), out_sinfo=R.Tensor((512,), dtype="float32")
            )
            lv748 = R.call_tir(
                cls.multiply13,
                (lv747, stage4_unit1_bn3_gamma),
                out_sinfo=R.Tensor((512,), dtype="float32"),
            )
            lv749 = R.call_tir(
                cls.expand_dims9,
                (lv748,),
                out_sinfo=R.Tensor((512, 1, 1), dtype="float32"),
            )
            lv750 = R.call_tir(
                cls.squeeze4, (lv749,), out_sinfo=R.Tensor((512,), dtype="float32")
            )
            lv751 = R.call_tir(
                cls.expand_dims14,
                (lv750,),
                out_sinfo=R.Tensor((512, 1, 1, 1), dtype="float32"),
            )
            lv752 = R.call_tir(
                cls.multiply22,
                (stage4_unit1_conv2_weight, lv751),
                out_sinfo=R.Tensor((512, 512, 3, 3), dtype="float32"),
            )
            lv753 = R.call_tir(
                cls.layout_transform22,
                (lv752,),
                out_sinfo=R.Tensor((128, 128, 3, 3, 4, 4), dtype="float32"),
            )
            lv754 = R.call_tir(
                cls.contrib_conv2d_NCHWc16,
                (lv745, lv753),
                out_sinfo=R.Tensor((1, 128, 7, 7, 4), dtype="float32"),
            )
            lv755 = R.call_tir(
                cls.negative4,
                (stage4_unit1_bn3_moving_mean,),
                out_sinfo=R.Tensor((512,), dtype="float32"),
            )
            lv756 = R.call_tir(
                cls.multiply13,
                (lv755, lv748),
                out_sinfo=R.Tensor((512,), dtype="float32"),
            )
            lv757 = R.call_tir(
                cls.add16,
                (lv756, stage4_unit1_bn3_beta),
                out_sinfo=R.Tensor((512,), dtype="float32"),
            )
            lv758 = R.call_tir(
                cls.expand_dims9,
                (lv757,),
                out_sinfo=R.Tensor((512, 1, 1), dtype="float32"),
            )
            lv759 = R.call_tir(
                cls.expand_dims10,
                (lv758,),
                out_sinfo=R.Tensor((1, 512, 1, 1), dtype="float32"),
            )
            lv760 = R.call_tir(
                cls.layout_transform13,
                (lv759,),
                out_sinfo=R.Tensor((1, 128, 1, 1, 4), dtype="float32"),
            )
            lv761 = R.call_tir(
                cls.add23,
                (lv754, lv760),
                out_sinfo=R.Tensor((1, 128, 7, 7, 4), dtype="float32"),
            )
            lv762 = R.call_tir(
                cls.relu7,
                (lv761,),
                out_sinfo=R.Tensor((1, 128, 7, 7, 4), dtype="float32"),
            )
            lv763 = R.call_tir(
                cls.layout_transform23,
                (stage4_unit1_conv3_weight,),
                out_sinfo=R.Tensor((512, 128, 1, 1, 4, 4), dtype="float32"),
            )
            lv764 = R.call_tir(
                cls.contrib_conv2d_NCHWc17,
                (lv762, lv763),
                out_sinfo=R.Tensor((1, 512, 7, 7, 4), dtype="float32"),
            )
            lv765 = R.call_tir(
                cls.layout_transform24,
                (stage4_unit1_sc_weight,),
                out_sinfo=R.Tensor((512, 256, 1, 1, 4, 4), dtype="float32"),
            )
            lv766 = R.call_tir(
                cls.contrib_conv2d_NCHWc18,
                (lv728, lv765),
                out_sinfo=R.Tensor((1, 512, 7, 7, 4), dtype="float32"),
            )
            lv767 = R.call_tir(
                cls.add24,
                (lv764, lv766),
                out_sinfo=R.Tensor((1, 512, 7, 7, 4), dtype="float32"),
            )
            lv768 = R.call_tir(
                cls.add25,
                (
                    stage4_unit2_bn1_moving_var,
                    R.const(1.9999999494757503e-05, "float32"),
                ),
                out_sinfo=R.Tensor((2048,), dtype="float32"),
            )
            lv769 = R.call_tir(
                cls.rsqrt6, (lv768,), out_sinfo=R.Tensor((2048,), dtype="float32")
            )
            lv770 = R.call_tir(
                cls.multiply23,
                (lv769, stage4_unit2_bn1_gamma),
                out_sinfo=R.Tensor((2048,), dtype="float32"),
            )
            lv771 = R.call_tir(
                cls.expand_dims15,
                (lv770,),
                out_sinfo=R.Tensor((2048, 1, 1), dtype="float32"),
            )
            lv772 = R.call_tir(
                cls.expand_dims16,
                (lv771,),
                out_sinfo=R.Tensor((1, 2048, 1, 1), dtype="float32"),
            )
            lv773 = R.call_tir(
                cls.layout_transform25,
                (lv772,),
                out_sinfo=R.Tensor((1, 512, 1, 1, 4), dtype="float32"),
            )
            lv774 = R.call_tir(
                cls.multiply24,
                (lv767, lv773),
                out_sinfo=R.Tensor((1, 512, 7, 7, 4), dtype="float32"),
            )
            lv775 = R.call_tir(
                cls.negative6,
                (stage4_unit2_bn1_moving_mean,),
                out_sinfo=R.Tensor((2048,), dtype="float32"),
            )
            lv776 = R.call_tir(
                cls.multiply23,
                (lv775, lv770),
                out_sinfo=R.Tensor((2048,), dtype="float32"),
            )
            lv777 = R.call_tir(
                cls.add26,
                (lv776, stage4_unit2_bn1_beta),
                out_sinfo=R.Tensor((2048,), dtype="float32"),
            )
            lv778 = R.call_tir(
                cls.expand_dims15,
                (lv777,),
                out_sinfo=R.Tensor((2048, 1, 1), dtype="float32"),
            )
            lv779 = R.call_tir(
                cls.expand_dims16,
                (lv778,),
                out_sinfo=R.Tensor((1, 2048, 1, 1), dtype="float32"),
            )
            lv780 = R.call_tir(
                cls.layout_transform25,
                (lv779,),
                out_sinfo=R.Tensor((1, 512, 1, 1, 4), dtype="float32"),
            )
            lv781 = R.call_tir(
                cls.add27,
                (lv774, lv780),
                out_sinfo=R.Tensor((1, 512, 7, 7, 4), dtype="float32"),
            )
            lv782 = R.call_tir(
                cls.relu8,
                (lv781,),
                out_sinfo=R.Tensor((1, 512, 7, 7, 4), dtype="float32"),
            )
            lv783 = R.call_tir(
                cls.add15,
                (
                    stage4_unit2_bn2_moving_var,
                    R.const(1.9999999494757503e-05, "float32"),
                ),
                out_sinfo=R.Tensor((512,), dtype="float32"),
            )
            lv784 = R.call_tir(
                cls.rsqrt4, (lv783,), out_sinfo=R.Tensor((512,), dtype="float32")
            )
            lv785 = R.call_tir(
                cls.multiply13,
                (lv784, stage4_unit2_bn2_gamma),
                out_sinfo=R.Tensor((512,), dtype="float32"),
            )
            lv786 = R.call_tir(
                cls.expand_dims9,
                (lv785,),
                out_sinfo=R.Tensor((512, 1, 1), dtype="float32"),
            )
            lv787 = R.call_tir(
                cls.squeeze4, (lv786,), out_sinfo=R.Tensor((512,), dtype="float32")
            )
            lv788 = R.call_tir(
                cls.expand_dims14,
                (lv787,),
                out_sinfo=R.Tensor((512, 1, 1, 1), dtype="float32"),
            )
            lv789 = R.call_tir(
                cls.multiply25,
                (stage4_unit2_conv1_weight, lv788),
                out_sinfo=R.Tensor((512, 2048, 1, 1), dtype="float32"),
            )
            lv790 = R.call_tir(
                cls.layout_transform26,
                (lv789,),
                out_sinfo=R.Tensor((128, 512, 1, 1, 4, 4), dtype="float32"),
            )
            lv791 = R.call_tir(
                cls.contrib_conv2d_NCHWc19,
                (lv782, lv790),
                out_sinfo=R.Tensor((1, 128, 7, 7, 4), dtype="float32"),
            )
            lv792 = R.call_tir(
                cls.negative4,
                (stage4_unit2_bn2_moving_mean,),
                out_sinfo=R.Tensor((512,), dtype="float32"),
            )
            lv793 = R.call_tir(
                cls.multiply13,
                (lv792, lv785),
                out_sinfo=R.Tensor((512,), dtype="float32"),
            )
            lv794 = R.call_tir(
                cls.add16,
                (lv793, stage4_unit2_bn2_beta),
                out_sinfo=R.Tensor((512,), dtype="float32"),
            )
            lv795 = R.call_tir(
                cls.expand_dims9,
                (lv794,),
                out_sinfo=R.Tensor((512, 1, 1), dtype="float32"),
            )
            lv796 = R.call_tir(
                cls.expand_dims10,
                (lv795,),
                out_sinfo=R.Tensor((1, 512, 1, 1), dtype="float32"),
            )
            lv797 = R.call_tir(
                cls.layout_transform13,
                (lv796,),
                out_sinfo=R.Tensor((1, 128, 1, 1, 4), dtype="float32"),
            )
            lv798 = R.call_tir(
                cls.add23,
                (lv791, lv797),
                out_sinfo=R.Tensor((1, 128, 7, 7, 4), dtype="float32"),
            )
            lv799 = R.call_tir(
                cls.relu7,
                (lv798,),
                out_sinfo=R.Tensor((1, 128, 7, 7, 4), dtype="float32"),
            )
            lv800 = R.call_tir(
                cls.add15,
                (
                    stage4_unit2_bn3_moving_var,
                    R.const(1.9999999494757503e-05, "float32"),
                ),
                out_sinfo=R.Tensor((512,), dtype="float32"),
            )
            lv801 = R.call_tir(
                cls.rsqrt4, (lv800,), out_sinfo=R.Tensor((512,), dtype="float32")
            )
            lv802 = R.call_tir(
                cls.multiply13,
                (lv801, stage4_unit2_bn3_gamma),
                out_sinfo=R.Tensor((512,), dtype="float32"),
            )
            lv803 = R.call_tir(
                cls.expand_dims9,
                (lv802,),
                out_sinfo=R.Tensor((512, 1, 1), dtype="float32"),
            )
            lv804 = R.call_tir(
                cls.squeeze4, (lv803,), out_sinfo=R.Tensor((512,), dtype="float32")
            )
            lv805 = R.call_tir(
                cls.expand_dims14,
                (lv804,),
                out_sinfo=R.Tensor((512, 1, 1, 1), dtype="float32"),
            )
            lv806 = R.call_tir(
                cls.multiply22,
                (stage4_unit2_conv2_weight, lv805),
                out_sinfo=R.Tensor((512, 512, 3, 3), dtype="float32"),
            )
            lv807 = R.call_tir(
                cls.layout_transform22,
                (lv806,),
                out_sinfo=R.Tensor((128, 128, 3, 3, 4, 4), dtype="float32"),
            )
            lv808 = R.call_tir(
                cls.contrib_conv2d_NCHWc16,
                (lv799, lv807),
                out_sinfo=R.Tensor((1, 128, 7, 7, 4), dtype="float32"),
            )
            lv809 = R.call_tir(
                cls.negative4,
                (stage4_unit2_bn3_moving_mean,),
                out_sinfo=R.Tensor((512,), dtype="float32"),
            )
            lv810 = R.call_tir(
                cls.multiply13,
                (lv809, lv802),
                out_sinfo=R.Tensor((512,), dtype="float32"),
            )
            lv811 = R.call_tir(
                cls.add16,
                (lv810, stage4_unit2_bn3_beta),
                out_sinfo=R.Tensor((512,), dtype="float32"),
            )
            lv812 = R.call_tir(
                cls.expand_dims9,
                (lv811,),
                out_sinfo=R.Tensor((512, 1, 1), dtype="float32"),
            )
            lv813 = R.call_tir(
                cls.expand_dims10,
                (lv812,),
                out_sinfo=R.Tensor((1, 512, 1, 1), dtype="float32"),
            )
            lv814 = R.call_tir(
                cls.layout_transform13,
                (lv813,),
                out_sinfo=R.Tensor((1, 128, 1, 1, 4), dtype="float32"),
            )
            lv815 = R.call_tir(
                cls.add23,
                (lv808, lv814),
                out_sinfo=R.Tensor((1, 128, 7, 7, 4), dtype="float32"),
            )
            lv816 = R.call_tir(
                cls.relu7,
                (lv815,),
                out_sinfo=R.Tensor((1, 128, 7, 7, 4), dtype="float32"),
            )
            lv817 = R.call_tir(
                cls.layout_transform23,
                (stage4_unit2_conv3_weight,),
                out_sinfo=R.Tensor((512, 128, 1, 1, 4, 4), dtype="float32"),
            )
            lv818 = R.call_tir(
                cls.contrib_conv2d_NCHWc17,
                (lv816, lv817),
                out_sinfo=R.Tensor((1, 512, 7, 7, 4), dtype="float32"),
            )
            lv819 = R.call_tir(
                cls.add24,
                (lv818, lv767),
                out_sinfo=R.Tensor((1, 512, 7, 7, 4), dtype="float32"),
            )
            lv820 = R.call_tir(
                cls.add25,
                (
                    stage4_unit3_bn1_moving_var,
                    R.const(1.9999999494757503e-05, "float32"),
                ),
                out_sinfo=R.Tensor((2048,), dtype="float32"),
            )
            lv821 = R.call_tir(
                cls.rsqrt6, (lv820,), out_sinfo=R.Tensor((2048,), dtype="float32")
            )
            lv822 = R.call_tir(
                cls.multiply23,
                (lv821, stage4_unit3_bn1_gamma),
                out_sinfo=R.Tensor((2048,), dtype="float32"),
            )
            lv823 = R.call_tir(
                cls.expand_dims15,
                (lv822,),
                out_sinfo=R.Tensor((2048, 1, 1), dtype="float32"),
            )
            lv824 = R.call_tir(
                cls.expand_dims16,
                (lv823,),
                out_sinfo=R.Tensor((1, 2048, 1, 1), dtype="float32"),
            )
            lv825 = R.call_tir(
                cls.layout_transform25,
                (lv824,),
                out_sinfo=R.Tensor((1, 512, 1, 1, 4), dtype="float32"),
            )
            lv826 = R.call_tir(
                cls.multiply24,
                (lv819, lv825),
                out_sinfo=R.Tensor((1, 512, 7, 7, 4), dtype="float32"),
            )
            lv827 = R.call_tir(
                cls.negative6,
                (stage4_unit3_bn1_moving_mean,),
                out_sinfo=R.Tensor((2048,), dtype="float32"),
            )
            lv828 = R.call_tir(
                cls.multiply23,
                (lv827, lv822),
                out_sinfo=R.Tensor((2048,), dtype="float32"),
            )
            lv829 = R.call_tir(
                cls.add26,
                (lv828, stage4_unit3_bn1_beta),
                out_sinfo=R.Tensor((2048,), dtype="float32"),
            )
            lv830 = R.call_tir(
                cls.expand_dims15,
                (lv829,),
                out_sinfo=R.Tensor((2048, 1, 1), dtype="float32"),
            )
            lv831 = R.call_tir(
                cls.expand_dims16,
                (lv830,),
                out_sinfo=R.Tensor((1, 2048, 1, 1), dtype="float32"),
            )
            lv832 = R.call_tir(
                cls.layout_transform25,
                (lv831,),
                out_sinfo=R.Tensor((1, 512, 1, 1, 4), dtype="float32"),
            )
            lv833 = R.call_tir(
                cls.add27,
                (lv826, lv832),
                out_sinfo=R.Tensor((1, 512, 7, 7, 4), dtype="float32"),
            )
            lv834 = R.call_tir(
                cls.relu8,
                (lv833,),
                out_sinfo=R.Tensor((1, 512, 7, 7, 4), dtype="float32"),
            )
            lv835 = R.call_tir(
                cls.add15,
                (
                    stage4_unit3_bn2_moving_var,
                    R.const(1.9999999494757503e-05, "float32"),
                ),
                out_sinfo=R.Tensor((512,), dtype="float32"),
            )
            lv836 = R.call_tir(
                cls.rsqrt4, (lv835,), out_sinfo=R.Tensor((512,), dtype="float32")
            )
            lv837 = R.call_tir(
                cls.multiply13,
                (lv836, stage4_unit3_bn2_gamma),
                out_sinfo=R.Tensor((512,), dtype="float32"),
            )
            lv838 = R.call_tir(
                cls.expand_dims9,
                (lv837,),
                out_sinfo=R.Tensor((512, 1, 1), dtype="float32"),
            )
            lv839 = R.call_tir(
                cls.squeeze4, (lv838,), out_sinfo=R.Tensor((512,), dtype="float32")
            )
            lv840 = R.call_tir(
                cls.expand_dims14,
                (lv839,),
                out_sinfo=R.Tensor((512, 1, 1, 1), dtype="float32"),
            )
            lv841 = R.call_tir(
                cls.multiply25,
                (stage4_unit3_conv1_weight, lv840),
                out_sinfo=R.Tensor((512, 2048, 1, 1), dtype="float32"),
            )
            lv842 = R.call_tir(
                cls.layout_transform26,
                (lv841,),
                out_sinfo=R.Tensor((128, 512, 1, 1, 4, 4), dtype="float32"),
            )
            lv843 = R.call_tir(
                cls.contrib_conv2d_NCHWc19,
                (lv834, lv842),
                out_sinfo=R.Tensor((1, 128, 7, 7, 4), dtype="float32"),
            )
            lv844 = R.call_tir(
                cls.negative4,
                (stage4_unit3_bn2_moving_mean,),
                out_sinfo=R.Tensor((512,), dtype="float32"),
            )
            lv845 = R.call_tir(
                cls.multiply13,
                (lv844, lv837),
                out_sinfo=R.Tensor((512,), dtype="float32"),
            )
            lv846 = R.call_tir(
                cls.add16,
                (lv845, stage4_unit3_bn2_beta),
                out_sinfo=R.Tensor((512,), dtype="float32"),
            )
            lv847 = R.call_tir(
                cls.expand_dims9,
                (lv846,),
                out_sinfo=R.Tensor((512, 1, 1), dtype="float32"),
            )
            lv848 = R.call_tir(
                cls.expand_dims10,
                (lv847,),
                out_sinfo=R.Tensor((1, 512, 1, 1), dtype="float32"),
            )
            lv849 = R.call_tir(
                cls.layout_transform13,
                (lv848,),
                out_sinfo=R.Tensor((1, 128, 1, 1, 4), dtype="float32"),
            )
            lv850 = R.call_tir(
                cls.add23,
                (lv843, lv849),
                out_sinfo=R.Tensor((1, 128, 7, 7, 4), dtype="float32"),
            )
            lv851 = R.call_tir(
                cls.relu7,
                (lv850,),
                out_sinfo=R.Tensor((1, 128, 7, 7, 4), dtype="float32"),
            )
            lv852 = R.call_tir(
                cls.add15,
                (
                    stage4_unit3_bn3_moving_var,
                    R.const(1.9999999494757503e-05, "float32"),
                ),
                out_sinfo=R.Tensor((512,), dtype="float32"),
            )
            lv853 = R.call_tir(
                cls.rsqrt4, (lv852,), out_sinfo=R.Tensor((512,), dtype="float32")
            )
            lv854 = R.call_tir(
                cls.multiply13,
                (lv853, stage4_unit3_bn3_gamma),
                out_sinfo=R.Tensor((512,), dtype="float32"),
            )
            lv855 = R.call_tir(
                cls.expand_dims9,
                (lv854,),
                out_sinfo=R.Tensor((512, 1, 1), dtype="float32"),
            )
            lv856 = R.call_tir(
                cls.squeeze4, (lv855,), out_sinfo=R.Tensor((512,), dtype="float32")
            )
            lv857 = R.call_tir(
                cls.expand_dims14,
                (lv856,),
                out_sinfo=R.Tensor((512, 1, 1, 1), dtype="float32"),
            )
            lv858 = R.call_tir(
                cls.multiply22,
                (stage4_unit3_conv2_weight, lv857),
                out_sinfo=R.Tensor((512, 512, 3, 3), dtype="float32"),
            )
            lv859 = R.call_tir(
                cls.layout_transform22,
                (lv858,),
                out_sinfo=R.Tensor((128, 128, 3, 3, 4, 4), dtype="float32"),
            )
            lv860 = R.call_tir(
                cls.contrib_conv2d_NCHWc16,
                (lv851, lv859),
                out_sinfo=R.Tensor((1, 128, 7, 7, 4), dtype="float32"),
            )
            lv861 = R.call_tir(
                cls.negative4,
                (stage4_unit3_bn3_moving_mean,),
                out_sinfo=R.Tensor((512,), dtype="float32"),
            )
            lv862 = R.call_tir(
                cls.multiply13,
                (lv861, lv854),
                out_sinfo=R.Tensor((512,), dtype="float32"),
            )
            lv863 = R.call_tir(
                cls.add16,
                (lv862, stage4_unit3_bn3_beta),
                out_sinfo=R.Tensor((512,), dtype="float32"),
            )
            lv864 = R.call_tir(
                cls.expand_dims9,
                (lv863,),
                out_sinfo=R.Tensor((512, 1, 1), dtype="float32"),
            )
            lv865 = R.call_tir(
                cls.expand_dims10,
                (lv864,),
                out_sinfo=R.Tensor((1, 512, 1, 1), dtype="float32"),
            )
            lv866 = R.call_tir(
                cls.layout_transform13,
                (lv865,),
                out_sinfo=R.Tensor((1, 128, 1, 1, 4), dtype="float32"),
            )
            lv867 = R.call_tir(
                cls.add23,
                (lv860, lv866),
                out_sinfo=R.Tensor((1, 128, 7, 7, 4), dtype="float32"),
            )
            lv868 = R.call_tir(
                cls.relu7,
                (lv867,),
                out_sinfo=R.Tensor((1, 128, 7, 7, 4), dtype="float32"),
            )
            lv869 = R.call_tir(
                cls.layout_transform23,
                (stage4_unit3_conv3_weight,),
                out_sinfo=R.Tensor((512, 128, 1, 1, 4, 4), dtype="float32"),
            )
            lv870 = R.call_tir(
                cls.contrib_conv2d_NCHWc17,
                (lv868, lv869),
                out_sinfo=R.Tensor((1, 512, 7, 7, 4), dtype="float32"),
            )
            lv871 = R.call_tir(
                cls.add24,
                (lv870, lv819),
                out_sinfo=R.Tensor((1, 512, 7, 7, 4), dtype="float32"),
            )
            lv872 = R.call_tir(
                cls.add25,
                (bn1_moving_var, R.const(1.9999999494757503e-05, "float32")),
                out_sinfo=R.Tensor((2048,), dtype="float32"),
            )
            lv873 = R.call_tir(
                cls.rsqrt6, (lv872,), out_sinfo=R.Tensor((2048,), dtype="float32")
            )
            lv874 = R.call_tir(
                cls.multiply23,
                (lv873, bn1_gamma),
                out_sinfo=R.Tensor((2048,), dtype="float32"),
            )
            lv875 = R.call_tir(
                cls.expand_dims15,
                (lv874,),
                out_sinfo=R.Tensor((2048, 1, 1), dtype="float32"),
            )
            lv876 = R.call_tir(
                cls.expand_dims16,
                (lv875,),
                out_sinfo=R.Tensor((1, 2048, 1, 1), dtype="float32"),
            )
            lv877 = R.call_tir(
                cls.layout_transform25,
                (lv876,),
                out_sinfo=R.Tensor((1, 512, 1, 1, 4), dtype="float32"),
            )
            lv878 = R.call_tir(
                cls.multiply24,
                (lv871, lv877),
                out_sinfo=R.Tensor((1, 512, 7, 7, 4), dtype="float32"),
            )
            lv879 = R.call_tir(
                cls.negative6,
                (bn1_moving_mean,),
                out_sinfo=R.Tensor((2048,), dtype="float32"),
            )
            lv880 = R.call_tir(
                cls.multiply23,
                (lv879, lv874),
                out_sinfo=R.Tensor((2048,), dtype="float32"),
            )
            lv881 = R.call_tir(
                cls.add26,
                (lv880, bn1_beta),
                out_sinfo=R.Tensor((2048,), dtype="float32"),
            )
            lv882 = R.call_tir(
                cls.expand_dims15,
                (lv881,),
                out_sinfo=R.Tensor((2048, 1, 1), dtype="float32"),
            )
            lv883 = R.call_tir(
                cls.expand_dims16,
                (lv882,),
                out_sinfo=R.Tensor((1, 2048, 1, 1), dtype="float32"),
            )
            lv884 = R.call_tir(
                cls.layout_transform25,
                (lv883,),
                out_sinfo=R.Tensor((1, 512, 1, 1, 4), dtype="float32"),
            )
            lv885 = R.call_tir(
                cls.add27,
                (lv878, lv884),
                out_sinfo=R.Tensor((1, 512, 7, 7, 4), dtype="float32"),
            )
            lv886 = R.call_tir(
                cls.relu8,
                (lv885,),
                out_sinfo=R.Tensor((1, 512, 7, 7, 4), dtype="float32"),
            )
            lv887 = R.call_tir(
                cls.global_avg_pool2d,
                (lv886,),
                out_sinfo=R.Tensor((1, 512, 1, 1, 4), dtype="float32"),
            )
            lv888 = R.call_tir(
                cls.layout_transform27,
                (lv887,),
                out_sinfo=R.Tensor((1, 2048, 1, 1), dtype="float32"),
            )
            lv889 = R.call_tir(
                cls.batch_flatten,
                (lv888,),
                out_sinfo=R.Tensor((1, 2048), dtype="float32"),
            )
            lv890 = R.call_tir(
                cls.dense,
                (lv889, fc1_weight),
                out_sinfo=R.Tensor((1, 1000), dtype="float32"),
            )
            lv891 = R.call_tir(
                cls.expand_dims17,
                (fc1_bias,),
                out_sinfo=R.Tensor((1, 1000), dtype="float32"),
            )
            lv892 = R.call_tir(
                cls.add28,
                (lv890, lv891),
                out_sinfo=R.Tensor((1, 1000), dtype="float32"),
            )
            lv893 = R.call_tir(
                cls.softmax, (lv892,), out_sinfo=R.Tensor((1, 1000), dtype="float32")
            )
            gv: R.Tensor((1, 1000), dtype="float32") = lv893
            R.output(gv)
        return gv
One or more operators have not been tuned. Please tune your model for better performance. Use DEBUG logging level to see more details.