# from tvm.script import ir as I
# from tvm.script import tir as T
# from tvm.script import relax as R
@I.ir_module
class Module:
@T.prim_func
def add(
A: T.Buffer((T.int64(3),), "float32"),
B: T.Buffer((), "float32"),
T_add: T.Buffer((T.int64(3),), "float32"),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0 in range(T.int64(3)):
with T.block("T_add"):
v_ax0 = T.axis.spatial(T.int64(3), ax0)
T.reads(A[v_ax0], B[()])
T.writes(T_add[v_ax0])
T_add[v_ax0] = A[v_ax0] + B[()]
@T.prim_func
def add1(
A: T.Buffer((T.int64(3),), "float32"),
B: T.Buffer((T.int64(3),), "float32"),
T_add: T.Buffer((T.int64(3),), "float32"),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0 in range(T.int64(3)):
with T.block("T_add"):
v_ax0 = T.axis.spatial(T.int64(3), ax0)
T.reads(A[v_ax0], B[v_ax0])
T.writes(T_add[v_ax0])
T_add[v_ax0] = A[v_ax0] + B[v_ax0]
@T.prim_func
def add10(
A: T.Buffer(
(T.int64(1), T.int64(64), T.int64(56), T.int64(56), T.int64(4)), "float32"
),
B: T.Buffer(
(T.int64(1), T.int64(64), T.int64(1), T.int64(1), T.int64(4)), "float32"
),
T_add: T.Buffer(
(T.int64(1), T.int64(64), T.int64(56), T.int64(56), T.int64(4)), "float32"
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2, ax3, ax4 in T.grid(
T.int64(1), T.int64(64), T.int64(56), T.int64(56), T.int64(4)
):
with T.block("T_add"):
v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap(
"SSSSS", [ax0, ax1, ax2, ax3, ax4]
)
T.reads(
A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4],
B[v_ax0, v_ax1, T.int64(0), T.int64(0), v_ax4],
)
T.writes(T_add[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
T_add[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = (
A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]
+ B[v_ax0, v_ax1, T.int64(0), T.int64(0), v_ax4]
)
@T.prim_func
def add11(
A: T.Buffer((T.int64(128),), "float32"),
B: T.Buffer((), "float32"),
T_add: T.Buffer((T.int64(128),), "float32"),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0 in range(T.int64(128)):
with T.block("T_add"):
v_ax0 = T.axis.spatial(T.int64(128), ax0)
T.reads(A[v_ax0], B[()])
T.writes(T_add[v_ax0])
T_add[v_ax0] = A[v_ax0] + B[()]
@T.prim_func
def add12(
A: T.Buffer((T.int64(128),), "float32"),
B: T.Buffer((T.int64(128),), "float32"),
T_add: T.Buffer((T.int64(128),), "float32"),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0 in range(T.int64(128)):
with T.block("T_add"):
v_ax0 = T.axis.spatial(T.int64(128), ax0)
T.reads(A[v_ax0], B[v_ax0])
T.writes(T_add[v_ax0])
T_add[v_ax0] = A[v_ax0] + B[v_ax0]
@T.prim_func
def add13(
A: T.Buffer(
(T.int64(1), T.int64(32), T.int64(28), T.int64(28), T.int64(4)), "float32"
),
B: T.Buffer(
(T.int64(1), T.int64(32), T.int64(1), T.int64(1), T.int64(4)), "float32"
),
T_add: T.Buffer(
(T.int64(1), T.int64(32), T.int64(28), T.int64(28), T.int64(4)), "float32"
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2, ax3, ax4 in T.grid(
T.int64(1), T.int64(32), T.int64(28), T.int64(28), T.int64(4)
):
with T.block("T_add"):
v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap(
"SSSSS", [ax0, ax1, ax2, ax3, ax4]
)
T.reads(
A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4],
B[v_ax0, v_ax1, T.int64(0), T.int64(0), v_ax4],
)
T.writes(T_add[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
T_add[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = (
A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]
+ B[v_ax0, v_ax1, T.int64(0), T.int64(0), v_ax4]
)
@T.prim_func
def add14(
A: T.Buffer(
(T.int64(1), T.int64(128), T.int64(28), T.int64(28), T.int64(4)), "float32"
),
B: T.Buffer(
(T.int64(1), T.int64(128), T.int64(28), T.int64(28), T.int64(4)), "float32"
),
T_add: T.Buffer(
(T.int64(1), T.int64(128), T.int64(28), T.int64(28), T.int64(4)), "float32"
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2, ax3, ax4 in T.grid(
T.int64(1), T.int64(128), T.int64(28), T.int64(28), T.int64(4)
):
with T.block("T_add"):
v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap(
"SSSSS", [ax0, ax1, ax2, ax3, ax4]
)
T.reads(
A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4],
B[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4],
)
T.writes(T_add[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
T_add[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = (
A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]
+ B[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]
)
@T.prim_func
def add15(
A: T.Buffer((T.int64(512),), "float32"),
B: T.Buffer((), "float32"),
T_add: T.Buffer((T.int64(512),), "float32"),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0 in range(T.int64(512)):
with T.block("T_add"):
v_ax0 = T.axis.spatial(T.int64(512), ax0)
T.reads(A[v_ax0], B[()])
T.writes(T_add[v_ax0])
T_add[v_ax0] = A[v_ax0] + B[()]
@T.prim_func
def add16(
A: T.Buffer((T.int64(512),), "float32"),
B: T.Buffer((T.int64(512),), "float32"),
T_add: T.Buffer((T.int64(512),), "float32"),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0 in range(T.int64(512)):
with T.block("T_add"):
v_ax0 = T.axis.spatial(T.int64(512), ax0)
T.reads(A[v_ax0], B[v_ax0])
T.writes(T_add[v_ax0])
T_add[v_ax0] = A[v_ax0] + B[v_ax0]
@T.prim_func
def add17(
A: T.Buffer(
(T.int64(1), T.int64(128), T.int64(28), T.int64(28), T.int64(4)), "float32"
),
B: T.Buffer(
(T.int64(1), T.int64(128), T.int64(1), T.int64(1), T.int64(4)), "float32"
),
T_add: T.Buffer(
(T.int64(1), T.int64(128), T.int64(28), T.int64(28), T.int64(4)), "float32"
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2, ax3, ax4 in T.grid(
T.int64(1), T.int64(128), T.int64(28), T.int64(28), T.int64(4)
):
with T.block("T_add"):
v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap(
"SSSSS", [ax0, ax1, ax2, ax3, ax4]
)
T.reads(
A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4],
B[v_ax0, v_ax1, T.int64(0), T.int64(0), v_ax4],
)
T.writes(T_add[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
T_add[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = (
A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]
+ B[v_ax0, v_ax1, T.int64(0), T.int64(0), v_ax4]
)
@T.prim_func
def add18(
A: T.Buffer(
(T.int64(1), T.int64(64), T.int64(14), T.int64(14), T.int64(4)), "float32"
),
B: T.Buffer(
(T.int64(1), T.int64(64), T.int64(1), T.int64(1), T.int64(4)), "float32"
),
T_add: T.Buffer(
(T.int64(1), T.int64(64), T.int64(14), T.int64(14), T.int64(4)), "float32"
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2, ax3, ax4 in T.grid(
T.int64(1), T.int64(64), T.int64(14), T.int64(14), T.int64(4)
):
with T.block("T_add"):
v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap(
"SSSSS", [ax0, ax1, ax2, ax3, ax4]
)
T.reads(
A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4],
B[v_ax0, v_ax1, T.int64(0), T.int64(0), v_ax4],
)
T.writes(T_add[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
T_add[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = (
A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]
+ B[v_ax0, v_ax1, T.int64(0), T.int64(0), v_ax4]
)
@T.prim_func
def add19(
A: T.Buffer(
(T.int64(1), T.int64(256), T.int64(14), T.int64(14), T.int64(4)), "float32"
),
B: T.Buffer(
(T.int64(1), T.int64(256), T.int64(14), T.int64(14), T.int64(4)), "float32"
),
T_add: T.Buffer(
(T.int64(1), T.int64(256), T.int64(14), T.int64(14), T.int64(4)), "float32"
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2, ax3, ax4 in T.grid(
T.int64(1), T.int64(256), T.int64(14), T.int64(14), T.int64(4)
):
with T.block("T_add"):
v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap(
"SSSSS", [ax0, ax1, ax2, ax3, ax4]
)
T.reads(
A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4],
B[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4],
)
T.writes(T_add[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
T_add[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = (
A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]
+ B[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]
)
@T.prim_func
def add2(
A: T.Buffer((T.int64(1), T.int64(3), T.int64(224), T.int64(224)), "float32"),
B: T.Buffer((T.int64(3), T.int64(1), T.int64(1)), "float32"),
T_add: T.Buffer(
(T.int64(1), T.int64(3), T.int64(224), T.int64(224)), "float32"
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2, ax3 in T.grid(
T.int64(1), T.int64(3), T.int64(224), T.int64(224)
):
with T.block("T_add"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(A[v_ax0, v_ax1, v_ax2, v_ax3], B[v_ax1, T.int64(0), T.int64(0)])
T.writes(T_add[v_ax0, v_ax1, v_ax2, v_ax3])
T_add[v_ax0, v_ax1, v_ax2, v_ax3] = (
A[v_ax0, v_ax1, v_ax2, v_ax3] + B[v_ax1, T.int64(0), T.int64(0)]
)
@T.prim_func
def add20(
A: T.Buffer((T.int64(1024),), "float32"),
B: T.Buffer((), "float32"),
T_add: T.Buffer((T.int64(1024),), "float32"),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0 in range(T.int64(1024)):
with T.block("T_add"):
v_ax0 = T.axis.spatial(T.int64(1024), ax0)
T.reads(A[v_ax0], B[()])
T.writes(T_add[v_ax0])
T_add[v_ax0] = A[v_ax0] + B[()]
@T.prim_func
def add21(
A: T.Buffer((T.int64(1024),), "float32"),
B: T.Buffer((T.int64(1024),), "float32"),
T_add: T.Buffer((T.int64(1024),), "float32"),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0 in range(T.int64(1024)):
with T.block("T_add"):
v_ax0 = T.axis.spatial(T.int64(1024), ax0)
T.reads(A[v_ax0], B[v_ax0])
T.writes(T_add[v_ax0])
T_add[v_ax0] = A[v_ax0] + B[v_ax0]
@T.prim_func
def add22(
A: T.Buffer(
(T.int64(1), T.int64(256), T.int64(14), T.int64(14), T.int64(4)), "float32"
),
B: T.Buffer(
(T.int64(1), T.int64(256), T.int64(1), T.int64(1), T.int64(4)), "float32"
),
T_add: T.Buffer(
(T.int64(1), T.int64(256), T.int64(14), T.int64(14), T.int64(4)), "float32"
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2, ax3, ax4 in T.grid(
T.int64(1), T.int64(256), T.int64(14), T.int64(14), T.int64(4)
):
with T.block("T_add"):
v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap(
"SSSSS", [ax0, ax1, ax2, ax3, ax4]
)
T.reads(
A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4],
B[v_ax0, v_ax1, T.int64(0), T.int64(0), v_ax4],
)
T.writes(T_add[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
T_add[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = (
A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]
+ B[v_ax0, v_ax1, T.int64(0), T.int64(0), v_ax4]
)
@T.prim_func
def add23(
A: T.Buffer(
(T.int64(1), T.int64(128), T.int64(7), T.int64(7), T.int64(4)), "float32"
),
B: T.Buffer(
(T.int64(1), T.int64(128), T.int64(1), T.int64(1), T.int64(4)), "float32"
),
T_add: T.Buffer(
(T.int64(1), T.int64(128), T.int64(7), T.int64(7), T.int64(4)), "float32"
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2, ax3, ax4 in T.grid(
T.int64(1), T.int64(128), T.int64(7), T.int64(7), T.int64(4)
):
with T.block("T_add"):
v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap(
"SSSSS", [ax0, ax1, ax2, ax3, ax4]
)
T.reads(
A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4],
B[v_ax0, v_ax1, T.int64(0), T.int64(0), v_ax4],
)
T.writes(T_add[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
T_add[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = (
A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]
+ B[v_ax0, v_ax1, T.int64(0), T.int64(0), v_ax4]
)
@T.prim_func
def add24(
A: T.Buffer(
(T.int64(1), T.int64(512), T.int64(7), T.int64(7), T.int64(4)), "float32"
),
B: T.Buffer(
(T.int64(1), T.int64(512), T.int64(7), T.int64(7), T.int64(4)), "float32"
),
T_add: T.Buffer(
(T.int64(1), T.int64(512), T.int64(7), T.int64(7), T.int64(4)), "float32"
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2, ax3, ax4 in T.grid(
T.int64(1), T.int64(512), T.int64(7), T.int64(7), T.int64(4)
):
with T.block("T_add"):
v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap(
"SSSSS", [ax0, ax1, ax2, ax3, ax4]
)
T.reads(
A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4],
B[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4],
)
T.writes(T_add[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
T_add[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = (
A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]
+ B[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]
)
@T.prim_func
def add25(
A: T.Buffer((T.int64(2048),), "float32"),
B: T.Buffer((), "float32"),
T_add: T.Buffer((T.int64(2048),), "float32"),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0 in range(T.int64(2048)):
with T.block("T_add"):
v_ax0 = T.axis.spatial(T.int64(2048), ax0)
T.reads(A[v_ax0], B[()])
T.writes(T_add[v_ax0])
T_add[v_ax0] = A[v_ax0] + B[()]
@T.prim_func
def add26(
A: T.Buffer((T.int64(2048),), "float32"),
B: T.Buffer((T.int64(2048),), "float32"),
T_add: T.Buffer((T.int64(2048),), "float32"),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0 in range(T.int64(2048)):
with T.block("T_add"):
v_ax0 = T.axis.spatial(T.int64(2048), ax0)
T.reads(A[v_ax0], B[v_ax0])
T.writes(T_add[v_ax0])
T_add[v_ax0] = A[v_ax0] + B[v_ax0]
@T.prim_func
def add27(
A: T.Buffer(
(T.int64(1), T.int64(512), T.int64(7), T.int64(7), T.int64(4)), "float32"
),
B: T.Buffer(
(T.int64(1), T.int64(512), T.int64(1), T.int64(1), T.int64(4)), "float32"
),
T_add: T.Buffer(
(T.int64(1), T.int64(512), T.int64(7), T.int64(7), T.int64(4)), "float32"
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2, ax3, ax4 in T.grid(
T.int64(1), T.int64(512), T.int64(7), T.int64(7), T.int64(4)
):
with T.block("T_add"):
v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap(
"SSSSS", [ax0, ax1, ax2, ax3, ax4]
)
T.reads(
A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4],
B[v_ax0, v_ax1, T.int64(0), T.int64(0), v_ax4],
)
T.writes(T_add[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
T_add[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = (
A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]
+ B[v_ax0, v_ax1, T.int64(0), T.int64(0), v_ax4]
)
@T.prim_func
def add28(
A: T.Buffer((T.int64(1), T.int64(1000)), "float32"),
B: T.Buffer((T.int64(1), T.int64(1000)), "float32"),
T_add: T.Buffer((T.int64(1), T.int64(1000)), "float32"),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1 in T.grid(T.int64(1), T.int64(1000)):
with T.block("T_add"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(A[v_ax0, v_ax1], B[v_ax0, v_ax1])
T.writes(T_add[v_ax0, v_ax1])
T_add[v_ax0, v_ax1] = A[v_ax0, v_ax1] + B[v_ax0, v_ax1]
@T.prim_func
def add3(
A: T.Buffer((T.int64(64),), "float32"),
B: T.Buffer((), "float32"),
T_add: T.Buffer((T.int64(64),), "float32"),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0 in range(T.int64(64)):
with T.block("T_add"):
v_ax0 = T.axis.spatial(T.int64(64), ax0)
T.reads(A[v_ax0], B[()])
T.writes(T_add[v_ax0])
T_add[v_ax0] = A[v_ax0] + B[()]
@T.prim_func
def add4(
A: T.Buffer((T.int64(64),), "float32"),
B: T.Buffer((T.int64(64),), "float32"),
T_add: T.Buffer((T.int64(64),), "float32"),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0 in range(T.int64(64)):
with T.block("T_add"):
v_ax0 = T.axis.spatial(T.int64(64), ax0)
T.reads(A[v_ax0], B[v_ax0])
T.writes(T_add[v_ax0])
T_add[v_ax0] = A[v_ax0] + B[v_ax0]
@T.prim_func
def add5(
A: T.Buffer(
(T.int64(1), T.int64(16), T.int64(112), T.int64(112), T.int64(4)), "float32"
),
B: T.Buffer(
(T.int64(1), T.int64(16), T.int64(1), T.int64(1), T.int64(4)), "float32"
),
T_add: T.Buffer(
(T.int64(1), T.int64(16), T.int64(112), T.int64(112), T.int64(4)), "float32"
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2, ax3, ax4 in T.grid(
T.int64(1), T.int64(16), T.int64(112), T.int64(112), T.int64(4)
):
with T.block("T_add"):
v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap(
"SSSSS", [ax0, ax1, ax2, ax3, ax4]
)
T.reads(
A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4],
B[v_ax0, v_ax1, T.int64(0), T.int64(0), v_ax4],
)
T.writes(T_add[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
T_add[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = (
A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]
+ B[v_ax0, v_ax1, T.int64(0), T.int64(0), v_ax4]
)
@T.prim_func
def add6(
A: T.Buffer(
(T.int64(1), T.int64(16), T.int64(56), T.int64(56), T.int64(4)), "float32"
),
B: T.Buffer(
(T.int64(1), T.int64(16), T.int64(1), T.int64(1), T.int64(4)), "float32"
),
T_add: T.Buffer(
(T.int64(1), T.int64(16), T.int64(56), T.int64(56), T.int64(4)), "float32"
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2, ax3, ax4 in T.grid(
T.int64(1), T.int64(16), T.int64(56), T.int64(56), T.int64(4)
):
with T.block("T_add"):
v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap(
"SSSSS", [ax0, ax1, ax2, ax3, ax4]
)
T.reads(
A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4],
B[v_ax0, v_ax1, T.int64(0), T.int64(0), v_ax4],
)
T.writes(T_add[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
T_add[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = (
A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]
+ B[v_ax0, v_ax1, T.int64(0), T.int64(0), v_ax4]
)
@T.prim_func
def add7(
A: T.Buffer(
(T.int64(1), T.int64(64), T.int64(56), T.int64(56), T.int64(4)), "float32"
),
B: T.Buffer(
(T.int64(1), T.int64(64), T.int64(56), T.int64(56), T.int64(4)), "float32"
),
T_add: T.Buffer(
(T.int64(1), T.int64(64), T.int64(56), T.int64(56), T.int64(4)), "float32"
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2, ax3, ax4 in T.grid(
T.int64(1), T.int64(64), T.int64(56), T.int64(56), T.int64(4)
):
with T.block("T_add"):
v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap(
"SSSSS", [ax0, ax1, ax2, ax3, ax4]
)
T.reads(
A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4],
B[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4],
)
T.writes(T_add[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
T_add[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = (
A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]
+ B[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]
)
@T.prim_func
def add8(
A: T.Buffer((T.int64(256),), "float32"),
B: T.Buffer((), "float32"),
T_add: T.Buffer((T.int64(256),), "float32"),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0 in range(T.int64(256)):
with T.block("T_add"):
v_ax0 = T.axis.spatial(T.int64(256), ax0)
T.reads(A[v_ax0], B[()])
T.writes(T_add[v_ax0])
T_add[v_ax0] = A[v_ax0] + B[()]
@T.prim_func
def add9(
A: T.Buffer((T.int64(256),), "float32"),
B: T.Buffer((T.int64(256),), "float32"),
T_add: T.Buffer((T.int64(256),), "float32"),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0 in range(T.int64(256)):
with T.block("T_add"):
v_ax0 = T.axis.spatial(T.int64(256), ax0)
T.reads(A[v_ax0], B[v_ax0])
T.writes(T_add[v_ax0])
T_add[v_ax0] = A[v_ax0] + B[v_ax0]
@T.prim_func
def batch_flatten(
A: T.Buffer((T.int64(1), T.int64(2048), T.int64(1), T.int64(1)), "float32"),
tensor: T.Buffer((T.int64(1), T.int64(2048)), "float32"),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1 in T.grid(T.int64(1), T.int64(2048)):
with T.block("tensor"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(A[v_ax0, v_ax1 % T.int64(2048), T.int64(0), T.int64(0)])
T.writes(tensor[v_ax0, v_ax1])
tensor[v_ax0, v_ax1] = A[
v_ax0, v_ax1 % T.int64(2048), T.int64(0), T.int64(0)
]
@T.prim_func
def contrib_conv2d_NCHWc(
A: T.Buffer(
(T.int64(1), T.int64(1), T.int64(224), T.int64(224), T.int64(3)), "float32"
),
B: T.Buffer(
(T.int64(16), T.int64(1), T.int64(7), T.int64(7), T.int64(3), T.int64(4)),
"float32",
),
conv2d_NCHWc: T.Buffer(
(T.int64(1), T.int64(16), T.int64(112), T.int64(112), T.int64(4)), "float32"
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
data_pad = T.alloc_buffer(
(T.int64(1), T.int64(1), T.int64(230), T.int64(230), T.int64(3))
)
for i0, i1, i2, i3, i4 in T.grid(
T.int64(1), T.int64(1), T.int64(230), T.int64(230), T.int64(3)
):
with T.block("data_pad"):
v_i0, v_i1, v_i2, v_i3, v_i4 = T.axis.remap(
"SSSSS", [i0, i1, i2, i3, i4]
)
T.reads(A[v_i0, v_i1, v_i2 - T.int64(3), v_i3 - T.int64(3), v_i4])
T.writes(data_pad[v_i0, v_i1, v_i2, v_i3, v_i4])
data_pad[v_i0, v_i1, v_i2, v_i3, v_i4] = T.if_then_else(
T.int64(3) <= v_i2
and v_i2 < T.int64(227)
and T.int64(3) <= v_i3
and v_i3 < T.int64(227),
A[v_i0, v_i1, v_i2 - T.int64(3), v_i3 - T.int64(3), v_i4],
T.float32(0),
)
for n, oc_chunk, oh, ow, oc_block, ic, kh, kw in T.grid(
T.int64(1),
T.int64(16),
T.int64(112),
T.int64(112),
T.int64(4),
T.int64(3),
T.int64(7),
T.int64(7),
):
with T.block("conv2d_NCHWc"):
(
v_n,
v_oc_chunk,
v_oh,
v_ow,
v_oc_block,
v_ic,
v_kh,
v_kw,
) = T.axis.remap(
"SSSSSRRR", [n, oc_chunk, oh, ow, oc_block, ic, kh, kw]
)
T.reads(
data_pad[
v_n,
v_ic // T.int64(3),
v_oh * T.int64(2) + v_kh,
v_ow * T.int64(2) + v_kw,
v_ic % T.int64(3),
],
B[
v_oc_chunk,
v_ic // T.int64(3),
v_kh,
v_kw,
v_ic % T.int64(3),
v_oc_block,
],
)
T.writes(conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block])
with T.init():
conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block] = T.float32(0)
conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block] = (
conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block]
+ data_pad[
v_n,
v_ic // T.int64(3),
v_oh * T.int64(2) + v_kh,
v_ow * T.int64(2) + v_kw,
v_ic % T.int64(3),
]
* B[
v_oc_chunk,
v_ic // T.int64(3),
v_kh,
v_kw,
v_ic % T.int64(3),
v_oc_block,
]
)
@T.prim_func
def contrib_conv2d_NCHWc1(
A: T.Buffer(
(T.int64(1), T.int64(16), T.int64(56), T.int64(56), T.int64(4)), "float32"
),
B: T.Buffer(
(T.int64(16), T.int64(16), T.int64(1), T.int64(1), T.int64(4), T.int64(4)),
"float32",
),
conv2d_NCHWc: T.Buffer(
(T.int64(1), T.int64(16), T.int64(56), T.int64(56), T.int64(4)), "float32"
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for n, oc_chunk, oh, ow, oc_block, ic, kh, kw in T.grid(
T.int64(1),
T.int64(16),
T.int64(56),
T.int64(56),
T.int64(4),
T.int64(64),
T.int64(1),
T.int64(1),
):
with T.block("conv2d_NCHWc"):
(
v_n,
v_oc_chunk,
v_oh,
v_ow,
v_oc_block,
v_ic,
v_kh,
v_kw,
) = T.axis.remap(
"SSSSSRRR", [n, oc_chunk, oh, ow, oc_block, ic, kh, kw]
)
T.reads(
A[
v_n,
v_ic // T.int64(4),
v_oh + v_kh,
v_ow + v_kw,
v_ic % T.int64(4),
],
B[
v_oc_chunk,
v_ic // T.int64(4),
v_kh,
v_kw,
v_ic % T.int64(4),
v_oc_block,
],
)
T.writes(conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block])
with T.init():
conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block] = T.float32(0)
conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block] = (
conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block]
+ A[
v_n,
v_ic // T.int64(4),
v_oh + v_kh,
v_ow + v_kw,
v_ic % T.int64(4),
]
* B[
v_oc_chunk,
v_ic // T.int64(4),
v_kh,
v_kw,
v_ic % T.int64(4),
v_oc_block,
]
)
@T.prim_func
def contrib_conv2d_NCHWc10(
A: T.Buffer(
(T.int64(1), T.int64(128), T.int64(28), T.int64(28), T.int64(4)), "float32"
),
B: T.Buffer(
(T.int64(64), T.int64(128), T.int64(1), T.int64(1), T.int64(4), T.int64(4)),
"float32",
),
conv2d_NCHWc: T.Buffer(
(T.int64(1), T.int64(64), T.int64(14), T.int64(14), T.int64(4)), "float32"
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for n, oc_chunk, oh, ow, oc_block, ic, kh, kw in T.grid(
T.int64(1),
T.int64(64),
T.int64(14),
T.int64(14),
T.int64(4),
T.int64(512),
T.int64(1),
T.int64(1),
):
with T.block("conv2d_NCHWc"):
(
v_n,
v_oc_chunk,
v_oh,
v_ow,
v_oc_block,
v_ic,
v_kh,
v_kw,
) = T.axis.remap(
"SSSSSRRR", [n, oc_chunk, oh, ow, oc_block, ic, kh, kw]
)
T.reads(
A[
v_n,
v_ic // T.int64(4),
v_oh * T.int64(2) + v_kh,
v_ow * T.int64(2) + v_kw,
v_ic % T.int64(4),
],
B[
v_oc_chunk,
v_ic // T.int64(4),
v_kh,
v_kw,
v_ic % T.int64(4),
v_oc_block,
],
)
T.writes(conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block])
with T.init():
conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block] = T.float32(0)
conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block] = (
conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block]
+ A[
v_n,
v_ic // T.int64(4),
v_oh * T.int64(2) + v_kh,
v_ow * T.int64(2) + v_kw,
v_ic % T.int64(4),
]
* B[
v_oc_chunk,
v_ic // T.int64(4),
v_kh,
v_kw,
v_ic % T.int64(4),
v_oc_block,
]
)
@T.prim_func
def contrib_conv2d_NCHWc11(
A: T.Buffer(
(T.int64(1), T.int64(64), T.int64(14), T.int64(14), T.int64(4)), "float32"
),
B: T.Buffer(
(T.int64(64), T.int64(64), T.int64(3), T.int64(3), T.int64(4), T.int64(4)),
"float32",
),
conv2d_NCHWc: T.Buffer(
(T.int64(1), T.int64(64), T.int64(14), T.int64(14), T.int64(4)), "float32"
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
data_pad = T.alloc_buffer(
(T.int64(1), T.int64(64), T.int64(16), T.int64(16), T.int64(4))
)
for i0, i1, i2, i3, i4 in T.grid(
T.int64(1), T.int64(64), T.int64(16), T.int64(16), T.int64(4)
):
with T.block("data_pad"):
v_i0, v_i1, v_i2, v_i3, v_i4 = T.axis.remap(
"SSSSS", [i0, i1, i2, i3, i4]
)
T.reads(A[v_i0, v_i1, v_i2 - T.int64(1), v_i3 - T.int64(1), v_i4])
T.writes(data_pad[v_i0, v_i1, v_i2, v_i3, v_i4])
data_pad[v_i0, v_i1, v_i2, v_i3, v_i4] = T.if_then_else(
T.int64(1) <= v_i2
and v_i2 < T.int64(15)
and T.int64(1) <= v_i3
and v_i3 < T.int64(15),
A[v_i0, v_i1, v_i2 - T.int64(1), v_i3 - T.int64(1), v_i4],
T.float32(0),
)
for n, oc_chunk, oh, ow, oc_block, ic, kh, kw in T.grid(
T.int64(1),
T.int64(64),
T.int64(14),
T.int64(14),
T.int64(4),
T.int64(256),
T.int64(3),
T.int64(3),
):
with T.block("conv2d_NCHWc"):
(
v_n,
v_oc_chunk,
v_oh,
v_ow,
v_oc_block,
v_ic,
v_kh,
v_kw,
) = T.axis.remap(
"SSSSSRRR", [n, oc_chunk, oh, ow, oc_block, ic, kh, kw]
)
T.reads(
data_pad[
v_n,
v_ic // T.int64(4),
v_oh + v_kh,
v_ow + v_kw,
v_ic % T.int64(4),
],
B[
v_oc_chunk,
v_ic // T.int64(4),
v_kh,
v_kw,
v_ic % T.int64(4),
v_oc_block,
],
)
T.writes(conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block])
with T.init():
conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block] = T.float32(0)
conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block] = (
conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block]
+ data_pad[
v_n,
v_ic // T.int64(4),
v_oh + v_kh,
v_ow + v_kw,
v_ic % T.int64(4),
]
* B[
v_oc_chunk,
v_ic // T.int64(4),
v_kh,
v_kw,
v_ic % T.int64(4),
v_oc_block,
]
)
@T.prim_func
def contrib_conv2d_NCHWc12(
A: T.Buffer(
(T.int64(1), T.int64(64), T.int64(14), T.int64(14), T.int64(4)), "float32"
),
B: T.Buffer(
(T.int64(256), T.int64(64), T.int64(1), T.int64(1), T.int64(4), T.int64(4)),
"float32",
),
conv2d_NCHWc: T.Buffer(
(T.int64(1), T.int64(256), T.int64(14), T.int64(14), T.int64(4)), "float32"
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for n, oc_chunk, oh, ow, oc_block, ic, kh, kw in T.grid(
T.int64(1),
T.int64(256),
T.int64(14),
T.int64(14),
T.int64(4),
T.int64(256),
T.int64(1),
T.int64(1),
):
with T.block("conv2d_NCHWc"):
(
v_n,
v_oc_chunk,
v_oh,
v_ow,
v_oc_block,
v_ic,
v_kh,
v_kw,
) = T.axis.remap(
"SSSSSRRR", [n, oc_chunk, oh, ow, oc_block, ic, kh, kw]
)
T.reads(
A[
v_n,
v_ic // T.int64(4),
v_oh + v_kh,
v_ow + v_kw,
v_ic % T.int64(4),
],
B[
v_oc_chunk,
v_ic // T.int64(4),
v_kh,
v_kw,
v_ic % T.int64(4),
v_oc_block,
],
)
T.writes(conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block])
with T.init():
conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block] = T.float32(0)
conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block] = (
conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block]
+ A[
v_n,
v_ic // T.int64(4),
v_oh + v_kh,
v_ow + v_kw,
v_ic % T.int64(4),
]
* B[
v_oc_chunk,
v_ic // T.int64(4),
v_kh,
v_kw,
v_ic % T.int64(4),
v_oc_block,
]
)
@T.prim_func
def contrib_conv2d_NCHWc13(
A: T.Buffer(
(T.int64(1), T.int64(128), T.int64(28), T.int64(28), T.int64(4)), "float32"
),
B: T.Buffer(
(
T.int64(256),
T.int64(128),
T.int64(1),
T.int64(1),
T.int64(4),
T.int64(4),
),
"float32",
),
conv2d_NCHWc: T.Buffer(
(T.int64(1), T.int64(256), T.int64(14), T.int64(14), T.int64(4)), "float32"
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for n, oc_chunk, oh, ow, oc_block, ic, kh, kw in T.grid(
T.int64(1),
T.int64(256),
T.int64(14),
T.int64(14),
T.int64(4),
T.int64(512),
T.int64(1),
T.int64(1),
):
with T.block("conv2d_NCHWc"):
(
v_n,
v_oc_chunk,
v_oh,
v_ow,
v_oc_block,
v_ic,
v_kh,
v_kw,
) = T.axis.remap(
"SSSSSRRR", [n, oc_chunk, oh, ow, oc_block, ic, kh, kw]
)
T.reads(
A[
v_n,
v_ic // T.int64(4),
v_oh * T.int64(2) + v_kh,
v_ow * T.int64(2) + v_kw,
v_ic % T.int64(4),
],
B[
v_oc_chunk,
v_ic // T.int64(4),
v_kh,
v_kw,
v_ic % T.int64(4),
v_oc_block,
],
)
T.writes(conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block])
with T.init():
conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block] = T.float32(0)
conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block] = (
conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block]
+ A[
v_n,
v_ic // T.int64(4),
v_oh * T.int64(2) + v_kh,
v_ow * T.int64(2) + v_kw,
v_ic % T.int64(4),
]
* B[
v_oc_chunk,
v_ic // T.int64(4),
v_kh,
v_kw,
v_ic % T.int64(4),
v_oc_block,
]
)
@T.prim_func
def contrib_conv2d_NCHWc14(
A: T.Buffer(
(T.int64(1), T.int64(256), T.int64(14), T.int64(14), T.int64(4)), "float32"
),
B: T.Buffer(
(T.int64(64), T.int64(256), T.int64(1), T.int64(1), T.int64(4), T.int64(4)),
"float32",
),
conv2d_NCHWc: T.Buffer(
(T.int64(1), T.int64(64), T.int64(14), T.int64(14), T.int64(4)), "float32"
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for n, oc_chunk, oh, ow, oc_block, ic, kh, kw in T.grid(
T.int64(1),
T.int64(64),
T.int64(14),
T.int64(14),
T.int64(4),
T.int64(1024),
T.int64(1),
T.int64(1),
):
with T.block("conv2d_NCHWc"):
(
v_n,
v_oc_chunk,
v_oh,
v_ow,
v_oc_block,
v_ic,
v_kh,
v_kw,
) = T.axis.remap(
"SSSSSRRR", [n, oc_chunk, oh, ow, oc_block, ic, kh, kw]
)
T.reads(
A[
v_n,
v_ic // T.int64(4),
v_oh + v_kh,
v_ow + v_kw,
v_ic % T.int64(4),
],
B[
v_oc_chunk,
v_ic // T.int64(4),
v_kh,
v_kw,
v_ic % T.int64(4),
v_oc_block,
],
)
T.writes(conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block])
with T.init():
conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block] = T.float32(0)
conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block] = (
conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block]
+ A[
v_n,
v_ic // T.int64(4),
v_oh + v_kh,
v_ow + v_kw,
v_ic % T.int64(4),
]
* B[
v_oc_chunk,
v_ic // T.int64(4),
v_kh,
v_kw,
v_ic % T.int64(4),
v_oc_block,
]
)
@T.prim_func
def contrib_conv2d_NCHWc15(
A: T.Buffer(
(T.int64(1), T.int64(256), T.int64(14), T.int64(14), T.int64(4)), "float32"
),
B: T.Buffer(
(
T.int64(128),
T.int64(256),
T.int64(1),
T.int64(1),
T.int64(4),
T.int64(4),
),
"float32",
),
conv2d_NCHWc: T.Buffer(
(T.int64(1), T.int64(128), T.int64(7), T.int64(7), T.int64(4)), "float32"
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for n, oc_chunk, oh, ow, oc_block, ic, kh, kw in T.grid(
T.int64(1),
T.int64(128),
T.int64(7),
T.int64(7),
T.int64(4),
T.int64(1024),
T.int64(1),
T.int64(1),
):
with T.block("conv2d_NCHWc"):
(
v_n,
v_oc_chunk,
v_oh,
v_ow,
v_oc_block,
v_ic,
v_kh,
v_kw,
) = T.axis.remap(
"SSSSSRRR", [n, oc_chunk, oh, ow, oc_block, ic, kh, kw]
)
T.reads(
A[
v_n,
v_ic // T.int64(4),
v_oh * T.int64(2) + v_kh,
v_ow * T.int64(2) + v_kw,
v_ic % T.int64(4),
],
B[
v_oc_chunk,
v_ic // T.int64(4),
v_kh,
v_kw,
v_ic % T.int64(4),
v_oc_block,
],
)
T.writes(conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block])
with T.init():
conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block] = T.float32(0)
conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block] = (
conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block]
+ A[
v_n,
v_ic // T.int64(4),
v_oh * T.int64(2) + v_kh,
v_ow * T.int64(2) + v_kw,
v_ic % T.int64(4),
]
* B[
v_oc_chunk,
v_ic // T.int64(4),
v_kh,
v_kw,
v_ic % T.int64(4),
v_oc_block,
]
)
@T.prim_func
def contrib_conv2d_NCHWc16(
A: T.Buffer(
(T.int64(1), T.int64(128), T.int64(7), T.int64(7), T.int64(4)), "float32"
),
B: T.Buffer(
(
T.int64(128),
T.int64(128),
T.int64(3),
T.int64(3),
T.int64(4),
T.int64(4),
),
"float32",
),
conv2d_NCHWc: T.Buffer(
(T.int64(1), T.int64(128), T.int64(7), T.int64(7), T.int64(4)), "float32"
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
data_pad = T.alloc_buffer(
(T.int64(1), T.int64(128), T.int64(9), T.int64(9), T.int64(4))
)
for i0, i1, i2, i3, i4 in T.grid(
T.int64(1), T.int64(128), T.int64(9), T.int64(9), T.int64(4)
):
with T.block("data_pad"):
v_i0, v_i1, v_i2, v_i3, v_i4 = T.axis.remap(
"SSSSS", [i0, i1, i2, i3, i4]
)
T.reads(A[v_i0, v_i1, v_i2 - T.int64(1), v_i3 - T.int64(1), v_i4])
T.writes(data_pad[v_i0, v_i1, v_i2, v_i3, v_i4])
data_pad[v_i0, v_i1, v_i2, v_i3, v_i4] = T.if_then_else(
T.int64(1) <= v_i2
and v_i2 < T.int64(8)
and T.int64(1) <= v_i3
and v_i3 < T.int64(8),
A[v_i0, v_i1, v_i2 - T.int64(1), v_i3 - T.int64(1), v_i4],
T.float32(0),
)
for n, oc_chunk, oh, ow, oc_block, ic, kh, kw in T.grid(
T.int64(1),
T.int64(128),
T.int64(7),
T.int64(7),
T.int64(4),
T.int64(512),
T.int64(3),
T.int64(3),
):
with T.block("conv2d_NCHWc"):
(
v_n,
v_oc_chunk,
v_oh,
v_ow,
v_oc_block,
v_ic,
v_kh,
v_kw,
) = T.axis.remap(
"SSSSSRRR", [n, oc_chunk, oh, ow, oc_block, ic, kh, kw]
)
T.reads(
data_pad[
v_n,
v_ic // T.int64(4),
v_oh + v_kh,
v_ow + v_kw,
v_ic % T.int64(4),
],
B[
v_oc_chunk,
v_ic // T.int64(4),
v_kh,
v_kw,
v_ic % T.int64(4),
v_oc_block,
],
)
T.writes(conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block])
with T.init():
conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block] = T.float32(0)
conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block] = (
conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block]
+ data_pad[
v_n,
v_ic // T.int64(4),
v_oh + v_kh,
v_ow + v_kw,
v_ic % T.int64(4),
]
* B[
v_oc_chunk,
v_ic // T.int64(4),
v_kh,
v_kw,
v_ic % T.int64(4),
v_oc_block,
]
)
@T.prim_func
def contrib_conv2d_NCHWc17(
A: T.Buffer(
(T.int64(1), T.int64(128), T.int64(7), T.int64(7), T.int64(4)), "float32"
),
B: T.Buffer(
(
T.int64(512),
T.int64(128),
T.int64(1),
T.int64(1),
T.int64(4),
T.int64(4),
),
"float32",
),
conv2d_NCHWc: T.Buffer(
(T.int64(1), T.int64(512), T.int64(7), T.int64(7), T.int64(4)), "float32"
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for n, oc_chunk, oh, ow, oc_block, ic, kh, kw in T.grid(
T.int64(1),
T.int64(512),
T.int64(7),
T.int64(7),
T.int64(4),
T.int64(512),
T.int64(1),
T.int64(1),
):
with T.block("conv2d_NCHWc"):
(
v_n,
v_oc_chunk,
v_oh,
v_ow,
v_oc_block,
v_ic,
v_kh,
v_kw,
) = T.axis.remap(
"SSSSSRRR", [n, oc_chunk, oh, ow, oc_block, ic, kh, kw]
)
T.reads(
A[
v_n,
v_ic // T.int64(4),
v_oh + v_kh,
v_ow + v_kw,
v_ic % T.int64(4),
],
B[
v_oc_chunk,
v_ic // T.int64(4),
v_kh,
v_kw,
v_ic % T.int64(4),
v_oc_block,
],
)
T.writes(conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block])
with T.init():
conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block] = T.float32(0)
conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block] = (
conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block]
+ A[
v_n,
v_ic // T.int64(4),
v_oh + v_kh,
v_ow + v_kw,
v_ic % T.int64(4),
]
* B[
v_oc_chunk,
v_ic // T.int64(4),
v_kh,
v_kw,
v_ic % T.int64(4),
v_oc_block,
]
)
@T.prim_func
def contrib_conv2d_NCHWc18(
A: T.Buffer(
(T.int64(1), T.int64(256), T.int64(14), T.int64(14), T.int64(4)), "float32"
),
B: T.Buffer(
(
T.int64(512),
T.int64(256),
T.int64(1),
T.int64(1),
T.int64(4),
T.int64(4),
),
"float32",
),
conv2d_NCHWc: T.Buffer(
(T.int64(1), T.int64(512), T.int64(7), T.int64(7), T.int64(4)), "float32"
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for n, oc_chunk, oh, ow, oc_block, ic, kh, kw in T.grid(
T.int64(1),
T.int64(512),
T.int64(7),
T.int64(7),
T.int64(4),
T.int64(1024),
T.int64(1),
T.int64(1),
):
with T.block("conv2d_NCHWc"):
(
v_n,
v_oc_chunk,
v_oh,
v_ow,
v_oc_block,
v_ic,
v_kh,
v_kw,
) = T.axis.remap(
"SSSSSRRR", [n, oc_chunk, oh, ow, oc_block, ic, kh, kw]
)
T.reads(
A[
v_n,
v_ic // T.int64(4),
v_oh * T.int64(2) + v_kh,
v_ow * T.int64(2) + v_kw,
v_ic % T.int64(4),
],
B[
v_oc_chunk,
v_ic // T.int64(4),
v_kh,
v_kw,
v_ic % T.int64(4),
v_oc_block,
],
)
T.writes(conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block])
with T.init():
conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block] = T.float32(0)
conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block] = (
conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block]
+ A[
v_n,
v_ic // T.int64(4),
v_oh * T.int64(2) + v_kh,
v_ow * T.int64(2) + v_kw,
v_ic % T.int64(4),
]
* B[
v_oc_chunk,
v_ic // T.int64(4),
v_kh,
v_kw,
v_ic % T.int64(4),
v_oc_block,
]
)
@T.prim_func
def contrib_conv2d_NCHWc19(
A: T.Buffer(
(T.int64(1), T.int64(512), T.int64(7), T.int64(7), T.int64(4)), "float32"
),
B: T.Buffer(
(
T.int64(128),
T.int64(512),
T.int64(1),
T.int64(1),
T.int64(4),
T.int64(4),
),
"float32",
),
conv2d_NCHWc: T.Buffer(
(T.int64(1), T.int64(128), T.int64(7), T.int64(7), T.int64(4)), "float32"
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for n, oc_chunk, oh, ow, oc_block, ic, kh, kw in T.grid(
T.int64(1),
T.int64(128),
T.int64(7),
T.int64(7),
T.int64(4),
T.int64(2048),
T.int64(1),
T.int64(1),
):
with T.block("conv2d_NCHWc"):
(
v_n,
v_oc_chunk,
v_oh,
v_ow,
v_oc_block,
v_ic,
v_kh,
v_kw,
) = T.axis.remap(
"SSSSSRRR", [n, oc_chunk, oh, ow, oc_block, ic, kh, kw]
)
T.reads(
A[
v_n,
v_ic // T.int64(4),
v_oh + v_kh,
v_ow + v_kw,
v_ic % T.int64(4),
],
B[
v_oc_chunk,
v_ic // T.int64(4),
v_kh,
v_kw,
v_ic % T.int64(4),
v_oc_block,
],
)
T.writes(conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block])
with T.init():
conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block] = T.float32(0)
conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block] = (
conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block]
+ A[
v_n,
v_ic // T.int64(4),
v_oh + v_kh,
v_ow + v_kw,
v_ic % T.int64(4),
]
* B[
v_oc_chunk,
v_ic // T.int64(4),
v_kh,
v_kw,
v_ic % T.int64(4),
v_oc_block,
]
)
@T.prim_func
def contrib_conv2d_NCHWc2(
A: T.Buffer(
(T.int64(1), T.int64(16), T.int64(56), T.int64(56), T.int64(4)), "float32"
),
B: T.Buffer(
(T.int64(16), T.int64(16), T.int64(3), T.int64(3), T.int64(4), T.int64(4)),
"float32",
),
conv2d_NCHWc: T.Buffer(
(T.int64(1), T.int64(16), T.int64(56), T.int64(56), T.int64(4)), "float32"
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
data_pad = T.alloc_buffer(
(T.int64(1), T.int64(16), T.int64(58), T.int64(58), T.int64(4))
)
for i0, i1, i2, i3, i4 in T.grid(
T.int64(1), T.int64(16), T.int64(58), T.int64(58), T.int64(4)
):
with T.block("data_pad"):
v_i0, v_i1, v_i2, v_i3, v_i4 = T.axis.remap(
"SSSSS", [i0, i1, i2, i3, i4]
)
T.reads(A[v_i0, v_i1, v_i2 - T.int64(1), v_i3 - T.int64(1), v_i4])
T.writes(data_pad[v_i0, v_i1, v_i2, v_i3, v_i4])
data_pad[v_i0, v_i1, v_i2, v_i3, v_i4] = T.if_then_else(
T.int64(1) <= v_i2
and v_i2 < T.int64(57)
and T.int64(1) <= v_i3
and v_i3 < T.int64(57),
A[v_i0, v_i1, v_i2 - T.int64(1), v_i3 - T.int64(1), v_i4],
T.float32(0),
)
for n, oc_chunk, oh, ow, oc_block, ic, kh, kw in T.grid(
T.int64(1),
T.int64(16),
T.int64(56),
T.int64(56),
T.int64(4),
T.int64(64),
T.int64(3),
T.int64(3),
):
with T.block("conv2d_NCHWc"):
(
v_n,
v_oc_chunk,
v_oh,
v_ow,
v_oc_block,
v_ic,
v_kh,
v_kw,
) = T.axis.remap(
"SSSSSRRR", [n, oc_chunk, oh, ow, oc_block, ic, kh, kw]
)
T.reads(
data_pad[
v_n,
v_ic // T.int64(4),
v_oh + v_kh,
v_ow + v_kw,
v_ic % T.int64(4),
],
B[
v_oc_chunk,
v_ic // T.int64(4),
v_kh,
v_kw,
v_ic % T.int64(4),
v_oc_block,
],
)
T.writes(conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block])
with T.init():
conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block] = T.float32(0)
conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block] = (
conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block]
+ data_pad[
v_n,
v_ic // T.int64(4),
v_oh + v_kh,
v_ow + v_kw,
v_ic % T.int64(4),
]
* B[
v_oc_chunk,
v_ic // T.int64(4),
v_kh,
v_kw,
v_ic % T.int64(4),
v_oc_block,
]
)
@T.prim_func
def contrib_conv2d_NCHWc3(
A: T.Buffer(
(T.int64(1), T.int64(16), T.int64(56), T.int64(56), T.int64(4)), "float32"
),
B: T.Buffer(
(T.int64(64), T.int64(16), T.int64(1), T.int64(1), T.int64(4), T.int64(4)),
"float32",
),
conv2d_NCHWc: T.Buffer(
(T.int64(1), T.int64(64), T.int64(56), T.int64(56), T.int64(4)), "float32"
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for n, oc_chunk, oh, ow, oc_block, ic, kh, kw in T.grid(
T.int64(1),
T.int64(64),
T.int64(56),
T.int64(56),
T.int64(4),
T.int64(64),
T.int64(1),
T.int64(1),
):
with T.block("conv2d_NCHWc"):
(
v_n,
v_oc_chunk,
v_oh,
v_ow,
v_oc_block,
v_ic,
v_kh,
v_kw,
) = T.axis.remap(
"SSSSSRRR", [n, oc_chunk, oh, ow, oc_block, ic, kh, kw]
)
T.reads(
A[
v_n,
v_ic // T.int64(4),
v_oh + v_kh,
v_ow + v_kw,
v_ic % T.int64(4),
],
B[
v_oc_chunk,
v_ic // T.int64(4),
v_kh,
v_kw,
v_ic % T.int64(4),
v_oc_block,
],
)
T.writes(conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block])
with T.init():
conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block] = T.float32(0)
conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block] = (
conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block]
+ A[
v_n,
v_ic // T.int64(4),
v_oh + v_kh,
v_ow + v_kw,
v_ic % T.int64(4),
]
* B[
v_oc_chunk,
v_ic // T.int64(4),
v_kh,
v_kw,
v_ic % T.int64(4),
v_oc_block,
]
)
@T.prim_func
def contrib_conv2d_NCHWc4(
A: T.Buffer(
(T.int64(1), T.int64(64), T.int64(56), T.int64(56), T.int64(4)), "float32"
),
B: T.Buffer(
(T.int64(16), T.int64(64), T.int64(1), T.int64(1), T.int64(4), T.int64(4)),
"float32",
),
conv2d_NCHWc: T.Buffer(
(T.int64(1), T.int64(16), T.int64(56), T.int64(56), T.int64(4)), "float32"
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for n, oc_chunk, oh, ow, oc_block, ic, kh, kw in T.grid(
T.int64(1),
T.int64(16),
T.int64(56),
T.int64(56),
T.int64(4),
T.int64(256),
T.int64(1),
T.int64(1),
):
with T.block("conv2d_NCHWc"):
(
v_n,
v_oc_chunk,
v_oh,
v_ow,
v_oc_block,
v_ic,
v_kh,
v_kw,
) = T.axis.remap(
"SSSSSRRR", [n, oc_chunk, oh, ow, oc_block, ic, kh, kw]
)
T.reads(
A[
v_n,
v_ic // T.int64(4),
v_oh + v_kh,
v_ow + v_kw,
v_ic % T.int64(4),
],
B[
v_oc_chunk,
v_ic // T.int64(4),
v_kh,
v_kw,
v_ic % T.int64(4),
v_oc_block,
],
)
T.writes(conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block])
with T.init():
conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block] = T.float32(0)
conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block] = (
conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block]
+ A[
v_n,
v_ic // T.int64(4),
v_oh + v_kh,
v_ow + v_kw,
v_ic % T.int64(4),
]
* B[
v_oc_chunk,
v_ic // T.int64(4),
v_kh,
v_kw,
v_ic % T.int64(4),
v_oc_block,
]
)
@T.prim_func
def contrib_conv2d_NCHWc5(
A: T.Buffer(
(T.int64(1), T.int64(64), T.int64(56), T.int64(56), T.int64(4)), "float32"
),
B: T.Buffer(
(T.int64(32), T.int64(64), T.int64(1), T.int64(1), T.int64(4), T.int64(4)),
"float32",
),
conv2d_NCHWc: T.Buffer(
(T.int64(1), T.int64(32), T.int64(28), T.int64(28), T.int64(4)), "float32"
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for n, oc_chunk, oh, ow, oc_block, ic, kh, kw in T.grid(
T.int64(1),
T.int64(32),
T.int64(28),
T.int64(28),
T.int64(4),
T.int64(256),
T.int64(1),
T.int64(1),
):
with T.block("conv2d_NCHWc"):
(
v_n,
v_oc_chunk,
v_oh,
v_ow,
v_oc_block,
v_ic,
v_kh,
v_kw,
) = T.axis.remap(
"SSSSSRRR", [n, oc_chunk, oh, ow, oc_block, ic, kh, kw]
)
T.reads(
A[
v_n,
v_ic // T.int64(4),
v_oh * T.int64(2) + v_kh,
v_ow * T.int64(2) + v_kw,
v_ic % T.int64(4),
],
B[
v_oc_chunk,
v_ic // T.int64(4),
v_kh,
v_kw,
v_ic % T.int64(4),
v_oc_block,
],
)
T.writes(conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block])
with T.init():
conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block] = T.float32(0)
conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block] = (
conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block]
+ A[
v_n,
v_ic // T.int64(4),
v_oh * T.int64(2) + v_kh,
v_ow * T.int64(2) + v_kw,
v_ic % T.int64(4),
]
* B[
v_oc_chunk,
v_ic // T.int64(4),
v_kh,
v_kw,
v_ic % T.int64(4),
v_oc_block,
]
)
@T.prim_func
def contrib_conv2d_NCHWc6(
A: T.Buffer(
(T.int64(1), T.int64(32), T.int64(28), T.int64(28), T.int64(4)), "float32"
),
B: T.Buffer(
(T.int64(32), T.int64(32), T.int64(3), T.int64(3), T.int64(4), T.int64(4)),
"float32",
),
conv2d_NCHWc: T.Buffer(
(T.int64(1), T.int64(32), T.int64(28), T.int64(28), T.int64(4)), "float32"
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
data_pad = T.alloc_buffer(
(T.int64(1), T.int64(32), T.int64(30), T.int64(30), T.int64(4))
)
for i0, i1, i2, i3, i4 in T.grid(
T.int64(1), T.int64(32), T.int64(30), T.int64(30), T.int64(4)
):
with T.block("data_pad"):
v_i0, v_i1, v_i2, v_i3, v_i4 = T.axis.remap(
"SSSSS", [i0, i1, i2, i3, i4]
)
T.reads(A[v_i0, v_i1, v_i2 - T.int64(1), v_i3 - T.int64(1), v_i4])
T.writes(data_pad[v_i0, v_i1, v_i2, v_i3, v_i4])
data_pad[v_i0, v_i1, v_i2, v_i3, v_i4] = T.if_then_else(
T.int64(1) <= v_i2
and v_i2 < T.int64(29)
and T.int64(1) <= v_i3
and v_i3 < T.int64(29),
A[v_i0, v_i1, v_i2 - T.int64(1), v_i3 - T.int64(1), v_i4],
T.float32(0),
)
for n, oc_chunk, oh, ow, oc_block, ic, kh, kw in T.grid(
T.int64(1),
T.int64(32),
T.int64(28),
T.int64(28),
T.int64(4),
T.int64(128),
T.int64(3),
T.int64(3),
):
with T.block("conv2d_NCHWc"):
(
v_n,
v_oc_chunk,
v_oh,
v_ow,
v_oc_block,
v_ic,
v_kh,
v_kw,
) = T.axis.remap(
"SSSSSRRR", [n, oc_chunk, oh, ow, oc_block, ic, kh, kw]
)
T.reads(
data_pad[
v_n,
v_ic // T.int64(4),
v_oh + v_kh,
v_ow + v_kw,
v_ic % T.int64(4),
],
B[
v_oc_chunk,
v_ic // T.int64(4),
v_kh,
v_kw,
v_ic % T.int64(4),
v_oc_block,
],
)
T.writes(conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block])
with T.init():
conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block] = T.float32(0)
conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block] = (
conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block]
+ data_pad[
v_n,
v_ic // T.int64(4),
v_oh + v_kh,
v_ow + v_kw,
v_ic % T.int64(4),
]
* B[
v_oc_chunk,
v_ic // T.int64(4),
v_kh,
v_kw,
v_ic % T.int64(4),
v_oc_block,
]
)
@T.prim_func
def contrib_conv2d_NCHWc7(
A: T.Buffer(
(T.int64(1), T.int64(32), T.int64(28), T.int64(28), T.int64(4)), "float32"
),
B: T.Buffer(
(T.int64(128), T.int64(32), T.int64(1), T.int64(1), T.int64(4), T.int64(4)),
"float32",
),
conv2d_NCHWc: T.Buffer(
(T.int64(1), T.int64(128), T.int64(28), T.int64(28), T.int64(4)), "float32"
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for n, oc_chunk, oh, ow, oc_block, ic, kh, kw in T.grid(
T.int64(1),
T.int64(128),
T.int64(28),
T.int64(28),
T.int64(4),
T.int64(128),
T.int64(1),
T.int64(1),
):
with T.block("conv2d_NCHWc"):
(
v_n,
v_oc_chunk,
v_oh,
v_ow,
v_oc_block,
v_ic,
v_kh,
v_kw,
) = T.axis.remap(
"SSSSSRRR", [n, oc_chunk, oh, ow, oc_block, ic, kh, kw]
)
T.reads(
A[
v_n,
v_ic // T.int64(4),
v_oh + v_kh,
v_ow + v_kw,
v_ic % T.int64(4),
],
B[
v_oc_chunk,
v_ic // T.int64(4),
v_kh,
v_kw,
v_ic % T.int64(4),
v_oc_block,
],
)
T.writes(conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block])
with T.init():
conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block] = T.float32(0)
conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block] = (
conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block]
+ A[
v_n,
v_ic // T.int64(4),
v_oh + v_kh,
v_ow + v_kw,
v_ic % T.int64(4),
]
* B[
v_oc_chunk,
v_ic // T.int64(4),
v_kh,
v_kw,
v_ic % T.int64(4),
v_oc_block,
]
)
@T.prim_func
def contrib_conv2d_NCHWc8(
A: T.Buffer(
(T.int64(1), T.int64(64), T.int64(56), T.int64(56), T.int64(4)), "float32"
),
B: T.Buffer(
(T.int64(128), T.int64(64), T.int64(1), T.int64(1), T.int64(4), T.int64(4)),
"float32",
),
conv2d_NCHWc: T.Buffer(
(T.int64(1), T.int64(128), T.int64(28), T.int64(28), T.int64(4)), "float32"
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for n, oc_chunk, oh, ow, oc_block, ic, kh, kw in T.grid(
T.int64(1),
T.int64(128),
T.int64(28),
T.int64(28),
T.int64(4),
T.int64(256),
T.int64(1),
T.int64(1),
):
with T.block("conv2d_NCHWc"):
(
v_n,
v_oc_chunk,
v_oh,
v_ow,
v_oc_block,
v_ic,
v_kh,
v_kw,
) = T.axis.remap(
"SSSSSRRR", [n, oc_chunk, oh, ow, oc_block, ic, kh, kw]
)
T.reads(
A[
v_n,
v_ic // T.int64(4),
v_oh * T.int64(2) + v_kh,
v_ow * T.int64(2) + v_kw,
v_ic % T.int64(4),
],
B[
v_oc_chunk,
v_ic // T.int64(4),
v_kh,
v_kw,
v_ic % T.int64(4),
v_oc_block,
],
)
T.writes(conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block])
with T.init():
conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block] = T.float32(0)
conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block] = (
conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block]
+ A[
v_n,
v_ic // T.int64(4),
v_oh * T.int64(2) + v_kh,
v_ow * T.int64(2) + v_kw,
v_ic % T.int64(4),
]
* B[
v_oc_chunk,
v_ic // T.int64(4),
v_kh,
v_kw,
v_ic % T.int64(4),
v_oc_block,
]
)
@T.prim_func
def contrib_conv2d_NCHWc9(
A: T.Buffer(
(T.int64(1), T.int64(128), T.int64(28), T.int64(28), T.int64(4)), "float32"
),
B: T.Buffer(
(T.int64(32), T.int64(128), T.int64(1), T.int64(1), T.int64(4), T.int64(4)),
"float32",
),
conv2d_NCHWc: T.Buffer(
(T.int64(1), T.int64(32), T.int64(28), T.int64(28), T.int64(4)), "float32"
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for n, oc_chunk, oh, ow, oc_block, ic, kh, kw in T.grid(
T.int64(1),
T.int64(32),
T.int64(28),
T.int64(28),
T.int64(4),
T.int64(512),
T.int64(1),
T.int64(1),
):
with T.block("conv2d_NCHWc"):
(
v_n,
v_oc_chunk,
v_oh,
v_ow,
v_oc_block,
v_ic,
v_kh,
v_kw,
) = T.axis.remap(
"SSSSSRRR", [n, oc_chunk, oh, ow, oc_block, ic, kh, kw]
)
T.reads(
A[
v_n,
v_ic // T.int64(4),
v_oh + v_kh,
v_ow + v_kw,
v_ic % T.int64(4),
],
B[
v_oc_chunk,
v_ic // T.int64(4),
v_kh,
v_kw,
v_ic % T.int64(4),
v_oc_block,
],
)
T.writes(conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block])
with T.init():
conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block] = T.float32(0)
conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block] = (
conv2d_NCHWc[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block]
+ A[
v_n,
v_ic // T.int64(4),
v_oh + v_kh,
v_ow + v_kw,
v_ic % T.int64(4),
]
* B[
v_oc_chunk,
v_ic // T.int64(4),
v_kh,
v_kw,
v_ic % T.int64(4),
v_oc_block,
]
)
@T.prim_func
def dense(
A: T.Buffer((T.int64(1), T.int64(2048)), "float32"),
B: T.Buffer((T.int64(1000), T.int64(2048)), "float32"),
T_matmul_NT: T.Buffer((T.int64(1), T.int64(1000)), "float32"),
):
T.func_attr({"layout_free_buffers": [1], "tir.noalias": T.bool(True)})
# with T.block("root"):
for i, j, k in T.grid(T.int64(1), T.int64(1000), T.int64(2048)):
with T.block("T_matmul_NT"):
v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k])
T.reads(A[v_i, v_k], B[v_j, v_k])
T.writes(T_matmul_NT[v_i, v_j])
with T.init():
T_matmul_NT[v_i, v_j] = T.float32(0)
T_matmul_NT[v_i, v_j] = (
T_matmul_NT[v_i, v_j] + A[v_i, v_k] * B[v_j, v_k]
)
@T.prim_func
def divide(
A: T.Buffer((T.int64(3), T.int64(1), T.int64(1)), "float32"),
B: T.Buffer((T.int64(3), T.int64(1), T.int64(1)), "float32"),
T_divide: T.Buffer((T.int64(3), T.int64(1), T.int64(1)), "float32"),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2 in T.grid(T.int64(3), T.int64(1), T.int64(1)):
with T.block("T_divide"):
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
T.reads(A[v_ax0, v_ax1, v_ax2], B[v_ax0, v_ax1, v_ax2])
T.writes(T_divide[v_ax0, v_ax1, v_ax2])
T_divide[v_ax0, v_ax1, v_ax2] = (
A[v_ax0, v_ax1, v_ax2] / B[v_ax0, v_ax1, v_ax2]
)
@T.prim_func
def expand_dims(
A: T.Buffer((T.int64(3),), "float32"),
T_expand_dims: T.Buffer((T.int64(3), T.int64(1), T.int64(1)), "float32"),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2 in T.grid(T.int64(3), T.int64(1), T.int64(1)):
with T.block("T_expand_dims"):
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
T.reads(A[v_ax0])
T.writes(T_expand_dims[v_ax0, v_ax1, v_ax2])
T_expand_dims[v_ax0, v_ax1, v_ax2] = A[v_ax0]
@T.prim_func
def expand_dims1(
A: T.Buffer((T.int64(64),), "float32"),
T_expand_dims: T.Buffer((T.int64(64), T.int64(1), T.int64(1)), "float32"),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2 in T.grid(T.int64(64), T.int64(1), T.int64(1)):
with T.block("T_expand_dims"):
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
T.reads(A[v_ax0])
T.writes(T_expand_dims[v_ax0, v_ax1, v_ax2])
T_expand_dims[v_ax0, v_ax1, v_ax2] = A[v_ax0]
@T.prim_func
def expand_dims10(
A: T.Buffer((T.int64(512), T.int64(1), T.int64(1)), "float32"),
T_expand_dims: T.Buffer(
(T.int64(1), T.int64(512), T.int64(1), T.int64(1)), "float32"
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2, ax3 in T.grid(
T.int64(1), T.int64(512), T.int64(1), T.int64(1)
):
with T.block("T_expand_dims"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(A[v_ax1, v_ax2, v_ax3])
T.writes(T_expand_dims[v_ax0, v_ax1, v_ax2, v_ax3])
T_expand_dims[v_ax0, v_ax1, v_ax2, v_ax3] = A[v_ax1, v_ax2, v_ax3]
@T.prim_func
def expand_dims11(
A: T.Buffer((T.int64(256),), "float32"),
T_expand_dims: T.Buffer(
(T.int64(256), T.int64(1), T.int64(1), T.int64(1)), "float32"
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2, ax3 in T.grid(
T.int64(256), T.int64(1), T.int64(1), T.int64(1)
):
with T.block("T_expand_dims"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(A[v_ax0])
T.writes(T_expand_dims[v_ax0, v_ax1, v_ax2, v_ax3])
T_expand_dims[v_ax0, v_ax1, v_ax2, v_ax3] = A[v_ax0]
@T.prim_func
def expand_dims12(
A: T.Buffer((T.int64(1024),), "float32"),
T_expand_dims: T.Buffer((T.int64(1024), T.int64(1), T.int64(1)), "float32"),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2 in T.grid(T.int64(1024), T.int64(1), T.int64(1)):
with T.block("T_expand_dims"):
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
T.reads(A[v_ax0])
T.writes(T_expand_dims[v_ax0, v_ax1, v_ax2])
T_expand_dims[v_ax0, v_ax1, v_ax2] = A[v_ax0]
@T.prim_func
def expand_dims13(
A: T.Buffer((T.int64(1024), T.int64(1), T.int64(1)), "float32"),
T_expand_dims: T.Buffer(
(T.int64(1), T.int64(1024), T.int64(1), T.int64(1)), "float32"
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2, ax3 in T.grid(
T.int64(1), T.int64(1024), T.int64(1), T.int64(1)
):
with T.block("T_expand_dims"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(A[v_ax1, v_ax2, v_ax3])
T.writes(T_expand_dims[v_ax0, v_ax1, v_ax2, v_ax3])
T_expand_dims[v_ax0, v_ax1, v_ax2, v_ax3] = A[v_ax1, v_ax2, v_ax3]
@T.prim_func
def expand_dims14(
A: T.Buffer((T.int64(512),), "float32"),
T_expand_dims: T.Buffer(
(T.int64(512), T.int64(1), T.int64(1), T.int64(1)), "float32"
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2, ax3 in T.grid(
T.int64(512), T.int64(1), T.int64(1), T.int64(1)
):
with T.block("T_expand_dims"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(A[v_ax0])
T.writes(T_expand_dims[v_ax0, v_ax1, v_ax2, v_ax3])
T_expand_dims[v_ax0, v_ax1, v_ax2, v_ax3] = A[v_ax0]
@T.prim_func
def expand_dims15(
A: T.Buffer((T.int64(2048),), "float32"),
T_expand_dims: T.Buffer((T.int64(2048), T.int64(1), T.int64(1)), "float32"),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2 in T.grid(T.int64(2048), T.int64(1), T.int64(1)):
with T.block("T_expand_dims"):
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
T.reads(A[v_ax0])
T.writes(T_expand_dims[v_ax0, v_ax1, v_ax2])
T_expand_dims[v_ax0, v_ax1, v_ax2] = A[v_ax0]
@T.prim_func
def expand_dims16(
A: T.Buffer((T.int64(2048), T.int64(1), T.int64(1)), "float32"),
T_expand_dims: T.Buffer(
(T.int64(1), T.int64(2048), T.int64(1), T.int64(1)), "float32"
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2, ax3 in T.grid(
T.int64(1), T.int64(2048), T.int64(1), T.int64(1)
):
with T.block("T_expand_dims"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(A[v_ax1, v_ax2, v_ax3])
T.writes(T_expand_dims[v_ax0, v_ax1, v_ax2, v_ax3])
T_expand_dims[v_ax0, v_ax1, v_ax2, v_ax3] = A[v_ax1, v_ax2, v_ax3]
@T.prim_func
def expand_dims17(
A: T.Buffer((T.int64(1000),), "float32"),
T_expand_dims: T.Buffer((T.int64(1), T.int64(1000)), "float32"),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1 in T.grid(T.int64(1), T.int64(1000)):
with T.block("T_expand_dims"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(A[v_ax1])
T.writes(T_expand_dims[v_ax0, v_ax1])
T_expand_dims[v_ax0, v_ax1] = A[v_ax1]
@T.prim_func
def expand_dims2(
A: T.Buffer((T.int64(64),), "float32"),
T_expand_dims: T.Buffer(
(T.int64(64), T.int64(1), T.int64(1), T.int64(1)), "float32"
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2, ax3 in T.grid(
T.int64(64), T.int64(1), T.int64(1), T.int64(1)
):
with T.block("T_expand_dims"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(A[v_ax0])
T.writes(T_expand_dims[v_ax0, v_ax1, v_ax2, v_ax3])
T_expand_dims[v_ax0, v_ax1, v_ax2, v_ax3] = A[v_ax0]
@T.prim_func
def expand_dims3(
A: T.Buffer((T.int64(64), T.int64(1), T.int64(1)), "float32"),
T_expand_dims: T.Buffer(
(T.int64(1), T.int64(64), T.int64(1), T.int64(1)), "float32"
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2, ax3 in T.grid(
T.int64(1), T.int64(64), T.int64(1), T.int64(1)
):
with T.block("T_expand_dims"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(A[v_ax1, v_ax2, v_ax3])
T.writes(T_expand_dims[v_ax0, v_ax1, v_ax2, v_ax3])
T_expand_dims[v_ax0, v_ax1, v_ax2, v_ax3] = A[v_ax1, v_ax2, v_ax3]
@T.prim_func
def expand_dims4(
A: T.Buffer((T.int64(256),), "float32"),
T_expand_dims: T.Buffer((T.int64(256), T.int64(1), T.int64(1)), "float32"),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2 in T.grid(T.int64(256), T.int64(1), T.int64(1)):
with T.block("T_expand_dims"):
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
T.reads(A[v_ax0])
T.writes(T_expand_dims[v_ax0, v_ax1, v_ax2])
T_expand_dims[v_ax0, v_ax1, v_ax2] = A[v_ax0]
@T.prim_func
def expand_dims5(
A: T.Buffer((T.int64(256), T.int64(1), T.int64(1)), "float32"),
T_expand_dims: T.Buffer(
(T.int64(1), T.int64(256), T.int64(1), T.int64(1)), "float32"
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2, ax3 in T.grid(
T.int64(1), T.int64(256), T.int64(1), T.int64(1)
):
with T.block("T_expand_dims"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(A[v_ax1, v_ax2, v_ax3])
T.writes(T_expand_dims[v_ax0, v_ax1, v_ax2, v_ax3])
T_expand_dims[v_ax0, v_ax1, v_ax2, v_ax3] = A[v_ax1, v_ax2, v_ax3]
@T.prim_func
def expand_dims6(
A: T.Buffer((T.int64(128),), "float32"),
T_expand_dims: T.Buffer((T.int64(128), T.int64(1), T.int64(1)), "float32"),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2 in T.grid(T.int64(128), T.int64(1), T.int64(1)):
with T.block("T_expand_dims"):
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
T.reads(A[v_ax0])
T.writes(T_expand_dims[v_ax0, v_ax1, v_ax2])
T_expand_dims[v_ax0, v_ax1, v_ax2] = A[v_ax0]
@T.prim_func
def expand_dims7(
A: T.Buffer((T.int64(128),), "float32"),
T_expand_dims: T.Buffer(
(T.int64(128), T.int64(1), T.int64(1), T.int64(1)), "float32"
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2, ax3 in T.grid(
T.int64(128), T.int64(1), T.int64(1), T.int64(1)
):
with T.block("T_expand_dims"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(A[v_ax0])
T.writes(T_expand_dims[v_ax0, v_ax1, v_ax2, v_ax3])
T_expand_dims[v_ax0, v_ax1, v_ax2, v_ax3] = A[v_ax0]
@T.prim_func
def expand_dims8(
A: T.Buffer((T.int64(128), T.int64(1), T.int64(1)), "float32"),
T_expand_dims: T.Buffer(
(T.int64(1), T.int64(128), T.int64(1), T.int64(1)), "float32"
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2, ax3 in T.grid(
T.int64(1), T.int64(128), T.int64(1), T.int64(1)
):
with T.block("T_expand_dims"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(A[v_ax1, v_ax2, v_ax3])
T.writes(T_expand_dims[v_ax0, v_ax1, v_ax2, v_ax3])
T_expand_dims[v_ax0, v_ax1, v_ax2, v_ax3] = A[v_ax1, v_ax2, v_ax3]
@T.prim_func
def expand_dims9(
A: T.Buffer((T.int64(512),), "float32"),
T_expand_dims: T.Buffer((T.int64(512), T.int64(1), T.int64(1)), "float32"),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2 in T.grid(T.int64(512), T.int64(1), T.int64(1)):
with T.block("T_expand_dims"):
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
T.reads(A[v_ax0])
T.writes(T_expand_dims[v_ax0, v_ax1, v_ax2])
T_expand_dims[v_ax0, v_ax1, v_ax2] = A[v_ax0]
@T.prim_func
def global_avg_pool2d(
A: T.Buffer(
(T.int64(1), T.int64(512), T.int64(7), T.int64(7), T.int64(4)), "float32"
),
adaptive_pool_avg: T.Buffer(
(T.int64(1), T.int64(512), T.int64(1), T.int64(1), T.int64(4)), "float32"
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
adaptive_pool_sum = T.alloc_buffer(
(T.int64(1), T.int64(512), T.int64(1), T.int64(1), T.int64(4))
)
for ax0, ax1, ax2, ax3, ax4, rv0, rv1 in T.grid(
T.int64(1),
T.int64(512),
T.int64(1),
T.int64(1),
T.int64(4),
T.int64(7),
T.int64(7),
):
with T.block("adaptive_pool_sum"):
v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_rv0, v_rv1 = T.axis.remap(
"SSSSSRR", [ax0, ax1, ax2, ax3, ax4, rv0, rv1]
)
T.reads(
A[
v_ax0,
v_ax1,
v_ax2 * T.int64(7) + v_rv0,
v_ax3 * T.int64(7) + v_rv1,
v_ax4,
]
)
T.writes(adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
with T.init():
adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = T.float32(0)
adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = (
adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]
+ A[
v_ax0,
v_ax1,
v_ax2 * T.int64(7) + v_rv0,
v_ax3 * T.int64(7) + v_rv1,
v_ax4,
]
)
for ax0, ax1, ax2, ax3, ax4 in T.grid(
T.int64(1), T.int64(512), T.int64(1), T.int64(1), T.int64(4)
):
with T.block("adaptive_pool_avg"):
v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap(
"SSSSS", [ax0, ax1, ax2, ax3, ax4]
)
T.reads(adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
T.writes(adaptive_pool_avg[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
T.block_attr({"schedule_rule": "meta_schedule.adaptive_pool_avg"})
adaptive_pool_avg[
v_ax0, v_ax1, v_ax2, v_ax3, v_ax4
] = adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] * T.float32(
0.020408163265306121
)
@T.prim_func
def layout_transform(
A: T.Buffer((T.int64(1), T.int64(3), T.int64(224), T.int64(224)), "float32"),
T_layout_trans: T.Buffer(
(T.int64(1), T.int64(1), T.int64(224), T.int64(224), T.int64(3)), "float32"
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2, ax3, ax4 in T.grid(
T.int64(1), T.int64(1), T.int64(224), T.int64(224), T.int64(3)
):
with T.block("T_layout_trans"):
v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap(
"SSSSS", [ax0, ax1, ax2, ax3, ax4]
)
T.reads(A[v_ax0, v_ax1 * T.int64(3) + v_ax4, v_ax2, v_ax3])
T.writes(T_layout_trans[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
T.block_attr(
{
"dst_layout": "NCHW3c",
"input_shape": [
T.int64(1),
T.int64(3),
T.int64(224),
T.int64(224),
],
"schedule_rule": "None",
"src_layout": "NCHW",
}
)
T_layout_trans[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = T.if_then_else(
v_ax0 < T.int64(1)
and v_ax1 * T.int64(3) + v_ax4 < T.int64(3)
and v_ax2 < T.int64(224)
and v_ax3 < T.int64(224),
A[v_ax0, v_ax1 * T.int64(3) + v_ax4, v_ax2, v_ax3],
T.float32(0),
)
@T.prim_func
def layout_transform1(
A: T.Buffer((T.int64(64), T.int64(3), T.int64(7), T.int64(7)), "float32"),
T_layout_trans: T.Buffer(
(T.int64(16), T.int64(1), T.int64(7), T.int64(7), T.int64(3), T.int64(4)),
"float32",
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2, ax3, ax4, ax5 in T.grid(
T.int64(16), T.int64(1), T.int64(7), T.int64(7), T.int64(3), T.int64(4)
):
with T.block("T_layout_trans"):
v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5 = T.axis.remap(
"SSSSSS", [ax0, ax1, ax2, ax3, ax4, ax5]
)
T.reads(
A[
v_ax0 * T.int64(4) + v_ax5,
v_ax1 * T.int64(3) + v_ax4,
v_ax2,
v_ax3,
]
)
T.writes(T_layout_trans[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5])
T.block_attr(
{
"dst_layout": "OIHW3i4o",
"input_shape": [
T.int64(64),
T.int64(3),
T.int64(7),
T.int64(7),
],
"schedule_rule": "None",
"src_layout": "OIHW",
}
)
T_layout_trans[
v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5
] = T.if_then_else(
v_ax0 * T.int64(4) + v_ax5 < T.int64(64)
and v_ax1 * T.int64(3) + v_ax4 < T.int64(3)
and v_ax2 < T.int64(7)
and v_ax3 < T.int64(7),
A[
v_ax0 * T.int64(4) + v_ax5,
v_ax1 * T.int64(3) + v_ax4,
v_ax2,
v_ax3,
],
T.float32(0),
)
@T.prim_func
def layout_transform10(
A: T.Buffer((T.int64(128), T.int64(128), T.int64(3), T.int64(3)), "float32"),
T_layout_trans: T.Buffer(
(T.int64(32), T.int64(32), T.int64(3), T.int64(3), T.int64(4), T.int64(4)),
"float32",
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2, ax3, ax4, ax5 in T.grid(
T.int64(32), T.int64(32), T.int64(3), T.int64(3), T.int64(4), T.int64(4)
):
with T.block("T_layout_trans"):
v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5 = T.axis.remap(
"SSSSSS", [ax0, ax1, ax2, ax3, ax4, ax5]
)
T.reads(
A[
v_ax0 * T.int64(4) + v_ax5,
v_ax1 * T.int64(4) + v_ax4,
v_ax2,
v_ax3,
]
)
T.writes(T_layout_trans[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5])
T.block_attr(
{
"dst_layout": "OIHW4i4o",
"input_shape": [
T.int64(128),
T.int64(128),
T.int64(3),
T.int64(3),
],
"schedule_rule": "None",
"src_layout": "OIHW",
}
)
T_layout_trans[
v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5
] = T.if_then_else(
v_ax0 * T.int64(4) + v_ax5 < T.int64(128)
and v_ax1 * T.int64(4) + v_ax4 < T.int64(128)
and v_ax2 < T.int64(3)
and v_ax3 < T.int64(3),
A[
v_ax0 * T.int64(4) + v_ax5,
v_ax1 * T.int64(4) + v_ax4,
v_ax2,
v_ax3,
],
T.float32(0),
)
@T.prim_func
def layout_transform11(
A: T.Buffer((T.int64(512), T.int64(128), T.int64(1), T.int64(1)), "float32"),
T_layout_trans: T.Buffer(
(T.int64(128), T.int64(32), T.int64(1), T.int64(1), T.int64(4), T.int64(4)),
"float32",
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2, ax3, ax4, ax5 in T.grid(
T.int64(128), T.int64(32), T.int64(1), T.int64(1), T.int64(4), T.int64(4)
):
with T.block("T_layout_trans"):
v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5 = T.axis.remap(
"SSSSSS", [ax0, ax1, ax2, ax3, ax4, ax5]
)
T.reads(
A[
v_ax0 * T.int64(4) + v_ax5,
v_ax1 * T.int64(4) + v_ax4,
v_ax2,
v_ax3,
]
)
T.writes(T_layout_trans[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5])
T.block_attr(
{
"dst_layout": "OIHW4i4o",
"input_shape": [
T.int64(512),
T.int64(128),
T.int64(1),
T.int64(1),
],
"schedule_rule": "None",
"src_layout": "OIHW",
}
)
T_layout_trans[
v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5
] = T.if_then_else(
v_ax0 * T.int64(4) + v_ax5 < T.int64(512)
and v_ax1 * T.int64(4) + v_ax4 < T.int64(128)
and v_ax2 < T.int64(1)
and v_ax3 < T.int64(1),
A[
v_ax0 * T.int64(4) + v_ax5,
v_ax1 * T.int64(4) + v_ax4,
v_ax2,
v_ax3,
],
T.float32(0),
)
@T.prim_func
def layout_transform12(
A: T.Buffer((T.int64(512), T.int64(256), T.int64(1), T.int64(1)), "float32"),
T_layout_trans: T.Buffer(
(T.int64(128), T.int64(64), T.int64(1), T.int64(1), T.int64(4), T.int64(4)),
"float32",
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2, ax3, ax4, ax5 in T.grid(
T.int64(128), T.int64(64), T.int64(1), T.int64(1), T.int64(4), T.int64(4)
):
with T.block("T_layout_trans"):
v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5 = T.axis.remap(
"SSSSSS", [ax0, ax1, ax2, ax3, ax4, ax5]
)
T.reads(
A[
v_ax0 * T.int64(4) + v_ax5,
v_ax1 * T.int64(4) + v_ax4,
v_ax2,
v_ax3,
]
)
T.writes(T_layout_trans[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5])
T.block_attr(
{
"dst_layout": "OIHW4i4o",
"input_shape": [
T.int64(512),
T.int64(256),
T.int64(1),
T.int64(1),
],
"schedule_rule": "None",
"src_layout": "OIHW",
}
)
T_layout_trans[
v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5
] = T.if_then_else(
v_ax0 * T.int64(4) + v_ax5 < T.int64(512)
and v_ax1 * T.int64(4) + v_ax4 < T.int64(256)
and v_ax2 < T.int64(1)
and v_ax3 < T.int64(1),
A[
v_ax0 * T.int64(4) + v_ax5,
v_ax1 * T.int64(4) + v_ax4,
v_ax2,
v_ax3,
],
T.float32(0),
)
@T.prim_func
def layout_transform13(
A: T.Buffer((T.int64(1), T.int64(512), T.int64(1), T.int64(1)), "float32"),
T_layout_trans: T.Buffer(
(T.int64(1), T.int64(128), T.int64(1), T.int64(1), T.int64(4)), "float32"
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2, ax3, ax4 in T.grid(
T.int64(1), T.int64(128), T.int64(1), T.int64(1), T.int64(4)
):
with T.block("T_layout_trans"):
v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap(
"SSSSS", [ax0, ax1, ax2, ax3, ax4]
)
T.reads(A[v_ax0, v_ax1 * T.int64(4) + v_ax4, v_ax2, v_ax3])
T.writes(T_layout_trans[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
T.block_attr(
{
"dst_layout": "NCHW4c",
"input_shape": [
T.int64(1),
T.int64(512),
T.int64(1),
T.int64(1),
],
"schedule_rule": "None",
"src_layout": "NCHW",
}
)
T_layout_trans[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = T.if_then_else(
v_ax0 < T.int64(1)
and v_ax1 * T.int64(4) + v_ax4 < T.int64(512)
and v_ax2 < T.int64(1)
and v_ax3 < T.int64(1),
A[v_ax0, v_ax1 * T.int64(4) + v_ax4, v_ax2, v_ax3],
T.float32(0),
)
@T.prim_func
def layout_transform14(
A: T.Buffer((T.int64(128), T.int64(512), T.int64(1), T.int64(1)), "float32"),
T_layout_trans: T.Buffer(
(T.int64(32), T.int64(128), T.int64(1), T.int64(1), T.int64(4), T.int64(4)),
"float32",
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2, ax3, ax4, ax5 in T.grid(
T.int64(32), T.int64(128), T.int64(1), T.int64(1), T.int64(4), T.int64(4)
):
with T.block("T_layout_trans"):
v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5 = T.axis.remap(
"SSSSSS", [ax0, ax1, ax2, ax3, ax4, ax5]
)
T.reads(
A[
v_ax0 * T.int64(4) + v_ax5,
v_ax1 * T.int64(4) + v_ax4,
v_ax2,
v_ax3,
]
)
T.writes(T_layout_trans[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5])
T.block_attr(
{
"dst_layout": "OIHW4i4o",
"input_shape": [
T.int64(128),
T.int64(512),
T.int64(1),
T.int64(1),
],
"schedule_rule": "None",
"src_layout": "OIHW",
}
)
T_layout_trans[
v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5
] = T.if_then_else(
v_ax0 * T.int64(4) + v_ax5 < T.int64(128)
and v_ax1 * T.int64(4) + v_ax4 < T.int64(512)
and v_ax2 < T.int64(1)
and v_ax3 < T.int64(1),
A[
v_ax0 * T.int64(4) + v_ax5,
v_ax1 * T.int64(4) + v_ax4,
v_ax2,
v_ax3,
],
T.float32(0),
)
@T.prim_func
def layout_transform15(
A: T.Buffer((T.int64(256), T.int64(512), T.int64(1), T.int64(1)), "float32"),
T_layout_trans: T.Buffer(
(T.int64(64), T.int64(128), T.int64(1), T.int64(1), T.int64(4), T.int64(4)),
"float32",
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2, ax3, ax4, ax5 in T.grid(
T.int64(64), T.int64(128), T.int64(1), T.int64(1), T.int64(4), T.int64(4)
):
with T.block("T_layout_trans"):
v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5 = T.axis.remap(
"SSSSSS", [ax0, ax1, ax2, ax3, ax4, ax5]
)
T.reads(
A[
v_ax0 * T.int64(4) + v_ax5,
v_ax1 * T.int64(4) + v_ax4,
v_ax2,
v_ax3,
]
)
T.writes(T_layout_trans[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5])
T.block_attr(
{
"dst_layout": "OIHW4i4o",
"input_shape": [
T.int64(256),
T.int64(512),
T.int64(1),
T.int64(1),
],
"schedule_rule": "None",
"src_layout": "OIHW",
}
)
T_layout_trans[
v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5
] = T.if_then_else(
v_ax0 * T.int64(4) + v_ax5 < T.int64(256)
and v_ax1 * T.int64(4) + v_ax4 < T.int64(512)
and v_ax2 < T.int64(1)
and v_ax3 < T.int64(1),
A[
v_ax0 * T.int64(4) + v_ax5,
v_ax1 * T.int64(4) + v_ax4,
v_ax2,
v_ax3,
],
T.float32(0),
)
@T.prim_func
def layout_transform16(
A: T.Buffer((T.int64(256), T.int64(256), T.int64(3), T.int64(3)), "float32"),
T_layout_trans: T.Buffer(
(T.int64(64), T.int64(64), T.int64(3), T.int64(3), T.int64(4), T.int64(4)),
"float32",
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2, ax3, ax4, ax5 in T.grid(
T.int64(64), T.int64(64), T.int64(3), T.int64(3), T.int64(4), T.int64(4)
):
with T.block("T_layout_trans"):
v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5 = T.axis.remap(
"SSSSSS", [ax0, ax1, ax2, ax3, ax4, ax5]
)
T.reads(
A[
v_ax0 * T.int64(4) + v_ax5,
v_ax1 * T.int64(4) + v_ax4,
v_ax2,
v_ax3,
]
)
T.writes(T_layout_trans[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5])
T.block_attr(
{
"dst_layout": "OIHW4i4o",
"input_shape": [
T.int64(256),
T.int64(256),
T.int64(3),
T.int64(3),
],
"schedule_rule": "None",
"src_layout": "OIHW",
}
)
T_layout_trans[
v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5
] = T.if_then_else(
v_ax0 * T.int64(4) + v_ax5 < T.int64(256)
and v_ax1 * T.int64(4) + v_ax4 < T.int64(256)
and v_ax2 < T.int64(3)
and v_ax3 < T.int64(3),
A[
v_ax0 * T.int64(4) + v_ax5,
v_ax1 * T.int64(4) + v_ax4,
v_ax2,
v_ax3,
],
T.float32(0),
)
@T.prim_func
def layout_transform17(
A: T.Buffer((T.int64(1024), T.int64(256), T.int64(1), T.int64(1)), "float32"),
T_layout_trans: T.Buffer(
(T.int64(256), T.int64(64), T.int64(1), T.int64(1), T.int64(4), T.int64(4)),
"float32",
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2, ax3, ax4, ax5 in T.grid(
T.int64(256), T.int64(64), T.int64(1), T.int64(1), T.int64(4), T.int64(4)
):
with T.block("T_layout_trans"):
v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5 = T.axis.remap(
"SSSSSS", [ax0, ax1, ax2, ax3, ax4, ax5]
)
T.reads(
A[
v_ax0 * T.int64(4) + v_ax5,
v_ax1 * T.int64(4) + v_ax4,
v_ax2,
v_ax3,
]
)
T.writes(T_layout_trans[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5])
T.block_attr(
{
"dst_layout": "OIHW4i4o",
"input_shape": [
T.int64(1024),
T.int64(256),
T.int64(1),
T.int64(1),
],
"schedule_rule": "None",
"src_layout": "OIHW",
}
)
T_layout_trans[
v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5
] = T.if_then_else(
v_ax0 * T.int64(4) + v_ax5 < T.int64(1024)
and v_ax1 * T.int64(4) + v_ax4 < T.int64(256)
and v_ax2 < T.int64(1)
and v_ax3 < T.int64(1),
A[
v_ax0 * T.int64(4) + v_ax5,
v_ax1 * T.int64(4) + v_ax4,
v_ax2,
v_ax3,
],
T.float32(0),
)
@T.prim_func
def layout_transform18(
A: T.Buffer((T.int64(1024), T.int64(512), T.int64(1), T.int64(1)), "float32"),
T_layout_trans: T.Buffer(
(
T.int64(256),
T.int64(128),
T.int64(1),
T.int64(1),
T.int64(4),
T.int64(4),
),
"float32",
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2, ax3, ax4, ax5 in T.grid(
T.int64(256), T.int64(128), T.int64(1), T.int64(1), T.int64(4), T.int64(4)
):
with T.block("T_layout_trans"):
v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5 = T.axis.remap(
"SSSSSS", [ax0, ax1, ax2, ax3, ax4, ax5]
)
T.reads(
A[
v_ax0 * T.int64(4) + v_ax5,
v_ax1 * T.int64(4) + v_ax4,
v_ax2,
v_ax3,
]
)
T.writes(T_layout_trans[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5])
T.block_attr(
{
"dst_layout": "OIHW4i4o",
"input_shape": [
T.int64(1024),
T.int64(512),
T.int64(1),
T.int64(1),
],
"schedule_rule": "None",
"src_layout": "OIHW",
}
)
T_layout_trans[
v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5
] = T.if_then_else(
v_ax0 * T.int64(4) + v_ax5 < T.int64(1024)
and v_ax1 * T.int64(4) + v_ax4 < T.int64(512)
and v_ax2 < T.int64(1)
and v_ax3 < T.int64(1),
A[
v_ax0 * T.int64(4) + v_ax5,
v_ax1 * T.int64(4) + v_ax4,
v_ax2,
v_ax3,
],
T.float32(0),
)
@T.prim_func
def layout_transform19(
A: T.Buffer((T.int64(1), T.int64(1024), T.int64(1), T.int64(1)), "float32"),
T_layout_trans: T.Buffer(
(T.int64(1), T.int64(256), T.int64(1), T.int64(1), T.int64(4)), "float32"
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2, ax3, ax4 in T.grid(
T.int64(1), T.int64(256), T.int64(1), T.int64(1), T.int64(4)
):
with T.block("T_layout_trans"):
v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap(
"SSSSS", [ax0, ax1, ax2, ax3, ax4]
)
T.reads(A[v_ax0, v_ax1 * T.int64(4) + v_ax4, v_ax2, v_ax3])
T.writes(T_layout_trans[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
T.block_attr(
{
"dst_layout": "NCHW4c",
"input_shape": [
T.int64(1),
T.int64(1024),
T.int64(1),
T.int64(1),
],
"schedule_rule": "None",
"src_layout": "NCHW",
}
)
T_layout_trans[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = T.if_then_else(
v_ax0 < T.int64(1)
and v_ax1 * T.int64(4) + v_ax4 < T.int64(1024)
and v_ax2 < T.int64(1)
and v_ax3 < T.int64(1),
A[v_ax0, v_ax1 * T.int64(4) + v_ax4, v_ax2, v_ax3],
T.float32(0),
)
@T.prim_func
def layout_transform2(
A: T.Buffer((T.int64(1), T.int64(64), T.int64(1), T.int64(1)), "float32"),
T_layout_trans: T.Buffer(
(T.int64(1), T.int64(16), T.int64(1), T.int64(1), T.int64(4)), "float32"
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2, ax3, ax4 in T.grid(
T.int64(1), T.int64(16), T.int64(1), T.int64(1), T.int64(4)
):
with T.block("T_layout_trans"):
v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap(
"SSSSS", [ax0, ax1, ax2, ax3, ax4]
)
T.reads(A[v_ax0, v_ax1 * T.int64(4) + v_ax4, v_ax2, v_ax3])
T.writes(T_layout_trans[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
T.block_attr(
{
"dst_layout": "NCHW4c",
"input_shape": [
T.int64(1),
T.int64(64),
T.int64(1),
T.int64(1),
],
"schedule_rule": "None",
"src_layout": "NCHW",
}
)
T_layout_trans[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = T.if_then_else(
v_ax0 < T.int64(1)
and v_ax1 * T.int64(4) + v_ax4 < T.int64(64)
and v_ax2 < T.int64(1)
and v_ax3 < T.int64(1),
A[v_ax0, v_ax1 * T.int64(4) + v_ax4, v_ax2, v_ax3],
T.float32(0),
)
@T.prim_func
def layout_transform20(
A: T.Buffer((T.int64(256), T.int64(1024), T.int64(1), T.int64(1)), "float32"),
T_layout_trans: T.Buffer(
(T.int64(64), T.int64(256), T.int64(1), T.int64(1), T.int64(4), T.int64(4)),
"float32",
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2, ax3, ax4, ax5 in T.grid(
T.int64(64), T.int64(256), T.int64(1), T.int64(1), T.int64(4), T.int64(4)
):
with T.block("T_layout_trans"):
v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5 = T.axis.remap(
"SSSSSS", [ax0, ax1, ax2, ax3, ax4, ax5]
)
T.reads(
A[
v_ax0 * T.int64(4) + v_ax5,
v_ax1 * T.int64(4) + v_ax4,
v_ax2,
v_ax3,
]
)
T.writes(T_layout_trans[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5])
T.block_attr(
{
"dst_layout": "OIHW4i4o",
"input_shape": [
T.int64(256),
T.int64(1024),
T.int64(1),
T.int64(1),
],
"schedule_rule": "None",
"src_layout": "OIHW",
}
)
T_layout_trans[
v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5
] = T.if_then_else(
v_ax0 * T.int64(4) + v_ax5 < T.int64(256)
and v_ax1 * T.int64(4) + v_ax4 < T.int64(1024)
and v_ax2 < T.int64(1)
and v_ax3 < T.int64(1),
A[
v_ax0 * T.int64(4) + v_ax5,
v_ax1 * T.int64(4) + v_ax4,
v_ax2,
v_ax3,
],
T.float32(0),
)
@T.prim_func
def layout_transform21(
A: T.Buffer((T.int64(512), T.int64(1024), T.int64(1), T.int64(1)), "float32"),
T_layout_trans: T.Buffer(
(
T.int64(128),
T.int64(256),
T.int64(1),
T.int64(1),
T.int64(4),
T.int64(4),
),
"float32",
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2, ax3, ax4, ax5 in T.grid(
T.int64(128), T.int64(256), T.int64(1), T.int64(1), T.int64(4), T.int64(4)
):
with T.block("T_layout_trans"):
v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5 = T.axis.remap(
"SSSSSS", [ax0, ax1, ax2, ax3, ax4, ax5]
)
T.reads(
A[
v_ax0 * T.int64(4) + v_ax5,
v_ax1 * T.int64(4) + v_ax4,
v_ax2,
v_ax3,
]
)
T.writes(T_layout_trans[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5])
T.block_attr(
{
"dst_layout": "OIHW4i4o",
"input_shape": [
T.int64(512),
T.int64(1024),
T.int64(1),
T.int64(1),
],
"schedule_rule": "None",
"src_layout": "OIHW",
}
)
T_layout_trans[
v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5
] = T.if_then_else(
v_ax0 * T.int64(4) + v_ax5 < T.int64(512)
and v_ax1 * T.int64(4) + v_ax4 < T.int64(1024)
and v_ax2 < T.int64(1)
and v_ax3 < T.int64(1),
A[
v_ax0 * T.int64(4) + v_ax5,
v_ax1 * T.int64(4) + v_ax4,
v_ax2,
v_ax3,
],
T.float32(0),
)
@T.prim_func
def layout_transform22(
A: T.Buffer((T.int64(512), T.int64(512), T.int64(3), T.int64(3)), "float32"),
T_layout_trans: T.Buffer(
(
T.int64(128),
T.int64(128),
T.int64(3),
T.int64(3),
T.int64(4),
T.int64(4),
),
"float32",
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2, ax3, ax4, ax5 in T.grid(
T.int64(128), T.int64(128), T.int64(3), T.int64(3), T.int64(4), T.int64(4)
):
with T.block("T_layout_trans"):
v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5 = T.axis.remap(
"SSSSSS", [ax0, ax1, ax2, ax3, ax4, ax5]
)
T.reads(
A[
v_ax0 * T.int64(4) + v_ax5,
v_ax1 * T.int64(4) + v_ax4,
v_ax2,
v_ax3,
]
)
T.writes(T_layout_trans[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5])
T.block_attr(
{
"dst_layout": "OIHW4i4o",
"input_shape": [
T.int64(512),
T.int64(512),
T.int64(3),
T.int64(3),
],
"schedule_rule": "None",
"src_layout": "OIHW",
}
)
T_layout_trans[
v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5
] = T.if_then_else(
v_ax0 * T.int64(4) + v_ax5 < T.int64(512)
and v_ax1 * T.int64(4) + v_ax4 < T.int64(512)
and v_ax2 < T.int64(3)
and v_ax3 < T.int64(3),
A[
v_ax0 * T.int64(4) + v_ax5,
v_ax1 * T.int64(4) + v_ax4,
v_ax2,
v_ax3,
],
T.float32(0),
)
@T.prim_func
def layout_transform23(
A: T.Buffer((T.int64(2048), T.int64(512), T.int64(1), T.int64(1)), "float32"),
T_layout_trans: T.Buffer(
(
T.int64(512),
T.int64(128),
T.int64(1),
T.int64(1),
T.int64(4),
T.int64(4),
),
"float32",
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2, ax3, ax4, ax5 in T.grid(
T.int64(512), T.int64(128), T.int64(1), T.int64(1), T.int64(4), T.int64(4)
):
with T.block("T_layout_trans"):
v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5 = T.axis.remap(
"SSSSSS", [ax0, ax1, ax2, ax3, ax4, ax5]
)
T.reads(
A[
v_ax0 * T.int64(4) + v_ax5,
v_ax1 * T.int64(4) + v_ax4,
v_ax2,
v_ax3,
]
)
T.writes(T_layout_trans[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5])
T.block_attr(
{
"dst_layout": "OIHW4i4o",
"input_shape": [
T.int64(2048),
T.int64(512),
T.int64(1),
T.int64(1),
],
"schedule_rule": "None",
"src_layout": "OIHW",
}
)
T_layout_trans[
v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5
] = T.if_then_else(
v_ax0 * T.int64(4) + v_ax5 < T.int64(2048)
and v_ax1 * T.int64(4) + v_ax4 < T.int64(512)
and v_ax2 < T.int64(1)
and v_ax3 < T.int64(1),
A[
v_ax0 * T.int64(4) + v_ax5,
v_ax1 * T.int64(4) + v_ax4,
v_ax2,
v_ax3,
],
T.float32(0),
)
@T.prim_func
def layout_transform24(
A: T.Buffer((T.int64(2048), T.int64(1024), T.int64(1), T.int64(1)), "float32"),
T_layout_trans: T.Buffer(
(
T.int64(512),
T.int64(256),
T.int64(1),
T.int64(1),
T.int64(4),
T.int64(4),
),
"float32",
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2, ax3, ax4, ax5 in T.grid(
T.int64(512), T.int64(256), T.int64(1), T.int64(1), T.int64(4), T.int64(4)
):
with T.block("T_layout_trans"):
v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5 = T.axis.remap(
"SSSSSS", [ax0, ax1, ax2, ax3, ax4, ax5]
)
T.reads(
A[
v_ax0 * T.int64(4) + v_ax5,
v_ax1 * T.int64(4) + v_ax4,
v_ax2,
v_ax3,
]
)
T.writes(T_layout_trans[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5])
T.block_attr(
{
"dst_layout": "OIHW4i4o",
"input_shape": [
T.int64(2048),
T.int64(1024),
T.int64(1),
T.int64(1),
],
"schedule_rule": "None",
"src_layout": "OIHW",
}
)
T_layout_trans[
v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5
] = T.if_then_else(
v_ax0 * T.int64(4) + v_ax5 < T.int64(2048)
and v_ax1 * T.int64(4) + v_ax4 < T.int64(1024)
and v_ax2 < T.int64(1)
and v_ax3 < T.int64(1),
A[
v_ax0 * T.int64(4) + v_ax5,
v_ax1 * T.int64(4) + v_ax4,
v_ax2,
v_ax3,
],
T.float32(0),
)
@T.prim_func
def layout_transform25(
A: T.Buffer((T.int64(1), T.int64(2048), T.int64(1), T.int64(1)), "float32"),
T_layout_trans: T.Buffer(
(T.int64(1), T.int64(512), T.int64(1), T.int64(1), T.int64(4)), "float32"
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2, ax3, ax4 in T.grid(
T.int64(1), T.int64(512), T.int64(1), T.int64(1), T.int64(4)
):
with T.block("T_layout_trans"):
v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap(
"SSSSS", [ax0, ax1, ax2, ax3, ax4]
)
T.reads(A[v_ax0, v_ax1 * T.int64(4) + v_ax4, v_ax2, v_ax3])
T.writes(T_layout_trans[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
T.block_attr(
{
"dst_layout": "NCHW4c",
"input_shape": [
T.int64(1),
T.int64(2048),
T.int64(1),
T.int64(1),
],
"schedule_rule": "None",
"src_layout": "NCHW",
}
)
T_layout_trans[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = T.if_then_else(
v_ax0 < T.int64(1)
and v_ax1 * T.int64(4) + v_ax4 < T.int64(2048)
and v_ax2 < T.int64(1)
and v_ax3 < T.int64(1),
A[v_ax0, v_ax1 * T.int64(4) + v_ax4, v_ax2, v_ax3],
T.float32(0),
)
@T.prim_func
def layout_transform26(
A: T.Buffer((T.int64(512), T.int64(2048), T.int64(1), T.int64(1)), "float32"),
T_layout_trans: T.Buffer(
(
T.int64(128),
T.int64(512),
T.int64(1),
T.int64(1),
T.int64(4),
T.int64(4),
),
"float32",
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2, ax3, ax4, ax5 in T.grid(
T.int64(128), T.int64(512), T.int64(1), T.int64(1), T.int64(4), T.int64(4)
):
with T.block("T_layout_trans"):
v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5 = T.axis.remap(
"SSSSSS", [ax0, ax1, ax2, ax3, ax4, ax5]
)
T.reads(
A[
v_ax0 * T.int64(4) + v_ax5,
v_ax1 * T.int64(4) + v_ax4,
v_ax2,
v_ax3,
]
)
T.writes(T_layout_trans[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5])
T.block_attr(
{
"dst_layout": "OIHW4i4o",
"input_shape": [
T.int64(512),
T.int64(2048),
T.int64(1),
T.int64(1),
],
"schedule_rule": "None",
"src_layout": "OIHW",
}
)
T_layout_trans[
v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5
] = T.if_then_else(
v_ax0 * T.int64(4) + v_ax5 < T.int64(512)
and v_ax1 * T.int64(4) + v_ax4 < T.int64(2048)
and v_ax2 < T.int64(1)
and v_ax3 < T.int64(1),
A[
v_ax0 * T.int64(4) + v_ax5,
v_ax1 * T.int64(4) + v_ax4,
v_ax2,
v_ax3,
],
T.float32(0),
)
@T.prim_func
def layout_transform27(
A: T.Buffer(
(T.int64(1), T.int64(512), T.int64(1), T.int64(1), T.int64(4)), "float32"
),
T_layout_trans: T.Buffer(
(T.int64(1), T.int64(2048), T.int64(1), T.int64(1)), "float32"
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2, ax3 in T.grid(
T.int64(1), T.int64(2048), T.int64(1), T.int64(1)
):
with T.block("T_layout_trans"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(A[v_ax0, v_ax1 // T.int64(4), v_ax2, v_ax3, v_ax1 % T.int64(4)])
T.writes(T_layout_trans[v_ax0, v_ax1, v_ax2, v_ax3])
T.block_attr(
{
"dst_layout": "NCHW",
"input_shape": [
T.int64(1),
T.int64(512),
T.int64(1),
T.int64(1),
T.int64(4),
],
"schedule_rule": "None",
"src_layout": "NCHW4c",
}
)
T_layout_trans[v_ax0, v_ax1, v_ax2, v_ax3] = T.if_then_else(
v_ax0 < T.int64(1)
and v_ax1 < T.int64(2048)
and v_ax2 < T.int64(1)
and v_ax3 < T.int64(1),
A[v_ax0, v_ax1 // T.int64(4), v_ax2, v_ax3, v_ax1 % T.int64(4)],
T.float32(0),
)
@T.prim_func
def layout_transform3(
A: T.Buffer((T.int64(64), T.int64(64), T.int64(1), T.int64(1)), "float32"),
T_layout_trans: T.Buffer(
(T.int64(16), T.int64(16), T.int64(1), T.int64(1), T.int64(4), T.int64(4)),
"float32",
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2, ax3, ax4, ax5 in T.grid(
T.int64(16), T.int64(16), T.int64(1), T.int64(1), T.int64(4), T.int64(4)
):
with T.block("T_layout_trans"):
v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5 = T.axis.remap(
"SSSSSS", [ax0, ax1, ax2, ax3, ax4, ax5]
)
T.reads(
A[
v_ax0 * T.int64(4) + v_ax5,
v_ax1 * T.int64(4) + v_ax4,
v_ax2,
v_ax3,
]
)
T.writes(T_layout_trans[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5])
T.block_attr(
{
"dst_layout": "OIHW4i4o",
"input_shape": [
T.int64(64),
T.int64(64),
T.int64(1),
T.int64(1),
],
"schedule_rule": "None",
"src_layout": "OIHW",
}
)
T_layout_trans[
v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5
] = T.if_then_else(
v_ax0 * T.int64(4) + v_ax5 < T.int64(64)
and v_ax1 * T.int64(4) + v_ax4 < T.int64(64)
and v_ax2 < T.int64(1)
and v_ax3 < T.int64(1),
A[
v_ax0 * T.int64(4) + v_ax5,
v_ax1 * T.int64(4) + v_ax4,
v_ax2,
v_ax3,
],
T.float32(0),
)
@T.prim_func
def layout_transform4(
A: T.Buffer((T.int64(64), T.int64(64), T.int64(3), T.int64(3)), "float32"),
T_layout_trans: T.Buffer(
(T.int64(16), T.int64(16), T.int64(3), T.int64(3), T.int64(4), T.int64(4)),
"float32",
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2, ax3, ax4, ax5 in T.grid(
T.int64(16), T.int64(16), T.int64(3), T.int64(3), T.int64(4), T.int64(4)
):
with T.block("T_layout_trans"):
v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5 = T.axis.remap(
"SSSSSS", [ax0, ax1, ax2, ax3, ax4, ax5]
)
T.reads(
A[
v_ax0 * T.int64(4) + v_ax5,
v_ax1 * T.int64(4) + v_ax4,
v_ax2,
v_ax3,
]
)
T.writes(T_layout_trans[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5])
T.block_attr(
{
"dst_layout": "OIHW4i4o",
"input_shape": [
T.int64(64),
T.int64(64),
T.int64(3),
T.int64(3),
],
"schedule_rule": "None",
"src_layout": "OIHW",
}
)
T_layout_trans[
v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5
] = T.if_then_else(
v_ax0 * T.int64(4) + v_ax5 < T.int64(64)
and v_ax1 * T.int64(4) + v_ax4 < T.int64(64)
and v_ax2 < T.int64(3)
and v_ax3 < T.int64(3),
A[
v_ax0 * T.int64(4) + v_ax5,
v_ax1 * T.int64(4) + v_ax4,
v_ax2,
v_ax3,
],
T.float32(0),
)
@T.prim_func
def layout_transform5(
A: T.Buffer((T.int64(256), T.int64(64), T.int64(1), T.int64(1)), "float32"),
T_layout_trans: T.Buffer(
(T.int64(64), T.int64(16), T.int64(1), T.int64(1), T.int64(4), T.int64(4)),
"float32",
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2, ax3, ax4, ax5 in T.grid(
T.int64(64), T.int64(16), T.int64(1), T.int64(1), T.int64(4), T.int64(4)
):
with T.block("T_layout_trans"):
v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5 = T.axis.remap(
"SSSSSS", [ax0, ax1, ax2, ax3, ax4, ax5]
)
T.reads(
A[
v_ax0 * T.int64(4) + v_ax5,
v_ax1 * T.int64(4) + v_ax4,
v_ax2,
v_ax3,
]
)
T.writes(T_layout_trans[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5])
T.block_attr(
{
"dst_layout": "OIHW4i4o",
"input_shape": [
T.int64(256),
T.int64(64),
T.int64(1),
T.int64(1),
],
"schedule_rule": "None",
"src_layout": "OIHW",
}
)
T_layout_trans[
v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5
] = T.if_then_else(
v_ax0 * T.int64(4) + v_ax5 < T.int64(256)
and v_ax1 * T.int64(4) + v_ax4 < T.int64(64)
and v_ax2 < T.int64(1)
and v_ax3 < T.int64(1),
A[
v_ax0 * T.int64(4) + v_ax5,
v_ax1 * T.int64(4) + v_ax4,
v_ax2,
v_ax3,
],
T.float32(0),
)
@T.prim_func
def layout_transform6(
A: T.Buffer((T.int64(1), T.int64(256), T.int64(1), T.int64(1)), "float32"),
T_layout_trans: T.Buffer(
(T.int64(1), T.int64(64), T.int64(1), T.int64(1), T.int64(4)), "float32"
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2, ax3, ax4 in T.grid(
T.int64(1), T.int64(64), T.int64(1), T.int64(1), T.int64(4)
):
with T.block("T_layout_trans"):
v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap(
"SSSSS", [ax0, ax1, ax2, ax3, ax4]
)
T.reads(A[v_ax0, v_ax1 * T.int64(4) + v_ax4, v_ax2, v_ax3])
T.writes(T_layout_trans[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
T.block_attr(
{
"dst_layout": "NCHW4c",
"input_shape": [
T.int64(1),
T.int64(256),
T.int64(1),
T.int64(1),
],
"schedule_rule": "None",
"src_layout": "NCHW",
}
)
T_layout_trans[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = T.if_then_else(
v_ax0 < T.int64(1)
and v_ax1 * T.int64(4) + v_ax4 < T.int64(256)
and v_ax2 < T.int64(1)
and v_ax3 < T.int64(1),
A[v_ax0, v_ax1 * T.int64(4) + v_ax4, v_ax2, v_ax3],
T.float32(0),
)
@T.prim_func
def layout_transform7(
A: T.Buffer((T.int64(64), T.int64(256), T.int64(1), T.int64(1)), "float32"),
T_layout_trans: T.Buffer(
(T.int64(16), T.int64(64), T.int64(1), T.int64(1), T.int64(4), T.int64(4)),
"float32",
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2, ax3, ax4, ax5 in T.grid(
T.int64(16), T.int64(64), T.int64(1), T.int64(1), T.int64(4), T.int64(4)
):
with T.block("T_layout_trans"):
v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5 = T.axis.remap(
"SSSSSS", [ax0, ax1, ax2, ax3, ax4, ax5]
)
T.reads(
A[
v_ax0 * T.int64(4) + v_ax5,
v_ax1 * T.int64(4) + v_ax4,
v_ax2,
v_ax3,
]
)
T.writes(T_layout_trans[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5])
T.block_attr(
{
"dst_layout": "OIHW4i4o",
"input_shape": [
T.int64(64),
T.int64(256),
T.int64(1),
T.int64(1),
],
"schedule_rule": "None",
"src_layout": "OIHW",
}
)
T_layout_trans[
v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5
] = T.if_then_else(
v_ax0 * T.int64(4) + v_ax5 < T.int64(64)
and v_ax1 * T.int64(4) + v_ax4 < T.int64(256)
and v_ax2 < T.int64(1)
and v_ax3 < T.int64(1),
A[
v_ax0 * T.int64(4) + v_ax5,
v_ax1 * T.int64(4) + v_ax4,
v_ax2,
v_ax3,
],
T.float32(0),
)
@T.prim_func
def layout_transform8(
A: T.Buffer((T.int64(128), T.int64(256), T.int64(1), T.int64(1)), "float32"),
T_layout_trans: T.Buffer(
(T.int64(32), T.int64(64), T.int64(1), T.int64(1), T.int64(4), T.int64(4)),
"float32",
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2, ax3, ax4, ax5 in T.grid(
T.int64(32), T.int64(64), T.int64(1), T.int64(1), T.int64(4), T.int64(4)
):
with T.block("T_layout_trans"):
v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5 = T.axis.remap(
"SSSSSS", [ax0, ax1, ax2, ax3, ax4, ax5]
)
T.reads(
A[
v_ax0 * T.int64(4) + v_ax5,
v_ax1 * T.int64(4) + v_ax4,
v_ax2,
v_ax3,
]
)
T.writes(T_layout_trans[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5])
T.block_attr(
{
"dst_layout": "OIHW4i4o",
"input_shape": [
T.int64(128),
T.int64(256),
T.int64(1),
T.int64(1),
],
"schedule_rule": "None",
"src_layout": "OIHW",
}
)
T_layout_trans[
v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_ax5
] = T.if_then_else(
v_ax0 * T.int64(4) + v_ax5 < T.int64(128)
and v_ax1 * T.int64(4) + v_ax4 < T.int64(256)
and v_ax2 < T.int64(1)
and v_ax3 < T.int64(1),
A[
v_ax0 * T.int64(4) + v_ax5,
v_ax1 * T.int64(4) + v_ax4,
v_ax2,
v_ax3,
],
T.float32(0),
)
@T.prim_func
def layout_transform9(
A: T.Buffer((T.int64(1), T.int64(128), T.int64(1), T.int64(1)), "float32"),
T_layout_trans: T.Buffer(
(T.int64(1), T.int64(32), T.int64(1), T.int64(1), T.int64(4)), "float32"
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2, ax3, ax4 in T.grid(
T.int64(1), T.int64(32), T.int64(1), T.int64(1), T.int64(4)
):
with T.block("T_layout_trans"):
v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap(
"SSSSS", [ax0, ax1, ax2, ax3, ax4]
)
T.reads(A[v_ax0, v_ax1 * T.int64(4) + v_ax4, v_ax2, v_ax3])
T.writes(T_layout_trans[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
T.block_attr(
{
"dst_layout": "NCHW4c",
"input_shape": [
T.int64(1),
T.int64(128),
T.int64(1),
T.int64(1),
],
"schedule_rule": "None",
"src_layout": "NCHW",
}
)
T_layout_trans[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = T.if_then_else(
v_ax0 < T.int64(1)
and v_ax1 * T.int64(4) + v_ax4 < T.int64(128)
and v_ax2 < T.int64(1)
and v_ax3 < T.int64(1),
A[v_ax0, v_ax1 * T.int64(4) + v_ax4, v_ax2, v_ax3],
T.float32(0),
)
@T.prim_func
def max_pool2d(
A: T.Buffer(
(T.int64(1), T.int64(16), T.int64(112), T.int64(112), T.int64(4)), "float32"
),
pool_max: T.Buffer(
(T.int64(1), T.int64(16), T.int64(56), T.int64(56), T.int64(4)), "float32"
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
pad_temp = T.alloc_buffer(
(T.int64(1), T.int64(16), T.int64(114), T.int64(114), T.int64(4))
)
for ax0, ax1, ax2, ax3, ax4 in T.grid(
T.int64(1), T.int64(16), T.int64(114), T.int64(114), T.int64(4)
):
with T.block("pad_temp"):
v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap(
"SSSSS", [ax0, ax1, ax2, ax3, ax4]
)
T.reads(A[v_ax0, v_ax1, v_ax2 - T.int64(1), v_ax3 - T.int64(1), v_ax4])
T.writes(pad_temp[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
pad_temp[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = T.if_then_else(
T.int64(1) <= v_ax2
and v_ax2 < T.int64(113)
and T.int64(1) <= v_ax3
and v_ax3 < T.int64(113),
A[v_ax0, v_ax1, v_ax2 - T.int64(1), v_ax3 - T.int64(1), v_ax4],
T.float32(-3.4028234663852886e38),
)
for ax0, ax1, ax2, ax3, ax4, rv0, rv1 in T.grid(
T.int64(1),
T.int64(16),
T.int64(56),
T.int64(56),
T.int64(4),
T.int64(3),
T.int64(3),
):
with T.block("pool_max"):
v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_rv0, v_rv1 = T.axis.remap(
"SSSSSRR", [ax0, ax1, ax2, ax3, ax4, rv0, rv1]
)
T.reads(
pad_temp[
v_ax0,
v_ax1,
v_ax2 * T.int64(2) + v_rv0,
v_ax3 * T.int64(2) + v_rv1,
v_ax4,
]
)
T.writes(pool_max[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
T.block_attr({"schedule_rule": "meta_schedule.pool_max"})
with T.init():
pool_max[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = T.float32(
-3.4028234663852886e38
)
pool_max[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = T.max(
pool_max[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4],
pad_temp[
v_ax0,
v_ax1,
v_ax2 * T.int64(2) + v_rv0,
v_ax3 * T.int64(2) + v_rv1,
v_ax4,
],
)
@T.prim_func
def multiply(
A: T.Buffer((T.int64(3),), "float32"),
B: T.Buffer((T.int64(3),), "float32"),
T_multiply: T.Buffer((T.int64(3),), "float32"),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0 in range(T.int64(3)):
with T.block("T_multiply"):
v_ax0 = T.axis.spatial(T.int64(3), ax0)
T.reads(A[v_ax0], B[v_ax0])
T.writes(T_multiply[v_ax0])
T_multiply[v_ax0] = A[v_ax0] * B[v_ax0]
@T.prim_func
def multiply1(
A: T.Buffer((T.int64(64),), "float32"),
B: T.Buffer((T.int64(64),), "float32"),
T_multiply: T.Buffer((T.int64(64),), "float32"),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0 in range(T.int64(64)):
with T.block("T_multiply"):
v_ax0 = T.axis.spatial(T.int64(64), ax0)
T.reads(A[v_ax0], B[v_ax0])
T.writes(T_multiply[v_ax0])
T_multiply[v_ax0] = A[v_ax0] * B[v_ax0]
@T.prim_func
def multiply10(
A: T.Buffer((T.int64(128),), "float32"),
B: T.Buffer((T.int64(128),), "float32"),
T_multiply: T.Buffer((T.int64(128),), "float32"),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0 in range(T.int64(128)):
with T.block("T_multiply"):
v_ax0 = T.axis.spatial(T.int64(128), ax0)
T.reads(A[v_ax0], B[v_ax0])
T.writes(T_multiply[v_ax0])
T_multiply[v_ax0] = A[v_ax0] * B[v_ax0]
@T.prim_func
def multiply11(
A: T.Buffer((T.int64(128), T.int64(256), T.int64(1), T.int64(1)), "float32"),
B: T.Buffer((T.int64(128), T.int64(1), T.int64(1), T.int64(1)), "float32"),
T_multiply: T.Buffer(
(T.int64(128), T.int64(256), T.int64(1), T.int64(1)), "float32"
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2, ax3 in T.grid(
T.int64(128), T.int64(256), T.int64(1), T.int64(1)
):
with T.block("T_multiply"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(
A[v_ax0, v_ax1, v_ax2, v_ax3], B[v_ax0, T.int64(0), v_ax2, v_ax3]
)
T.writes(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3])
T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] = (
A[v_ax0, v_ax1, v_ax2, v_ax3] * B[v_ax0, T.int64(0), v_ax2, v_ax3]
)
@T.prim_func
def multiply12(
A: T.Buffer((T.int64(128), T.int64(128), T.int64(3), T.int64(3)), "float32"),
B: T.Buffer((T.int64(128), T.int64(1), T.int64(1), T.int64(1)), "float32"),
T_multiply: T.Buffer(
(T.int64(128), T.int64(128), T.int64(3), T.int64(3)), "float32"
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2, ax3 in T.grid(
T.int64(128), T.int64(128), T.int64(3), T.int64(3)
):
with T.block("T_multiply"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(
A[v_ax0, v_ax1, v_ax2, v_ax3],
B[v_ax0, T.int64(0), T.int64(0), T.int64(0)],
)
T.writes(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3])
T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] = (
A[v_ax0, v_ax1, v_ax2, v_ax3]
* B[v_ax0, T.int64(0), T.int64(0), T.int64(0)]
)
@T.prim_func
def multiply13(
A: T.Buffer((T.int64(512),), "float32"),
B: T.Buffer((T.int64(512),), "float32"),
T_multiply: T.Buffer((T.int64(512),), "float32"),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0 in range(T.int64(512)):
with T.block("T_multiply"):
v_ax0 = T.axis.spatial(T.int64(512), ax0)
T.reads(A[v_ax0], B[v_ax0])
T.writes(T_multiply[v_ax0])
T_multiply[v_ax0] = A[v_ax0] * B[v_ax0]
@T.prim_func
def multiply14(
A: T.Buffer(
(T.int64(1), T.int64(128), T.int64(28), T.int64(28), T.int64(4)), "float32"
),
B: T.Buffer(
(T.int64(1), T.int64(128), T.int64(1), T.int64(1), T.int64(4)), "float32"
),
T_multiply: T.Buffer(
(T.int64(1), T.int64(128), T.int64(28), T.int64(28), T.int64(4)), "float32"
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2, ax3, ax4 in T.grid(
T.int64(1), T.int64(128), T.int64(28), T.int64(28), T.int64(4)
):
with T.block("T_multiply"):
v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap(
"SSSSS", [ax0, ax1, ax2, ax3, ax4]
)
T.reads(
A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4],
B[v_ax0, v_ax1, T.int64(0), T.int64(0), v_ax4],
)
T.writes(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
T_multiply[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = (
A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]
* B[v_ax0, v_ax1, T.int64(0), T.int64(0), v_ax4]
)
@T.prim_func
def multiply15(
A: T.Buffer((T.int64(128), T.int64(512), T.int64(1), T.int64(1)), "float32"),
B: T.Buffer((T.int64(128), T.int64(1), T.int64(1), T.int64(1)), "float32"),
T_multiply: T.Buffer(
(T.int64(128), T.int64(512), T.int64(1), T.int64(1)), "float32"
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2, ax3 in T.grid(
T.int64(128), T.int64(512), T.int64(1), T.int64(1)
):
with T.block("T_multiply"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(
A[v_ax0, v_ax1, v_ax2, v_ax3], B[v_ax0, T.int64(0), v_ax2, v_ax3]
)
T.writes(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3])
T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] = (
A[v_ax0, v_ax1, v_ax2, v_ax3] * B[v_ax0, T.int64(0), v_ax2, v_ax3]
)
@T.prim_func
def multiply16(
A: T.Buffer((T.int64(256), T.int64(512), T.int64(1), T.int64(1)), "float32"),
B: T.Buffer((T.int64(256), T.int64(1), T.int64(1), T.int64(1)), "float32"),
T_multiply: T.Buffer(
(T.int64(256), T.int64(512), T.int64(1), T.int64(1)), "float32"
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2, ax3 in T.grid(
T.int64(256), T.int64(512), T.int64(1), T.int64(1)
):
with T.block("T_multiply"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(
A[v_ax0, v_ax1, v_ax2, v_ax3], B[v_ax0, T.int64(0), v_ax2, v_ax3]
)
T.writes(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3])
T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] = (
A[v_ax0, v_ax1, v_ax2, v_ax3] * B[v_ax0, T.int64(0), v_ax2, v_ax3]
)
@T.prim_func
def multiply17(
A: T.Buffer((T.int64(256), T.int64(256), T.int64(3), T.int64(3)), "float32"),
B: T.Buffer((T.int64(256), T.int64(1), T.int64(1), T.int64(1)), "float32"),
T_multiply: T.Buffer(
(T.int64(256), T.int64(256), T.int64(3), T.int64(3)), "float32"
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2, ax3 in T.grid(
T.int64(256), T.int64(256), T.int64(3), T.int64(3)
):
with T.block("T_multiply"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(
A[v_ax0, v_ax1, v_ax2, v_ax3],
B[v_ax0, T.int64(0), T.int64(0), T.int64(0)],
)
T.writes(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3])
T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] = (
A[v_ax0, v_ax1, v_ax2, v_ax3]
* B[v_ax0, T.int64(0), T.int64(0), T.int64(0)]
)
@T.prim_func
def multiply18(
A: T.Buffer((T.int64(1024),), "float32"),
B: T.Buffer((T.int64(1024),), "float32"),
T_multiply: T.Buffer((T.int64(1024),), "float32"),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0 in range(T.int64(1024)):
with T.block("T_multiply"):
v_ax0 = T.axis.spatial(T.int64(1024), ax0)
T.reads(A[v_ax0], B[v_ax0])
T.writes(T_multiply[v_ax0])
T_multiply[v_ax0] = A[v_ax0] * B[v_ax0]
@T.prim_func
def multiply19(
A: T.Buffer(
(T.int64(1), T.int64(256), T.int64(14), T.int64(14), T.int64(4)), "float32"
),
B: T.Buffer(
(T.int64(1), T.int64(256), T.int64(1), T.int64(1), T.int64(4)), "float32"
),
T_multiply: T.Buffer(
(T.int64(1), T.int64(256), T.int64(14), T.int64(14), T.int64(4)), "float32"
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2, ax3, ax4 in T.grid(
T.int64(1), T.int64(256), T.int64(14), T.int64(14), T.int64(4)
):
with T.block("T_multiply"):
v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap(
"SSSSS", [ax0, ax1, ax2, ax3, ax4]
)
T.reads(
A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4],
B[v_ax0, v_ax1, T.int64(0), T.int64(0), v_ax4],
)
T.writes(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
T_multiply[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = (
A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]
* B[v_ax0, v_ax1, T.int64(0), T.int64(0), v_ax4]
)
@T.prim_func
def multiply2(
A: T.Buffer((T.int64(64), T.int64(3), T.int64(7), T.int64(7)), "float32"),
B: T.Buffer((T.int64(64), T.int64(1), T.int64(1), T.int64(1)), "float32"),
T_multiply: T.Buffer(
(T.int64(64), T.int64(3), T.int64(7), T.int64(7)), "float32"
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2, ax3 in T.grid(
T.int64(64), T.int64(3), T.int64(7), T.int64(7)
):
with T.block("T_multiply"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(
A[v_ax0, v_ax1, v_ax2, v_ax3],
B[v_ax0, T.int64(0), T.int64(0), T.int64(0)],
)
T.writes(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3])
T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] = (
A[v_ax0, v_ax1, v_ax2, v_ax3]
* B[v_ax0, T.int64(0), T.int64(0), T.int64(0)]
)
@T.prim_func
def multiply20(
A: T.Buffer((T.int64(256), T.int64(1024), T.int64(1), T.int64(1)), "float32"),
B: T.Buffer((T.int64(256), T.int64(1), T.int64(1), T.int64(1)), "float32"),
T_multiply: T.Buffer(
(T.int64(256), T.int64(1024), T.int64(1), T.int64(1)), "float32"
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2, ax3 in T.grid(
T.int64(256), T.int64(1024), T.int64(1), T.int64(1)
):
with T.block("T_multiply"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(
A[v_ax0, v_ax1, v_ax2, v_ax3], B[v_ax0, T.int64(0), v_ax2, v_ax3]
)
T.writes(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3])
T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] = (
A[v_ax0, v_ax1, v_ax2, v_ax3] * B[v_ax0, T.int64(0), v_ax2, v_ax3]
)
@T.prim_func
def multiply21(
A: T.Buffer((T.int64(512), T.int64(1024), T.int64(1), T.int64(1)), "float32"),
B: T.Buffer((T.int64(512), T.int64(1), T.int64(1), T.int64(1)), "float32"),
T_multiply: T.Buffer(
(T.int64(512), T.int64(1024), T.int64(1), T.int64(1)), "float32"
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2, ax3 in T.grid(
T.int64(512), T.int64(1024), T.int64(1), T.int64(1)
):
with T.block("T_multiply"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(
A[v_ax0, v_ax1, v_ax2, v_ax3], B[v_ax0, T.int64(0), v_ax2, v_ax3]
)
T.writes(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3])
T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] = (
A[v_ax0, v_ax1, v_ax2, v_ax3] * B[v_ax0, T.int64(0), v_ax2, v_ax3]
)
@T.prim_func
def multiply22(
A: T.Buffer((T.int64(512), T.int64(512), T.int64(3), T.int64(3)), "float32"),
B: T.Buffer((T.int64(512), T.int64(1), T.int64(1), T.int64(1)), "float32"),
T_multiply: T.Buffer(
(T.int64(512), T.int64(512), T.int64(3), T.int64(3)), "float32"
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2, ax3 in T.grid(
T.int64(512), T.int64(512), T.int64(3), T.int64(3)
):
with T.block("T_multiply"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(
A[v_ax0, v_ax1, v_ax2, v_ax3],
B[v_ax0, T.int64(0), T.int64(0), T.int64(0)],
)
T.writes(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3])
T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] = (
A[v_ax0, v_ax1, v_ax2, v_ax3]
* B[v_ax0, T.int64(0), T.int64(0), T.int64(0)]
)
@T.prim_func
def multiply23(
A: T.Buffer((T.int64(2048),), "float32"),
B: T.Buffer((T.int64(2048),), "float32"),
T_multiply: T.Buffer((T.int64(2048),), "float32"),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0 in range(T.int64(2048)):
with T.block("T_multiply"):
v_ax0 = T.axis.spatial(T.int64(2048), ax0)
T.reads(A[v_ax0], B[v_ax0])
T.writes(T_multiply[v_ax0])
T_multiply[v_ax0] = A[v_ax0] * B[v_ax0]
@T.prim_func
def multiply24(
A: T.Buffer(
(T.int64(1), T.int64(512), T.int64(7), T.int64(7), T.int64(4)), "float32"
),
B: T.Buffer(
(T.int64(1), T.int64(512), T.int64(1), T.int64(1), T.int64(4)), "float32"
),
T_multiply: T.Buffer(
(T.int64(1), T.int64(512), T.int64(7), T.int64(7), T.int64(4)), "float32"
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2, ax3, ax4 in T.grid(
T.int64(1), T.int64(512), T.int64(7), T.int64(7), T.int64(4)
):
with T.block("T_multiply"):
v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap(
"SSSSS", [ax0, ax1, ax2, ax3, ax4]
)
T.reads(
A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4],
B[v_ax0, v_ax1, T.int64(0), T.int64(0), v_ax4],
)
T.writes(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
T_multiply[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = (
A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]
* B[v_ax0, v_ax1, T.int64(0), T.int64(0), v_ax4]
)
@T.prim_func
def multiply25(
A: T.Buffer((T.int64(512), T.int64(2048), T.int64(1), T.int64(1)), "float32"),
B: T.Buffer((T.int64(512), T.int64(1), T.int64(1), T.int64(1)), "float32"),
T_multiply: T.Buffer(
(T.int64(512), T.int64(2048), T.int64(1), T.int64(1)), "float32"
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2, ax3 in T.grid(
T.int64(512), T.int64(2048), T.int64(1), T.int64(1)
):
with T.block("T_multiply"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(
A[v_ax0, v_ax1, v_ax2, v_ax3], B[v_ax0, T.int64(0), v_ax2, v_ax3]
)
T.writes(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3])
T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] = (
A[v_ax0, v_ax1, v_ax2, v_ax3] * B[v_ax0, T.int64(0), v_ax2, v_ax3]
)
@T.prim_func
def multiply3(
A: T.Buffer((T.int64(64), T.int64(3), T.int64(7), T.int64(7)), "float32"),
B: T.Buffer((T.int64(3), T.int64(1), T.int64(1)), "float32"),
T_multiply: T.Buffer(
(T.int64(64), T.int64(3), T.int64(7), T.int64(7)), "float32"
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2, ax3 in T.grid(
T.int64(64), T.int64(3), T.int64(7), T.int64(7)
):
with T.block("T_multiply"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(A[v_ax0, v_ax1, v_ax2, v_ax3], B[v_ax1, T.int64(0), T.int64(0)])
T.writes(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3])
T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] = (
A[v_ax0, v_ax1, v_ax2, v_ax3] * B[v_ax1, T.int64(0), T.int64(0)]
)
@T.prim_func
def multiply4(
A: T.Buffer(
(T.int64(1), T.int64(16), T.int64(56), T.int64(56), T.int64(4)), "float32"
),
B: T.Buffer(
(T.int64(1), T.int64(16), T.int64(1), T.int64(1), T.int64(4)), "float32"
),
T_multiply: T.Buffer(
(T.int64(1), T.int64(16), T.int64(56), T.int64(56), T.int64(4)), "float32"
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2, ax3, ax4 in T.grid(
T.int64(1), T.int64(16), T.int64(56), T.int64(56), T.int64(4)
):
with T.block("T_multiply"):
v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap(
"SSSSS", [ax0, ax1, ax2, ax3, ax4]
)
T.reads(
A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4],
B[v_ax0, v_ax1, T.int64(0), T.int64(0), v_ax4],
)
T.writes(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
T_multiply[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = (
A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]
* B[v_ax0, v_ax1, T.int64(0), T.int64(0), v_ax4]
)
@T.prim_func
def multiply5(
A: T.Buffer((T.int64(64), T.int64(64), T.int64(1), T.int64(1)), "float32"),
B: T.Buffer((T.int64(64), T.int64(1), T.int64(1), T.int64(1)), "float32"),
T_multiply: T.Buffer(
(T.int64(64), T.int64(64), T.int64(1), T.int64(1)), "float32"
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2, ax3 in T.grid(
T.int64(64), T.int64(64), T.int64(1), T.int64(1)
):
with T.block("T_multiply"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(
A[v_ax0, v_ax1, v_ax2, v_ax3], B[v_ax0, T.int64(0), v_ax2, v_ax3]
)
T.writes(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3])
T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] = (
A[v_ax0, v_ax1, v_ax2, v_ax3] * B[v_ax0, T.int64(0), v_ax2, v_ax3]
)
@T.prim_func
def multiply6(
A: T.Buffer((T.int64(64), T.int64(64), T.int64(3), T.int64(3)), "float32"),
B: T.Buffer((T.int64(64), T.int64(1), T.int64(1), T.int64(1)), "float32"),
T_multiply: T.Buffer(
(T.int64(64), T.int64(64), T.int64(3), T.int64(3)), "float32"
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2, ax3 in T.grid(
T.int64(64), T.int64(64), T.int64(3), T.int64(3)
):
with T.block("T_multiply"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(
A[v_ax0, v_ax1, v_ax2, v_ax3],
B[v_ax0, T.int64(0), T.int64(0), T.int64(0)],
)
T.writes(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3])
T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] = (
A[v_ax0, v_ax1, v_ax2, v_ax3]
* B[v_ax0, T.int64(0), T.int64(0), T.int64(0)]
)
@T.prim_func
def multiply7(
A: T.Buffer((T.int64(256),), "float32"),
B: T.Buffer((T.int64(256),), "float32"),
T_multiply: T.Buffer((T.int64(256),), "float32"),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0 in range(T.int64(256)):
with T.block("T_multiply"):
v_ax0 = T.axis.spatial(T.int64(256), ax0)
T.reads(A[v_ax0], B[v_ax0])
T.writes(T_multiply[v_ax0])
T_multiply[v_ax0] = A[v_ax0] * B[v_ax0]
@T.prim_func
def multiply8(
A: T.Buffer(
(T.int64(1), T.int64(64), T.int64(56), T.int64(56), T.int64(4)), "float32"
),
B: T.Buffer(
(T.int64(1), T.int64(64), T.int64(1), T.int64(1), T.int64(4)), "float32"
),
T_multiply: T.Buffer(
(T.int64(1), T.int64(64), T.int64(56), T.int64(56), T.int64(4)), "float32"
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2, ax3, ax4 in T.grid(
T.int64(1), T.int64(64), T.int64(56), T.int64(56), T.int64(4)
):
with T.block("T_multiply"):
v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap(
"SSSSS", [ax0, ax1, ax2, ax3, ax4]
)
T.reads(
A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4],
B[v_ax0, v_ax1, T.int64(0), T.int64(0), v_ax4],
)
T.writes(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
T_multiply[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = (
A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4]
* B[v_ax0, v_ax1, T.int64(0), T.int64(0), v_ax4]
)
@T.prim_func
def multiply9(
A: T.Buffer((T.int64(64), T.int64(256), T.int64(1), T.int64(1)), "float32"),
B: T.Buffer((T.int64(64), T.int64(1), T.int64(1), T.int64(1)), "float32"),
T_multiply: T.Buffer(
(T.int64(64), T.int64(256), T.int64(1), T.int64(1)), "float32"
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2, ax3 in T.grid(
T.int64(64), T.int64(256), T.int64(1), T.int64(1)
):
with T.block("T_multiply"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(
A[v_ax0, v_ax1, v_ax2, v_ax3], B[v_ax0, T.int64(0), v_ax2, v_ax3]
)
T.writes(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3])
T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] = (
A[v_ax0, v_ax1, v_ax2, v_ax3] * B[v_ax0, T.int64(0), v_ax2, v_ax3]
)
@T.prim_func
def negative(
A: T.Buffer((T.int64(3),), "float32"),
T_negative: T.Buffer((T.int64(3),), "float32"),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0 in range(T.int64(3)):
with T.block("T_negative"):
v_ax0 = T.axis.spatial(T.int64(3), ax0)
T.reads(A[v_ax0])
T.writes(T_negative[v_ax0])
T_negative[v_ax0] = T.float32(0) - A[v_ax0]
@T.prim_func
def negative1(
A: T.Buffer((T.int64(64),), "float32"),
T_negative: T.Buffer((T.int64(64),), "float32"),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0 in range(T.int64(64)):
with T.block("T_negative"):
v_ax0 = T.axis.spatial(T.int64(64), ax0)
T.reads(A[v_ax0])
T.writes(T_negative[v_ax0])
T_negative[v_ax0] = T.float32(0) - A[v_ax0]
@T.prim_func
def negative2(
A: T.Buffer((T.int64(256),), "float32"),
T_negative: T.Buffer((T.int64(256),), "float32"),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0 in range(T.int64(256)):
with T.block("T_negative"):
v_ax0 = T.axis.spatial(T.int64(256), ax0)
T.reads(A[v_ax0])
T.writes(T_negative[v_ax0])
T_negative[v_ax0] = T.float32(0) - A[v_ax0]
@T.prim_func
def negative3(
A: T.Buffer((T.int64(128),), "float32"),
T_negative: T.Buffer((T.int64(128),), "float32"),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0 in range(T.int64(128)):
with T.block("T_negative"):
v_ax0 = T.axis.spatial(T.int64(128), ax0)
T.reads(A[v_ax0])
T.writes(T_negative[v_ax0])
T_negative[v_ax0] = T.float32(0) - A[v_ax0]
@T.prim_func
def negative4(
A: T.Buffer((T.int64(512),), "float32"),
T_negative: T.Buffer((T.int64(512),), "float32"),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0 in range(T.int64(512)):
with T.block("T_negative"):
v_ax0 = T.axis.spatial(T.int64(512), ax0)
T.reads(A[v_ax0])
T.writes(T_negative[v_ax0])
T_negative[v_ax0] = T.float32(0) - A[v_ax0]
@T.prim_func
def negative5(
A: T.Buffer((T.int64(1024),), "float32"),
T_negative: T.Buffer((T.int64(1024),), "float32"),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0 in range(T.int64(1024)):
with T.block("T_negative"):
v_ax0 = T.axis.spatial(T.int64(1024), ax0)
T.reads(A[v_ax0])
T.writes(T_negative[v_ax0])
T_negative[v_ax0] = T.float32(0) - A[v_ax0]
@T.prim_func
def negative6(
A: T.Buffer((T.int64(2048),), "float32"),
T_negative: T.Buffer((T.int64(2048),), "float32"),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0 in range(T.int64(2048)):
with T.block("T_negative"):
v_ax0 = T.axis.spatial(T.int64(2048), ax0)
T.reads(A[v_ax0])
T.writes(T_negative[v_ax0])
T_negative[v_ax0] = T.float32(0) - A[v_ax0]
@T.prim_func
def relu(
A: T.Buffer(
(T.int64(1), T.int64(16), T.int64(112), T.int64(112), T.int64(4)), "float32"
),
T_relu: T.Buffer(
(T.int64(1), T.int64(16), T.int64(112), T.int64(112), T.int64(4)), "float32"
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2, ax3, ax4 in T.grid(
T.int64(1), T.int64(16), T.int64(112), T.int64(112), T.int64(4)
):
with T.block("T_relu"):
v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap(
"SSSSS", [ax0, ax1, ax2, ax3, ax4]
)
T.reads(A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
T.writes(T_relu[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
T_relu[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = T.max(
A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4], T.float32(0)
)
@T.prim_func
def relu1(
A: T.Buffer(
(T.int64(1), T.int64(16), T.int64(56), T.int64(56), T.int64(4)), "float32"
),
T_relu: T.Buffer(
(T.int64(1), T.int64(16), T.int64(56), T.int64(56), T.int64(4)), "float32"
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2, ax3, ax4 in T.grid(
T.int64(1), T.int64(16), T.int64(56), T.int64(56), T.int64(4)
):
with T.block("T_relu"):
v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap(
"SSSSS", [ax0, ax1, ax2, ax3, ax4]
)
T.reads(A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
T.writes(T_relu[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
T_relu[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = T.max(
A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4], T.float32(0)
)
@T.prim_func
def relu2(
A: T.Buffer(
(T.int64(1), T.int64(64), T.int64(56), T.int64(56), T.int64(4)), "float32"
),
T_relu: T.Buffer(
(T.int64(1), T.int64(64), T.int64(56), T.int64(56), T.int64(4)), "float32"
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2, ax3, ax4 in T.grid(
T.int64(1), T.int64(64), T.int64(56), T.int64(56), T.int64(4)
):
with T.block("T_relu"):
v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap(
"SSSSS", [ax0, ax1, ax2, ax3, ax4]
)
T.reads(A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
T.writes(T_relu[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
T_relu[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = T.max(
A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4], T.float32(0)
)
@T.prim_func
def relu3(
A: T.Buffer(
(T.int64(1), T.int64(32), T.int64(28), T.int64(28), T.int64(4)), "float32"
),
T_relu: T.Buffer(
(T.int64(1), T.int64(32), T.int64(28), T.int64(28), T.int64(4)), "float32"
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2, ax3, ax4 in T.grid(
T.int64(1), T.int64(32), T.int64(28), T.int64(28), T.int64(4)
):
with T.block("T_relu"):
v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap(
"SSSSS", [ax0, ax1, ax2, ax3, ax4]
)
T.reads(A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
T.writes(T_relu[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
T_relu[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = T.max(
A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4], T.float32(0)
)
@T.prim_func
def relu4(
A: T.Buffer(
(T.int64(1), T.int64(128), T.int64(28), T.int64(28), T.int64(4)), "float32"
),
T_relu: T.Buffer(
(T.int64(1), T.int64(128), T.int64(28), T.int64(28), T.int64(4)), "float32"
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2, ax3, ax4 in T.grid(
T.int64(1), T.int64(128), T.int64(28), T.int64(28), T.int64(4)
):
with T.block("T_relu"):
v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap(
"SSSSS", [ax0, ax1, ax2, ax3, ax4]
)
T.reads(A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
T.writes(T_relu[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
T_relu[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = T.max(
A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4], T.float32(0)
)
@T.prim_func
def relu5(
A: T.Buffer(
(T.int64(1), T.int64(64), T.int64(14), T.int64(14), T.int64(4)), "float32"
),
T_relu: T.Buffer(
(T.int64(1), T.int64(64), T.int64(14), T.int64(14), T.int64(4)), "float32"
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2, ax3, ax4 in T.grid(
T.int64(1), T.int64(64), T.int64(14), T.int64(14), T.int64(4)
):
with T.block("T_relu"):
v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap(
"SSSSS", [ax0, ax1, ax2, ax3, ax4]
)
T.reads(A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
T.writes(T_relu[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
T_relu[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = T.max(
A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4], T.float32(0)
)
@T.prim_func
def relu6(
A: T.Buffer(
(T.int64(1), T.int64(256), T.int64(14), T.int64(14), T.int64(4)), "float32"
),
T_relu: T.Buffer(
(T.int64(1), T.int64(256), T.int64(14), T.int64(14), T.int64(4)), "float32"
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2, ax3, ax4 in T.grid(
T.int64(1), T.int64(256), T.int64(14), T.int64(14), T.int64(4)
):
with T.block("T_relu"):
v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap(
"SSSSS", [ax0, ax1, ax2, ax3, ax4]
)
T.reads(A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
T.writes(T_relu[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
T_relu[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = T.max(
A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4], T.float32(0)
)
@T.prim_func
def relu7(
A: T.Buffer(
(T.int64(1), T.int64(128), T.int64(7), T.int64(7), T.int64(4)), "float32"
),
T_relu: T.Buffer(
(T.int64(1), T.int64(128), T.int64(7), T.int64(7), T.int64(4)), "float32"
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2, ax3, ax4 in T.grid(
T.int64(1), T.int64(128), T.int64(7), T.int64(7), T.int64(4)
):
with T.block("T_relu"):
v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap(
"SSSSS", [ax0, ax1, ax2, ax3, ax4]
)
T.reads(A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
T.writes(T_relu[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
T_relu[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = T.max(
A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4], T.float32(0)
)
@T.prim_func
def relu8(
A: T.Buffer(
(T.int64(1), T.int64(512), T.int64(7), T.int64(7), T.int64(4)), "float32"
),
T_relu: T.Buffer(
(T.int64(1), T.int64(512), T.int64(7), T.int64(7), T.int64(4)), "float32"
),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, ax2, ax3, ax4 in T.grid(
T.int64(1), T.int64(512), T.int64(7), T.int64(7), T.int64(4)
):
with T.block("T_relu"):
v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap(
"SSSSS", [ax0, ax1, ax2, ax3, ax4]
)
T.reads(A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
T.writes(T_relu[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
T_relu[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = T.max(
A[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4], T.float32(0)
)
@T.prim_func
def rsqrt(
A: T.Buffer((T.int64(3),), "float32"),
tensor: T.Buffer((T.int64(3),), "float32"),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0 in range(T.int64(3)):
with T.block("tensor"):
v_ax0 = T.axis.spatial(T.int64(3), ax0)
T.reads(A[v_ax0])
T.writes(tensor[v_ax0])
tensor[v_ax0] = T.float32(1) / T.sqrt(A[v_ax0])
@T.prim_func
def rsqrt1(
A: T.Buffer((T.int64(64),), "float32"),
tensor: T.Buffer((T.int64(64),), "float32"),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0 in range(T.int64(64)):
with T.block("tensor"):
v_ax0 = T.axis.spatial(T.int64(64), ax0)
T.reads(A[v_ax0])
T.writes(tensor[v_ax0])
tensor[v_ax0] = T.float32(1) / T.sqrt(A[v_ax0])
@T.prim_func
def rsqrt2(
A: T.Buffer((T.int64(256),), "float32"),
tensor: T.Buffer((T.int64(256),), "float32"),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0 in range(T.int64(256)):
with T.block("tensor"):
v_ax0 = T.axis.spatial(T.int64(256), ax0)
T.reads(A[v_ax0])
T.writes(tensor[v_ax0])
tensor[v_ax0] = T.float32(1) / T.sqrt(A[v_ax0])
@T.prim_func
def rsqrt3(
A: T.Buffer((T.int64(128),), "float32"),
tensor: T.Buffer((T.int64(128),), "float32"),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0 in range(T.int64(128)):
with T.block("tensor"):
v_ax0 = T.axis.spatial(T.int64(128), ax0)
T.reads(A[v_ax0])
T.writes(tensor[v_ax0])
tensor[v_ax0] = T.float32(1) / T.sqrt(A[v_ax0])
@T.prim_func
def rsqrt4(
A: T.Buffer((T.int64(512),), "float32"),
tensor: T.Buffer((T.int64(512),), "float32"),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0 in range(T.int64(512)):
with T.block("tensor"):
v_ax0 = T.axis.spatial(T.int64(512), ax0)
T.reads(A[v_ax0])
T.writes(tensor[v_ax0])
tensor[v_ax0] = T.float32(1) / T.sqrt(A[v_ax0])
@T.prim_func
def rsqrt5(
A: T.Buffer((T.int64(1024),), "float32"),
tensor: T.Buffer((T.int64(1024),), "float32"),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0 in range(T.int64(1024)):
with T.block("tensor"):
v_ax0 = T.axis.spatial(T.int64(1024), ax0)
T.reads(A[v_ax0])
T.writes(tensor[v_ax0])
tensor[v_ax0] = T.float32(1) / T.sqrt(A[v_ax0])
@T.prim_func
def rsqrt6(
A: T.Buffer((T.int64(2048),), "float32"),
tensor: T.Buffer((T.int64(2048),), "float32"),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0 in range(T.int64(2048)):
with T.block("tensor"):
v_ax0 = T.axis.spatial(T.int64(2048), ax0)
T.reads(A[v_ax0])
T.writes(tensor[v_ax0])
tensor[v_ax0] = T.float32(1) / T.sqrt(A[v_ax0])
@T.prim_func
def softmax(
A: T.Buffer((T.int64(1), T.int64(1000)), "float32"),
T_softmax_norm: T.Buffer((T.int64(1), T.int64(1000)), "float32"),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
T_softmax_maxelem = T.alloc_buffer((T.int64(1),))
T_softmax_exp = T.alloc_buffer((T.int64(1), T.int64(1000)))
T_softmax_expsum = T.alloc_buffer((T.int64(1),))
for i0, k in T.grid(T.int64(1), T.int64(1000)):
with T.block("T_softmax_maxelem"):
v_i0, v_k = T.axis.remap("SR", [i0, k])
T.reads(A[v_i0, v_k])
T.writes(T_softmax_maxelem[v_i0])
with T.init():
T_softmax_maxelem[v_i0] = T.float32(-3.4028234663852886e38)
T_softmax_maxelem[v_i0] = T.max(T_softmax_maxelem[v_i0], A[v_i0, v_k])
for i0, i1 in T.grid(T.int64(1), T.int64(1000)):
with T.block("T_softmax_exp"):
v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
T.reads(A[v_i0, v_i1], T_softmax_maxelem[v_i0])
T.writes(T_softmax_exp[v_i0, v_i1])
T_softmax_exp[v_i0, v_i1] = T.exp(
A[v_i0, v_i1] - T_softmax_maxelem[v_i0]
)
for i0, k in T.grid(T.int64(1), T.int64(1000)):
with T.block("T_softmax_expsum"):
v_i0, v_k = T.axis.remap("SR", [i0, k])
T.reads(T_softmax_exp[v_i0, v_k])
T.writes(T_softmax_expsum[v_i0])
with T.init():
T_softmax_expsum[v_i0] = T.float32(0)
T_softmax_expsum[v_i0] = (
T_softmax_expsum[v_i0] + T_softmax_exp[v_i0, v_k]
)
for i0, i1 in T.grid(T.int64(1), T.int64(1000)):
with T.block("T_softmax_norm"):
v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
T.reads(T_softmax_exp[v_i0, v_i1], T_softmax_expsum[v_i0])
T.writes(T_softmax_norm[v_i0, v_i1])
T.block_attr({"axis": 1})
T_softmax_norm[v_i0, v_i1] = (
T_softmax_exp[v_i0, v_i1] / T_softmax_expsum[v_i0]
)
@T.prim_func
def squeeze(
A: T.Buffer((T.int64(3), T.int64(1), T.int64(1)), "float32"),
T_squeeze: T.Buffer((T.int64(3),), "float32"),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0 in range(T.int64(3)):
with T.block("T_squeeze"):
v_ax0 = T.axis.spatial(T.int64(3), ax0)
T.reads(A[v_ax0, T.int64(0), T.int64(0)])
T.writes(T_squeeze[v_ax0])
T_squeeze[v_ax0] = A[v_ax0, T.int64(0), T.int64(0)]
@T.prim_func
def squeeze1(
A: T.Buffer((T.int64(64), T.int64(1), T.int64(1)), "float32"),
T_squeeze: T.Buffer((T.int64(64),), "float32"),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0 in range(T.int64(64)):
with T.block("T_squeeze"):
v_ax0 = T.axis.spatial(T.int64(64), ax0)
T.reads(A[v_ax0, T.int64(0), T.int64(0)])
T.writes(T_squeeze[v_ax0])
T_squeeze[v_ax0] = A[v_ax0, T.int64(0), T.int64(0)]
@T.prim_func
def squeeze2(
A: T.Buffer((T.int64(128), T.int64(1), T.int64(1)), "float32"),
T_squeeze: T.Buffer((T.int64(128),), "float32"),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0 in range(T.int64(128)):
with T.block("T_squeeze"):
v_ax0 = T.axis.spatial(T.int64(128), ax0)
T.reads(A[v_ax0, T.int64(0), T.int64(0)])
T.writes(T_squeeze[v_ax0])
T_squeeze[v_ax0] = A[v_ax0, T.int64(0), T.int64(0)]
@T.prim_func
def squeeze3(
A: T.Buffer((T.int64(256), T.int64(1), T.int64(1)), "float32"),
T_squeeze: T.Buffer((T.int64(256),), "float32"),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0 in range(T.int64(256)):
with T.block("T_squeeze"):
v_ax0 = T.axis.spatial(T.int64(256), ax0)
T.reads(A[v_ax0, T.int64(0), T.int64(0)])
T.writes(T_squeeze[v_ax0])
T_squeeze[v_ax0] = A[v_ax0, T.int64(0), T.int64(0)]
@T.prim_func
def squeeze4(
A: T.Buffer((T.int64(512), T.int64(1), T.int64(1)), "float32"),
T_squeeze: T.Buffer((T.int64(512),), "float32"),
):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0 in range(T.int64(512)):
with T.block("T_squeeze"):
v_ax0 = T.axis.spatial(T.int64(512), ax0)
T.reads(A[v_ax0, T.int64(0), T.int64(0)])
T.writes(T_squeeze[v_ax0])
T_squeeze[v_ax0] = A[v_ax0, T.int64(0), T.int64(0)]
@R.function
def main(
data: R.Tensor((1, 3, 224, 224), dtype="float32"),
bn_data_gamma: R.Tensor((3,), dtype="float32"),
bn_data_beta: R.Tensor((3,), dtype="float32"),
bn_data_moving_mean: R.Tensor((3,), dtype="float32"),
bn_data_moving_var: R.Tensor((3,), dtype="float32"),
conv0_weight: R.Tensor((64, 3, 7, 7), dtype="float32"),
bn0_gamma: R.Tensor((64,), dtype="float32"),
bn0_beta: R.Tensor((64,), dtype="float32"),
bn0_moving_mean: R.Tensor((64,), dtype="float32"),
bn0_moving_var: R.Tensor((64,), dtype="float32"),
stage1_unit1_bn1_gamma: R.Tensor((64,), dtype="float32"),
stage1_unit1_bn1_beta: R.Tensor((64,), dtype="float32"),
stage1_unit1_bn1_moving_mean: R.Tensor((64,), dtype="float32"),
stage1_unit1_bn1_moving_var: R.Tensor((64,), dtype="float32"),
stage1_unit1_conv1_weight: R.Tensor((64, 64, 1, 1), dtype="float32"),
stage1_unit1_bn2_gamma: R.Tensor((64,), dtype="float32"),
stage1_unit1_bn2_beta: R.Tensor((64,), dtype="float32"),
stage1_unit1_bn2_moving_mean: R.Tensor((64,), dtype="float32"),
stage1_unit1_bn2_moving_var: R.Tensor((64,), dtype="float32"),
stage1_unit1_conv2_weight: R.Tensor((64, 64, 3, 3), dtype="float32"),
stage1_unit1_bn3_gamma: R.Tensor((64,), dtype="float32"),
stage1_unit1_bn3_beta: R.Tensor((64,), dtype="float32"),
stage1_unit1_bn3_moving_mean: R.Tensor((64,), dtype="float32"),
stage1_unit1_bn3_moving_var: R.Tensor((64,), dtype="float32"),
stage1_unit1_conv3_weight: R.Tensor((256, 64, 1, 1), dtype="float32"),
stage1_unit1_sc_weight: R.Tensor((256, 64, 1, 1), dtype="float32"),
stage1_unit2_bn1_gamma: R.Tensor((256,), dtype="float32"),
stage1_unit2_bn1_beta: R.Tensor((256,), dtype="float32"),
stage1_unit2_bn1_moving_mean: R.Tensor((256,), dtype="float32"),
stage1_unit2_bn1_moving_var: R.Tensor((256,), dtype="float32"),
stage1_unit2_conv1_weight: R.Tensor((64, 256, 1, 1), dtype="float32"),
stage1_unit2_bn2_gamma: R.Tensor((64,), dtype="float32"),
stage1_unit2_bn2_beta: R.Tensor((64,), dtype="float32"),
stage1_unit2_bn2_moving_mean: R.Tensor((64,), dtype="float32"),
stage1_unit2_bn2_moving_var: R.Tensor((64,), dtype="float32"),
stage1_unit2_conv2_weight: R.Tensor((64, 64, 3, 3), dtype="float32"),
stage1_unit2_bn3_gamma: R.Tensor((64,), dtype="float32"),
stage1_unit2_bn3_beta: R.Tensor((64,), dtype="float32"),
stage1_unit2_bn3_moving_mean: R.Tensor((64,), dtype="float32"),
stage1_unit2_bn3_moving_var: R.Tensor((64,), dtype="float32"),
stage1_unit2_conv3_weight: R.Tensor((256, 64, 1, 1), dtype="float32"),
stage1_unit3_bn1_gamma: R.Tensor((256,), dtype="float32"),
stage1_unit3_bn1_beta: R.Tensor((256,), dtype="float32"),
stage1_unit3_bn1_moving_mean: R.Tensor((256,), dtype="float32"),
stage1_unit3_bn1_moving_var: R.Tensor((256,), dtype="float32"),
stage1_unit3_conv1_weight: R.Tensor((64, 256, 1, 1), dtype="float32"),
stage1_unit3_bn2_gamma: R.Tensor((64,), dtype="float32"),
stage1_unit3_bn2_beta: R.Tensor((64,), dtype="float32"),
stage1_unit3_bn2_moving_mean: R.Tensor((64,), dtype="float32"),
stage1_unit3_bn2_moving_var: R.Tensor((64,), dtype="float32"),
stage1_unit3_conv2_weight: R.Tensor((64, 64, 3, 3), dtype="float32"),
stage1_unit3_bn3_gamma: R.Tensor((64,), dtype="float32"),
stage1_unit3_bn3_beta: R.Tensor((64,), dtype="float32"),
stage1_unit3_bn3_moving_mean: R.Tensor((64,), dtype="float32"),
stage1_unit3_bn3_moving_var: R.Tensor((64,), dtype="float32"),
stage1_unit3_conv3_weight: R.Tensor((256, 64, 1, 1), dtype="float32"),
stage2_unit1_bn1_gamma: R.Tensor((256,), dtype="float32"),
stage2_unit1_bn1_beta: R.Tensor((256,), dtype="float32"),
stage2_unit1_bn1_moving_mean: R.Tensor((256,), dtype="float32"),
stage2_unit1_bn1_moving_var: R.Tensor((256,), dtype="float32"),
stage2_unit1_conv1_weight: R.Tensor((128, 256, 1, 1), dtype="float32"),
stage2_unit1_bn2_gamma: R.Tensor((128,), dtype="float32"),
stage2_unit1_bn2_beta: R.Tensor((128,), dtype="float32"),
stage2_unit1_bn2_moving_mean: R.Tensor((128,), dtype="float32"),
stage2_unit1_bn2_moving_var: R.Tensor((128,), dtype="float32"),
stage2_unit1_conv2_weight: R.Tensor((128, 128, 3, 3), dtype="float32"),
stage2_unit1_bn3_gamma: R.Tensor((128,), dtype="float32"),
stage2_unit1_bn3_beta: R.Tensor((128,), dtype="float32"),
stage2_unit1_bn3_moving_mean: R.Tensor((128,), dtype="float32"),
stage2_unit1_bn3_moving_var: R.Tensor((128,), dtype="float32"),
stage2_unit1_conv3_weight: R.Tensor((512, 128, 1, 1), dtype="float32"),
stage2_unit1_sc_weight: R.Tensor((512, 256, 1, 1), dtype="float32"),
stage2_unit2_bn1_gamma: R.Tensor((512,), dtype="float32"),
stage2_unit2_bn1_beta: R.Tensor((512,), dtype="float32"),
stage2_unit2_bn1_moving_mean: R.Tensor((512,), dtype="float32"),
stage2_unit2_bn1_moving_var: R.Tensor((512,), dtype="float32"),
stage2_unit2_conv1_weight: R.Tensor((128, 512, 1, 1), dtype="float32"),
stage2_unit2_bn2_gamma: R.Tensor((128,), dtype="float32"),
stage2_unit2_bn2_beta: R.Tensor((128,), dtype="float32"),
stage2_unit2_bn2_moving_mean: R.Tensor((128,), dtype="float32"),
stage2_unit2_bn2_moving_var: R.Tensor((128,), dtype="float32"),
stage2_unit2_conv2_weight: R.Tensor((128, 128, 3, 3), dtype="float32"),
stage2_unit2_bn3_gamma: R.Tensor((128,), dtype="float32"),
stage2_unit2_bn3_beta: R.Tensor((128,), dtype="float32"),
stage2_unit2_bn3_moving_mean: R.Tensor((128,), dtype="float32"),
stage2_unit2_bn3_moving_var: R.Tensor((128,), dtype="float32"),
stage2_unit2_conv3_weight: R.Tensor((512, 128, 1, 1), dtype="float32"),
stage2_unit3_bn1_gamma: R.Tensor((512,), dtype="float32"),
stage2_unit3_bn1_beta: R.Tensor((512,), dtype="float32"),
stage2_unit3_bn1_moving_mean: R.Tensor((512,), dtype="float32"),
stage2_unit3_bn1_moving_var: R.Tensor((512,), dtype="float32"),
stage2_unit3_conv1_weight: R.Tensor((128, 512, 1, 1), dtype="float32"),
stage2_unit3_bn2_gamma: R.Tensor((128,), dtype="float32"),
stage2_unit3_bn2_beta: R.Tensor((128,), dtype="float32"),
stage2_unit3_bn2_moving_mean: R.Tensor((128,), dtype="float32"),
stage2_unit3_bn2_moving_var: R.Tensor((128,), dtype="float32"),
stage2_unit3_conv2_weight: R.Tensor((128, 128, 3, 3), dtype="float32"),
stage2_unit3_bn3_gamma: R.Tensor((128,), dtype="float32"),
stage2_unit3_bn3_beta: R.Tensor((128,), dtype="float32"),
stage2_unit3_bn3_moving_mean: R.Tensor((128,), dtype="float32"),
stage2_unit3_bn3_moving_var: R.Tensor((128,), dtype="float32"),
stage2_unit3_conv3_weight: R.Tensor((512, 128, 1, 1), dtype="float32"),
stage2_unit4_bn1_gamma: R.Tensor((512,), dtype="float32"),
stage2_unit4_bn1_beta: R.Tensor((512,), dtype="float32"),
stage2_unit4_bn1_moving_mean: R.Tensor((512,), dtype="float32"),
stage2_unit4_bn1_moving_var: R.Tensor((512,), dtype="float32"),
stage2_unit4_conv1_weight: R.Tensor((128, 512, 1, 1), dtype="float32"),
stage2_unit4_bn2_gamma: R.Tensor((128,), dtype="float32"),
stage2_unit4_bn2_beta: R.Tensor((128,), dtype="float32"),
stage2_unit4_bn2_moving_mean: R.Tensor((128,), dtype="float32"),
stage2_unit4_bn2_moving_var: R.Tensor((128,), dtype="float32"),
stage2_unit4_conv2_weight: R.Tensor((128, 128, 3, 3), dtype="float32"),
stage2_unit4_bn3_gamma: R.Tensor((128,), dtype="float32"),
stage2_unit4_bn3_beta: R.Tensor((128,), dtype="float32"),
stage2_unit4_bn3_moving_mean: R.Tensor((128,), dtype="float32"),
stage2_unit4_bn3_moving_var: R.Tensor((128,), dtype="float32"),
stage2_unit4_conv3_weight: R.Tensor((512, 128, 1, 1), dtype="float32"),
stage3_unit1_bn1_gamma: R.Tensor((512,), dtype="float32"),
stage3_unit1_bn1_beta: R.Tensor((512,), dtype="float32"),
stage3_unit1_bn1_moving_mean: R.Tensor((512,), dtype="float32"),
stage3_unit1_bn1_moving_var: R.Tensor((512,), dtype="float32"),
stage3_unit1_conv1_weight: R.Tensor((256, 512, 1, 1), dtype="float32"),
stage3_unit1_bn2_gamma: R.Tensor((256,), dtype="float32"),
stage3_unit1_bn2_beta: R.Tensor((256,), dtype="float32"),
stage3_unit1_bn2_moving_mean: R.Tensor((256,), dtype="float32"),
stage3_unit1_bn2_moving_var: R.Tensor((256,), dtype="float32"),
stage3_unit1_conv2_weight: R.Tensor((256, 256, 3, 3), dtype="float32"),
stage3_unit1_bn3_gamma: R.Tensor((256,), dtype="float32"),
stage3_unit1_bn3_beta: R.Tensor((256,), dtype="float32"),
stage3_unit1_bn3_moving_mean: R.Tensor((256,), dtype="float32"),
stage3_unit1_bn3_moving_var: R.Tensor((256,), dtype="float32"),
stage3_unit1_conv3_weight: R.Tensor((1024, 256, 1, 1), dtype="float32"),
stage3_unit1_sc_weight: R.Tensor((1024, 512, 1, 1), dtype="float32"),
stage3_unit2_bn1_gamma: R.Tensor((1024,), dtype="float32"),
stage3_unit2_bn1_beta: R.Tensor((1024,), dtype="float32"),
stage3_unit2_bn1_moving_mean: R.Tensor((1024,), dtype="float32"),
stage3_unit2_bn1_moving_var: R.Tensor((1024,), dtype="float32"),
stage3_unit2_conv1_weight: R.Tensor((256, 1024, 1, 1), dtype="float32"),
stage3_unit2_bn2_gamma: R.Tensor((256,), dtype="float32"),
stage3_unit2_bn2_beta: R.Tensor((256,), dtype="float32"),
stage3_unit2_bn2_moving_mean: R.Tensor((256,), dtype="float32"),
stage3_unit2_bn2_moving_var: R.Tensor((256,), dtype="float32"),
stage3_unit2_conv2_weight: R.Tensor((256, 256, 3, 3), dtype="float32"),
stage3_unit2_bn3_gamma: R.Tensor((256,), dtype="float32"),
stage3_unit2_bn3_beta: R.Tensor((256,), dtype="float32"),
stage3_unit2_bn3_moving_mean: R.Tensor((256,), dtype="float32"),
stage3_unit2_bn3_moving_var: R.Tensor((256,), dtype="float32"),
stage3_unit2_conv3_weight: R.Tensor((1024, 256, 1, 1), dtype="float32"),
stage3_unit3_bn1_gamma: R.Tensor((1024,), dtype="float32"),
stage3_unit3_bn1_beta: R.Tensor((1024,), dtype="float32"),
stage3_unit3_bn1_moving_mean: R.Tensor((1024,), dtype="float32"),
stage3_unit3_bn1_moving_var: R.Tensor((1024,), dtype="float32"),
stage3_unit3_conv1_weight: R.Tensor((256, 1024, 1, 1), dtype="float32"),
stage3_unit3_bn2_gamma: R.Tensor((256,), dtype="float32"),
stage3_unit3_bn2_beta: R.Tensor((256,), dtype="float32"),
stage3_unit3_bn2_moving_mean: R.Tensor((256,), dtype="float32"),
stage3_unit3_bn2_moving_var: R.Tensor((256,), dtype="float32"),
stage3_unit3_conv2_weight: R.Tensor((256, 256, 3, 3), dtype="float32"),
stage3_unit3_bn3_gamma: R.Tensor((256,), dtype="float32"),
stage3_unit3_bn3_beta: R.Tensor((256,), dtype="float32"),
stage3_unit3_bn3_moving_mean: R.Tensor((256,), dtype="float32"),
stage3_unit3_bn3_moving_var: R.Tensor((256,), dtype="float32"),
stage3_unit3_conv3_weight: R.Tensor((1024, 256, 1, 1), dtype="float32"),
stage3_unit4_bn1_gamma: R.Tensor((1024,), dtype="float32"),
stage3_unit4_bn1_beta: R.Tensor((1024,), dtype="float32"),
stage3_unit4_bn1_moving_mean: R.Tensor((1024,), dtype="float32"),
stage3_unit4_bn1_moving_var: R.Tensor((1024,), dtype="float32"),
stage3_unit4_conv1_weight: R.Tensor((256, 1024, 1, 1), dtype="float32"),
stage3_unit4_bn2_gamma: R.Tensor((256,), dtype="float32"),
stage3_unit4_bn2_beta: R.Tensor((256,), dtype="float32"),
stage3_unit4_bn2_moving_mean: R.Tensor((256,), dtype="float32"),
stage3_unit4_bn2_moving_var: R.Tensor((256,), dtype="float32"),
stage3_unit4_conv2_weight: R.Tensor((256, 256, 3, 3), dtype="float32"),
stage3_unit4_bn3_gamma: R.Tensor((256,), dtype="float32"),
stage3_unit4_bn3_beta: R.Tensor((256,), dtype="float32"),
stage3_unit4_bn3_moving_mean: R.Tensor((256,), dtype="float32"),
stage3_unit4_bn3_moving_var: R.Tensor((256,), dtype="float32"),
stage3_unit4_conv3_weight: R.Tensor((1024, 256, 1, 1), dtype="float32"),
stage3_unit5_bn1_gamma: R.Tensor((1024,), dtype="float32"),
stage3_unit5_bn1_beta: R.Tensor((1024,), dtype="float32"),
stage3_unit5_bn1_moving_mean: R.Tensor((1024,), dtype="float32"),
stage3_unit5_bn1_moving_var: R.Tensor((1024,), dtype="float32"),
stage3_unit5_conv1_weight: R.Tensor((256, 1024, 1, 1), dtype="float32"),
stage3_unit5_bn2_gamma: R.Tensor((256,), dtype="float32"),
stage3_unit5_bn2_beta: R.Tensor((256,), dtype="float32"),
stage3_unit5_bn2_moving_mean: R.Tensor((256,), dtype="float32"),
stage3_unit5_bn2_moving_var: R.Tensor((256,), dtype="float32"),
stage3_unit5_conv2_weight: R.Tensor((256, 256, 3, 3), dtype="float32"),
stage3_unit5_bn3_gamma: R.Tensor((256,), dtype="float32"),
stage3_unit5_bn3_beta: R.Tensor((256,), dtype="float32"),
stage3_unit5_bn3_moving_mean: R.Tensor((256,), dtype="float32"),
stage3_unit5_bn3_moving_var: R.Tensor((256,), dtype="float32"),
stage3_unit5_conv3_weight: R.Tensor((1024, 256, 1, 1), dtype="float32"),
stage3_unit6_bn1_gamma: R.Tensor((1024,), dtype="float32"),
stage3_unit6_bn1_beta: R.Tensor((1024,), dtype="float32"),
stage3_unit6_bn1_moving_mean: R.Tensor((1024,), dtype="float32"),
stage3_unit6_bn1_moving_var: R.Tensor((1024,), dtype="float32"),
stage3_unit6_conv1_weight: R.Tensor((256, 1024, 1, 1), dtype="float32"),
stage3_unit6_bn2_gamma: R.Tensor((256,), dtype="float32"),
stage3_unit6_bn2_beta: R.Tensor((256,), dtype="float32"),
stage3_unit6_bn2_moving_mean: R.Tensor((256,), dtype="float32"),
stage3_unit6_bn2_moving_var: R.Tensor((256,), dtype="float32"),
stage3_unit6_conv2_weight: R.Tensor((256, 256, 3, 3), dtype="float32"),
stage3_unit6_bn3_gamma: R.Tensor((256,), dtype="float32"),
stage3_unit6_bn3_beta: R.Tensor((256,), dtype="float32"),
stage3_unit6_bn3_moving_mean: R.Tensor((256,), dtype="float32"),
stage3_unit6_bn3_moving_var: R.Tensor((256,), dtype="float32"),
stage3_unit6_conv3_weight: R.Tensor((1024, 256, 1, 1), dtype="float32"),
stage4_unit1_bn1_gamma: R.Tensor((1024,), dtype="float32"),
stage4_unit1_bn1_beta: R.Tensor((1024,), dtype="float32"),
stage4_unit1_bn1_moving_mean: R.Tensor((1024,), dtype="float32"),
stage4_unit1_bn1_moving_var: R.Tensor((1024,), dtype="float32"),
stage4_unit1_conv1_weight: R.Tensor((512, 1024, 1, 1), dtype="float32"),
stage4_unit1_bn2_gamma: R.Tensor((512,), dtype="float32"),
stage4_unit1_bn2_beta: R.Tensor((512,), dtype="float32"),
stage4_unit1_bn2_moving_mean: R.Tensor((512,), dtype="float32"),
stage4_unit1_bn2_moving_var: R.Tensor((512,), dtype="float32"),
stage4_unit1_conv2_weight: R.Tensor((512, 512, 3, 3), dtype="float32"),
stage4_unit1_bn3_gamma: R.Tensor((512,), dtype="float32"),
stage4_unit1_bn3_beta: R.Tensor((512,), dtype="float32"),
stage4_unit1_bn3_moving_mean: R.Tensor((512,), dtype="float32"),
stage4_unit1_bn3_moving_var: R.Tensor((512,), dtype="float32"),
stage4_unit1_conv3_weight: R.Tensor((2048, 512, 1, 1), dtype="float32"),
stage4_unit1_sc_weight: R.Tensor((2048, 1024, 1, 1), dtype="float32"),
stage4_unit2_bn1_gamma: R.Tensor((2048,), dtype="float32"),
stage4_unit2_bn1_beta: R.Tensor((2048,), dtype="float32"),
stage4_unit2_bn1_moving_mean: R.Tensor((2048,), dtype="float32"),
stage4_unit2_bn1_moving_var: R.Tensor((2048,), dtype="float32"),
stage4_unit2_conv1_weight: R.Tensor((512, 2048, 1, 1), dtype="float32"),
stage4_unit2_bn2_gamma: R.Tensor((512,), dtype="float32"),
stage4_unit2_bn2_beta: R.Tensor((512,), dtype="float32"),
stage4_unit2_bn2_moving_mean: R.Tensor((512,), dtype="float32"),
stage4_unit2_bn2_moving_var: R.Tensor((512,), dtype="float32"),
stage4_unit2_conv2_weight: R.Tensor((512, 512, 3, 3), dtype="float32"),
stage4_unit2_bn3_gamma: R.Tensor((512,), dtype="float32"),
stage4_unit2_bn3_beta: R.Tensor((512,), dtype="float32"),
stage4_unit2_bn3_moving_mean: R.Tensor((512,), dtype="float32"),
stage4_unit2_bn3_moving_var: R.Tensor((512,), dtype="float32"),
stage4_unit2_conv3_weight: R.Tensor((2048, 512, 1, 1), dtype="float32"),
stage4_unit3_bn1_gamma: R.Tensor((2048,), dtype="float32"),
stage4_unit3_bn1_beta: R.Tensor((2048,), dtype="float32"),
stage4_unit3_bn1_moving_mean: R.Tensor((2048,), dtype="float32"),
stage4_unit3_bn1_moving_var: R.Tensor((2048,), dtype="float32"),
stage4_unit3_conv1_weight: R.Tensor((512, 2048, 1, 1), dtype="float32"),
stage4_unit3_bn2_gamma: R.Tensor((512,), dtype="float32"),
stage4_unit3_bn2_beta: R.Tensor((512,), dtype="float32"),
stage4_unit3_bn2_moving_mean: R.Tensor((512,), dtype="float32"),
stage4_unit3_bn2_moving_var: R.Tensor((512,), dtype="float32"),
stage4_unit3_conv2_weight: R.Tensor((512, 512, 3, 3), dtype="float32"),
stage4_unit3_bn3_gamma: R.Tensor((512,), dtype="float32"),
stage4_unit3_bn3_beta: R.Tensor((512,), dtype="float32"),
stage4_unit3_bn3_moving_mean: R.Tensor((512,), dtype="float32"),
stage4_unit3_bn3_moving_var: R.Tensor((512,), dtype="float32"),
stage4_unit3_conv3_weight: R.Tensor((2048, 512, 1, 1), dtype="float32"),
bn1_gamma: R.Tensor((2048,), dtype="float32"),
bn1_beta: R.Tensor((2048,), dtype="float32"),
bn1_moving_mean: R.Tensor((2048,), dtype="float32"),
bn1_moving_var: R.Tensor((2048,), dtype="float32"),
fc1_weight: R.Tensor((1000, 2048), dtype="float32"),
fc1_bias: R.Tensor((1000,), dtype="float32"),
) -> R.Tensor((1, 1000), dtype="float32"):
cls = Module
with R.dataflow():
lv = R.call_tir(
cls.negative,
(bn_data_moving_mean,),
out_sinfo=R.Tensor((3,), dtype="float32"),
)
lv1 = R.call_tir(
cls.add,
(bn_data_moving_var, R.const(1.9999999494757503e-05, "float32")),
out_sinfo=R.Tensor((3,), dtype="float32"),
)
lv2 = R.call_tir(
cls.rsqrt, (lv1,), out_sinfo=R.Tensor((3,), dtype="float32")
)
lv3 = R.call_tir(
cls.multiply, (lv, lv2), out_sinfo=R.Tensor((3,), dtype="float32")
)
lv4 = R.call_tir(
cls.add1, (lv3, bn_data_beta), out_sinfo=R.Tensor((3,), dtype="float32")
)
lv5 = R.call_tir(
cls.expand_dims, (lv4,), out_sinfo=R.Tensor((3, 1, 1), dtype="float32")
)
lv6 = R.call_tir(
cls.expand_dims, (lv2,), out_sinfo=R.Tensor((3, 1, 1), dtype="float32")
)
lv7 = R.call_tir(
cls.squeeze, (lv6,), out_sinfo=R.Tensor((3,), dtype="float32")
)
lv8 = R.call_tir(
cls.expand_dims, (lv7,), out_sinfo=R.Tensor((3, 1, 1), dtype="float32")
)
lv9 = R.call_tir(
cls.divide, (lv5, lv8), out_sinfo=R.Tensor((3, 1, 1), dtype="float32")
)
lv10 = R.call_tir(
cls.add2,
(data, lv9),
out_sinfo=R.Tensor((1, 3, 224, 224), dtype="float32"),
)
lv11 = R.call_tir(
cls.layout_transform,
(lv10,),
out_sinfo=R.Tensor((1, 1, 224, 224, 3), dtype="float32"),
)
lv12 = R.call_tir(
cls.add3,
(bn0_moving_var, R.const(1.9999999494757503e-05, "float32")),
out_sinfo=R.Tensor((64,), dtype="float32"),
)
lv13 = R.call_tir(
cls.rsqrt1, (lv12,), out_sinfo=R.Tensor((64,), dtype="float32")
)
lv14 = R.call_tir(
cls.multiply1,
(lv13, bn0_gamma),
out_sinfo=R.Tensor((64,), dtype="float32"),
)
lv15 = R.call_tir(
cls.expand_dims1,
(lv14,),
out_sinfo=R.Tensor((64, 1, 1), dtype="float32"),
)
lv16 = R.call_tir(
cls.squeeze1, (lv15,), out_sinfo=R.Tensor((64,), dtype="float32")
)
lv17 = R.call_tir(
cls.expand_dims2,
(lv16,),
out_sinfo=R.Tensor((64, 1, 1, 1), dtype="float32"),
)
lv18 = R.call_tir(
cls.multiply2,
(conv0_weight, lv17),
out_sinfo=R.Tensor((64, 3, 7, 7), dtype="float32"),
)
lv19 = R.call_tir(
cls.expand_dims, (lv7,), out_sinfo=R.Tensor((3, 1, 1), dtype="float32")
)
lv20 = R.call_tir(
cls.multiply3,
(lv18, lv19),
out_sinfo=R.Tensor((64, 3, 7, 7), dtype="float32"),
)
lv21 = R.call_tir(
cls.layout_transform1,
(lv20,),
out_sinfo=R.Tensor((16, 1, 7, 7, 3, 4), dtype="float32"),
)
lv22 = R.call_tir(
cls.contrib_conv2d_NCHWc,
(lv11, lv21),
out_sinfo=R.Tensor((1, 16, 112, 112, 4), dtype="float32"),
)
lv23 = R.call_tir(
cls.negative1,
(bn0_moving_mean,),
out_sinfo=R.Tensor((64,), dtype="float32"),
)
lv24 = R.call_tir(
cls.multiply1, (lv23, lv14), out_sinfo=R.Tensor((64,), dtype="float32")
)
lv25 = R.call_tir(
cls.add4, (lv24, bn0_beta), out_sinfo=R.Tensor((64,), dtype="float32")
)
lv26 = R.call_tir(
cls.expand_dims1,
(lv25,),
out_sinfo=R.Tensor((64, 1, 1), dtype="float32"),
)
lv27 = R.call_tir(
cls.expand_dims3,
(lv26,),
out_sinfo=R.Tensor((1, 64, 1, 1), dtype="float32"),
)
lv28 = R.call_tir(
cls.layout_transform2,
(lv27,),
out_sinfo=R.Tensor((1, 16, 1, 1, 4), dtype="float32"),
)
lv29 = R.call_tir(
cls.add5,
(lv22, lv28),
out_sinfo=R.Tensor((1, 16, 112, 112, 4), dtype="float32"),
)
lv30 = R.call_tir(
cls.relu,
(lv29,),
out_sinfo=R.Tensor((1, 16, 112, 112, 4), dtype="float32"),
)
lv31 = R.call_tir(
cls.max_pool2d,
(lv30,),
out_sinfo=R.Tensor((1, 16, 56, 56, 4), dtype="float32"),
)
lv32 = R.call_tir(
cls.add3,
(
stage1_unit1_bn1_moving_var,
R.const(1.9999999494757503e-05, "float32"),
),
out_sinfo=R.Tensor((64,), dtype="float32"),
)
lv33 = R.call_tir(
cls.rsqrt1, (lv32,), out_sinfo=R.Tensor((64,), dtype="float32")
)
lv34 = R.call_tir(
cls.multiply1,
(lv33, stage1_unit1_bn1_gamma),
out_sinfo=R.Tensor((64,), dtype="float32"),
)
lv35 = R.call_tir(
cls.expand_dims1,
(lv34,),
out_sinfo=R.Tensor((64, 1, 1), dtype="float32"),
)
lv36 = R.call_tir(
cls.expand_dims3,
(lv35,),
out_sinfo=R.Tensor((1, 64, 1, 1), dtype="float32"),
)
lv37 = R.call_tir(
cls.layout_transform2,
(lv36,),
out_sinfo=R.Tensor((1, 16, 1, 1, 4), dtype="float32"),
)
lv38 = R.call_tir(
cls.multiply4,
(lv31, lv37),
out_sinfo=R.Tensor((1, 16, 56, 56, 4), dtype="float32"),
)
lv39 = R.call_tir(
cls.negative1,
(stage1_unit1_bn1_moving_mean,),
out_sinfo=R.Tensor((64,), dtype="float32"),
)
lv40 = R.call_tir(
cls.multiply1, (lv39, lv34), out_sinfo=R.Tensor((64,), dtype="float32")
)
lv41 = R.call_tir(
cls.add4,
(lv40, stage1_unit1_bn1_beta),
out_sinfo=R.Tensor((64,), dtype="float32"),
)
lv42 = R.call_tir(
cls.expand_dims1,
(lv41,),
out_sinfo=R.Tensor((64, 1, 1), dtype="float32"),
)
lv43 = R.call_tir(
cls.expand_dims3,
(lv42,),
out_sinfo=R.Tensor((1, 64, 1, 1), dtype="float32"),
)
lv44 = R.call_tir(
cls.layout_transform2,
(lv43,),
out_sinfo=R.Tensor((1, 16, 1, 1, 4), dtype="float32"),
)
lv45 = R.call_tir(
cls.add6,
(lv38, lv44),
out_sinfo=R.Tensor((1, 16, 56, 56, 4), dtype="float32"),
)
lv46 = R.call_tir(
cls.relu1,
(lv45,),
out_sinfo=R.Tensor((1, 16, 56, 56, 4), dtype="float32"),
)
lv47 = R.call_tir(
cls.add3,
(
stage1_unit1_bn2_moving_var,
R.const(1.9999999494757503e-05, "float32"),
),
out_sinfo=R.Tensor((64,), dtype="float32"),
)
lv48 = R.call_tir(
cls.rsqrt1, (lv47,), out_sinfo=R.Tensor((64,), dtype="float32")
)
lv49 = R.call_tir(
cls.multiply1,
(lv48, stage1_unit1_bn2_gamma),
out_sinfo=R.Tensor((64,), dtype="float32"),
)
lv50 = R.call_tir(
cls.expand_dims1,
(lv49,),
out_sinfo=R.Tensor((64, 1, 1), dtype="float32"),
)
lv51 = R.call_tir(
cls.squeeze1, (lv50,), out_sinfo=R.Tensor((64,), dtype="float32")
)
lv52 = R.call_tir(
cls.expand_dims2,
(lv51,),
out_sinfo=R.Tensor((64, 1, 1, 1), dtype="float32"),
)
lv53 = R.call_tir(
cls.multiply5,
(stage1_unit1_conv1_weight, lv52),
out_sinfo=R.Tensor((64, 64, 1, 1), dtype="float32"),
)
lv54 = R.call_tir(
cls.layout_transform3,
(lv53,),
out_sinfo=R.Tensor((16, 16, 1, 1, 4, 4), dtype="float32"),
)
lv55 = R.call_tir(
cls.contrib_conv2d_NCHWc1,
(lv46, lv54),
out_sinfo=R.Tensor((1, 16, 56, 56, 4), dtype="float32"),
)
lv56 = R.call_tir(
cls.negative1,
(stage1_unit1_bn2_moving_mean,),
out_sinfo=R.Tensor((64,), dtype="float32"),
)
lv57 = R.call_tir(
cls.multiply1, (lv56, lv49), out_sinfo=R.Tensor((64,), dtype="float32")
)
lv58 = R.call_tir(
cls.add4,
(lv57, stage1_unit1_bn2_beta),
out_sinfo=R.Tensor((64,), dtype="float32"),
)
lv59 = R.call_tir(
cls.expand_dims1,
(lv58,),
out_sinfo=R.Tensor((64, 1, 1), dtype="float32"),
)
lv60 = R.call_tir(
cls.expand_dims3,
(lv59,),
out_sinfo=R.Tensor((1, 64, 1, 1), dtype="float32"),
)
lv61 = R.call_tir(
cls.layout_transform2,
(lv60,),
out_sinfo=R.Tensor((1, 16, 1, 1, 4), dtype="float32"),
)
lv62 = R.call_tir(
cls.add6,
(lv55, lv61),
out_sinfo=R.Tensor((1, 16, 56, 56, 4), dtype="float32"),
)
lv63 = R.call_tir(
cls.relu1,
(lv62,),
out_sinfo=R.Tensor((1, 16, 56, 56, 4), dtype="float32"),
)
lv64 = R.call_tir(
cls.add3,
(
stage1_unit1_bn3_moving_var,
R.const(1.9999999494757503e-05, "float32"),
),
out_sinfo=R.Tensor((64,), dtype="float32"),
)
lv65 = R.call_tir(
cls.rsqrt1, (lv64,), out_sinfo=R.Tensor((64,), dtype="float32")
)
lv66 = R.call_tir(
cls.multiply1,
(lv65, stage1_unit1_bn3_gamma),
out_sinfo=R.Tensor((64,), dtype="float32"),
)
lv67 = R.call_tir(
cls.expand_dims1,
(lv66,),
out_sinfo=R.Tensor((64, 1, 1), dtype="float32"),
)
lv68 = R.call_tir(
cls.squeeze1, (lv67,), out_sinfo=R.Tensor((64,), dtype="float32")
)
lv69 = R.call_tir(
cls.expand_dims2,
(lv68,),
out_sinfo=R.Tensor((64, 1, 1, 1), dtype="float32"),
)
lv70 = R.call_tir(
cls.multiply6,
(stage1_unit1_conv2_weight, lv69),
out_sinfo=R.Tensor((64, 64, 3, 3), dtype="float32"),
)
lv71 = R.call_tir(
cls.layout_transform4,
(lv70,),
out_sinfo=R.Tensor((16, 16, 3, 3, 4, 4), dtype="float32"),
)
lv72 = R.call_tir(
cls.contrib_conv2d_NCHWc2,
(lv63, lv71),
out_sinfo=R.Tensor((1, 16, 56, 56, 4), dtype="float32"),
)
lv73 = R.call_tir(
cls.negative1,
(stage1_unit1_bn3_moving_mean,),
out_sinfo=R.Tensor((64,), dtype="float32"),
)
lv74 = R.call_tir(
cls.multiply1, (lv73, lv66), out_sinfo=R.Tensor((64,), dtype="float32")
)
lv75 = R.call_tir(
cls.add4,
(lv74, stage1_unit1_bn3_beta),
out_sinfo=R.Tensor((64,), dtype="float32"),
)
lv76 = R.call_tir(
cls.expand_dims1,
(lv75,),
out_sinfo=R.Tensor((64, 1, 1), dtype="float32"),
)
lv77 = R.call_tir(
cls.expand_dims3,
(lv76,),
out_sinfo=R.Tensor((1, 64, 1, 1), dtype="float32"),
)
lv78 = R.call_tir(
cls.layout_transform2,
(lv77,),
out_sinfo=R.Tensor((1, 16, 1, 1, 4), dtype="float32"),
)
lv79 = R.call_tir(
cls.add6,
(lv72, lv78),
out_sinfo=R.Tensor((1, 16, 56, 56, 4), dtype="float32"),
)
lv80 = R.call_tir(
cls.relu1,
(lv79,),
out_sinfo=R.Tensor((1, 16, 56, 56, 4), dtype="float32"),
)
lv81 = R.call_tir(
cls.layout_transform5,
(stage1_unit1_conv3_weight,),
out_sinfo=R.Tensor((64, 16, 1, 1, 4, 4), dtype="float32"),
)
lv82 = R.call_tir(
cls.contrib_conv2d_NCHWc3,
(lv80, lv81),
out_sinfo=R.Tensor((1, 64, 56, 56, 4), dtype="float32"),
)
lv83 = R.call_tir(
cls.layout_transform5,
(stage1_unit1_sc_weight,),
out_sinfo=R.Tensor((64, 16, 1, 1, 4, 4), dtype="float32"),
)
lv84 = R.call_tir(
cls.contrib_conv2d_NCHWc3,
(lv46, lv83),
out_sinfo=R.Tensor((1, 64, 56, 56, 4), dtype="float32"),
)
lv85 = R.call_tir(
cls.add7,
(lv82, lv84),
out_sinfo=R.Tensor((1, 64, 56, 56, 4), dtype="float32"),
)
lv86 = R.call_tir(
cls.add8,
(
stage1_unit2_bn1_moving_var,
R.const(1.9999999494757503e-05, "float32"),
),
out_sinfo=R.Tensor((256,), dtype="float32"),
)
lv87 = R.call_tir(
cls.rsqrt2, (lv86,), out_sinfo=R.Tensor((256,), dtype="float32")
)
lv88 = R.call_tir(
cls.multiply7,
(lv87, stage1_unit2_bn1_gamma),
out_sinfo=R.Tensor((256,), dtype="float32"),
)
lv89 = R.call_tir(
cls.expand_dims4,
(lv88,),
out_sinfo=R.Tensor((256, 1, 1), dtype="float32"),
)
lv90 = R.call_tir(
cls.expand_dims5,
(lv89,),
out_sinfo=R.Tensor((1, 256, 1, 1), dtype="float32"),
)
lv91 = R.call_tir(
cls.layout_transform6,
(lv90,),
out_sinfo=R.Tensor((1, 64, 1, 1, 4), dtype="float32"),
)
lv92 = R.call_tir(
cls.multiply8,
(lv85, lv91),
out_sinfo=R.Tensor((1, 64, 56, 56, 4), dtype="float32"),
)
lv93 = R.call_tir(
cls.negative2,
(stage1_unit2_bn1_moving_mean,),
out_sinfo=R.Tensor((256,), dtype="float32"),
)
lv94 = R.call_tir(
cls.multiply7, (lv93, lv88), out_sinfo=R.Tensor((256,), dtype="float32")
)
lv95 = R.call_tir(
cls.add9,
(lv94, stage1_unit2_bn1_beta),
out_sinfo=R.Tensor((256,), dtype="float32"),
)
lv96 = R.call_tir(
cls.expand_dims4,
(lv95,),
out_sinfo=R.Tensor((256, 1, 1), dtype="float32"),
)
lv97 = R.call_tir(
cls.expand_dims5,
(lv96,),
out_sinfo=R.Tensor((1, 256, 1, 1), dtype="float32"),
)
lv98 = R.call_tir(
cls.layout_transform6,
(lv97,),
out_sinfo=R.Tensor((1, 64, 1, 1, 4), dtype="float32"),
)
lv99 = R.call_tir(
cls.add10,
(lv92, lv98),
out_sinfo=R.Tensor((1, 64, 56, 56, 4), dtype="float32"),
)
lv100 = R.call_tir(
cls.relu2,
(lv99,),
out_sinfo=R.Tensor((1, 64, 56, 56, 4), dtype="float32"),
)
lv101 = R.call_tir(
cls.add3,
(
stage1_unit2_bn2_moving_var,
R.const(1.9999999494757503e-05, "float32"),
),
out_sinfo=R.Tensor((64,), dtype="float32"),
)
lv102 = R.call_tir(
cls.rsqrt1, (lv101,), out_sinfo=R.Tensor((64,), dtype="float32")
)
lv103 = R.call_tir(
cls.multiply1,
(lv102, stage1_unit2_bn2_gamma),
out_sinfo=R.Tensor((64,), dtype="float32"),
)
lv104 = R.call_tir(
cls.expand_dims1,
(lv103,),
out_sinfo=R.Tensor((64, 1, 1), dtype="float32"),
)
lv105 = R.call_tir(
cls.squeeze1, (lv104,), out_sinfo=R.Tensor((64,), dtype="float32")
)
lv106 = R.call_tir(
cls.expand_dims2,
(lv105,),
out_sinfo=R.Tensor((64, 1, 1, 1), dtype="float32"),
)
lv107 = R.call_tir(
cls.multiply9,
(stage1_unit2_conv1_weight, lv106),
out_sinfo=R.Tensor((64, 256, 1, 1), dtype="float32"),
)
lv108 = R.call_tir(
cls.layout_transform7,
(lv107,),
out_sinfo=R.Tensor((16, 64, 1, 1, 4, 4), dtype="float32"),
)
lv109 = R.call_tir(
cls.contrib_conv2d_NCHWc4,
(lv100, lv108),
out_sinfo=R.Tensor((1, 16, 56, 56, 4), dtype="float32"),
)
lv110 = R.call_tir(
cls.negative1,
(stage1_unit2_bn2_moving_mean,),
out_sinfo=R.Tensor((64,), dtype="float32"),
)
lv111 = R.call_tir(
cls.multiply1,
(lv110, lv103),
out_sinfo=R.Tensor((64,), dtype="float32"),
)
lv112 = R.call_tir(
cls.add4,
(lv111, stage1_unit2_bn2_beta),
out_sinfo=R.Tensor((64,), dtype="float32"),
)
lv113 = R.call_tir(
cls.expand_dims1,
(lv112,),
out_sinfo=R.Tensor((64, 1, 1), dtype="float32"),
)
lv114 = R.call_tir(
cls.expand_dims3,
(lv113,),
out_sinfo=R.Tensor((1, 64, 1, 1), dtype="float32"),
)
lv115 = R.call_tir(
cls.layout_transform2,
(lv114,),
out_sinfo=R.Tensor((1, 16, 1, 1, 4), dtype="float32"),
)
lv116 = R.call_tir(
cls.add6,
(lv109, lv115),
out_sinfo=R.Tensor((1, 16, 56, 56, 4), dtype="float32"),
)
lv117 = R.call_tir(
cls.relu1,
(lv116,),
out_sinfo=R.Tensor((1, 16, 56, 56, 4), dtype="float32"),
)
lv118 = R.call_tir(
cls.add3,
(
stage1_unit2_bn3_moving_var,
R.const(1.9999999494757503e-05, "float32"),
),
out_sinfo=R.Tensor((64,), dtype="float32"),
)
lv119 = R.call_tir(
cls.rsqrt1, (lv118,), out_sinfo=R.Tensor((64,), dtype="float32")
)
lv120 = R.call_tir(
cls.multiply1,
(lv119, stage1_unit2_bn3_gamma),
out_sinfo=R.Tensor((64,), dtype="float32"),
)
lv121 = R.call_tir(
cls.expand_dims1,
(lv120,),
out_sinfo=R.Tensor((64, 1, 1), dtype="float32"),
)
lv122 = R.call_tir(
cls.squeeze1, (lv121,), out_sinfo=R.Tensor((64,), dtype="float32")
)
lv123 = R.call_tir(
cls.expand_dims2,
(lv122,),
out_sinfo=R.Tensor((64, 1, 1, 1), dtype="float32"),
)
lv124 = R.call_tir(
cls.multiply6,
(stage1_unit2_conv2_weight, lv123),
out_sinfo=R.Tensor((64, 64, 3, 3), dtype="float32"),
)
lv125 = R.call_tir(
cls.layout_transform4,
(lv124,),
out_sinfo=R.Tensor((16, 16, 3, 3, 4, 4), dtype="float32"),
)
lv126 = R.call_tir(
cls.contrib_conv2d_NCHWc2,
(lv117, lv125),
out_sinfo=R.Tensor((1, 16, 56, 56, 4), dtype="float32"),
)
lv127 = R.call_tir(
cls.negative1,
(stage1_unit2_bn3_moving_mean,),
out_sinfo=R.Tensor((64,), dtype="float32"),
)
lv128 = R.call_tir(
cls.multiply1,
(lv127, lv120),
out_sinfo=R.Tensor((64,), dtype="float32"),
)
lv129 = R.call_tir(
cls.add4,
(lv128, stage1_unit2_bn3_beta),
out_sinfo=R.Tensor((64,), dtype="float32"),
)
lv130 = R.call_tir(
cls.expand_dims1,
(lv129,),
out_sinfo=R.Tensor((64, 1, 1), dtype="float32"),
)
lv131 = R.call_tir(
cls.expand_dims3,
(lv130,),
out_sinfo=R.Tensor((1, 64, 1, 1), dtype="float32"),
)
lv132 = R.call_tir(
cls.layout_transform2,
(lv131,),
out_sinfo=R.Tensor((1, 16, 1, 1, 4), dtype="float32"),
)
lv133 = R.call_tir(
cls.add6,
(lv126, lv132),
out_sinfo=R.Tensor((1, 16, 56, 56, 4), dtype="float32"),
)
lv134 = R.call_tir(
cls.relu1,
(lv133,),
out_sinfo=R.Tensor((1, 16, 56, 56, 4), dtype="float32"),
)
lv135 = R.call_tir(
cls.layout_transform5,
(stage1_unit2_conv3_weight,),
out_sinfo=R.Tensor((64, 16, 1, 1, 4, 4), dtype="float32"),
)
lv136 = R.call_tir(
cls.contrib_conv2d_NCHWc3,
(lv134, lv135),
out_sinfo=R.Tensor((1, 64, 56, 56, 4), dtype="float32"),
)
lv137 = R.call_tir(
cls.add7,
(lv136, lv85),
out_sinfo=R.Tensor((1, 64, 56, 56, 4), dtype="float32"),
)
lv138 = R.call_tir(
cls.add8,
(
stage1_unit3_bn1_moving_var,
R.const(1.9999999494757503e-05, "float32"),
),
out_sinfo=R.Tensor((256,), dtype="float32"),
)
lv139 = R.call_tir(
cls.rsqrt2, (lv138,), out_sinfo=R.Tensor((256,), dtype="float32")
)
lv140 = R.call_tir(
cls.multiply7,
(lv139, stage1_unit3_bn1_gamma),
out_sinfo=R.Tensor((256,), dtype="float32"),
)
lv141 = R.call_tir(
cls.expand_dims4,
(lv140,),
out_sinfo=R.Tensor((256, 1, 1), dtype="float32"),
)
lv142 = R.call_tir(
cls.expand_dims5,
(lv141,),
out_sinfo=R.Tensor((1, 256, 1, 1), dtype="float32"),
)
lv143 = R.call_tir(
cls.layout_transform6,
(lv142,),
out_sinfo=R.Tensor((1, 64, 1, 1, 4), dtype="float32"),
)
lv144 = R.call_tir(
cls.multiply8,
(lv137, lv143),
out_sinfo=R.Tensor((1, 64, 56, 56, 4), dtype="float32"),
)
lv145 = R.call_tir(
cls.negative2,
(stage1_unit3_bn1_moving_mean,),
out_sinfo=R.Tensor((256,), dtype="float32"),
)
lv146 = R.call_tir(
cls.multiply7,
(lv145, lv140),
out_sinfo=R.Tensor((256,), dtype="float32"),
)
lv147 = R.call_tir(
cls.add9,
(lv146, stage1_unit3_bn1_beta),
out_sinfo=R.Tensor((256,), dtype="float32"),
)
lv148 = R.call_tir(
cls.expand_dims4,
(lv147,),
out_sinfo=R.Tensor((256, 1, 1), dtype="float32"),
)
lv149 = R.call_tir(
cls.expand_dims5,
(lv148,),
out_sinfo=R.Tensor((1, 256, 1, 1), dtype="float32"),
)
lv150 = R.call_tir(
cls.layout_transform6,
(lv149,),
out_sinfo=R.Tensor((1, 64, 1, 1, 4), dtype="float32"),
)
lv151 = R.call_tir(
cls.add10,
(lv144, lv150),
out_sinfo=R.Tensor((1, 64, 56, 56, 4), dtype="float32"),
)
lv152 = R.call_tir(
cls.relu2,
(lv151,),
out_sinfo=R.Tensor((1, 64, 56, 56, 4), dtype="float32"),
)
lv153 = R.call_tir(
cls.add3,
(
stage1_unit3_bn2_moving_var,
R.const(1.9999999494757503e-05, "float32"),
),
out_sinfo=R.Tensor((64,), dtype="float32"),
)
lv154 = R.call_tir(
cls.rsqrt1, (lv153,), out_sinfo=R.Tensor((64,), dtype="float32")
)
lv155 = R.call_tir(
cls.multiply1,
(lv154, stage1_unit3_bn2_gamma),
out_sinfo=R.Tensor((64,), dtype="float32"),
)
lv156 = R.call_tir(
cls.expand_dims1,
(lv155,),
out_sinfo=R.Tensor((64, 1, 1), dtype="float32"),
)
lv157 = R.call_tir(
cls.squeeze1, (lv156,), out_sinfo=R.Tensor((64,), dtype="float32")
)
lv158 = R.call_tir(
cls.expand_dims2,
(lv157,),
out_sinfo=R.Tensor((64, 1, 1, 1), dtype="float32"),
)
lv159 = R.call_tir(
cls.multiply9,
(stage1_unit3_conv1_weight, lv158),
out_sinfo=R.Tensor((64, 256, 1, 1), dtype="float32"),
)
lv160 = R.call_tir(
cls.layout_transform7,
(lv159,),
out_sinfo=R.Tensor((16, 64, 1, 1, 4, 4), dtype="float32"),
)
lv161 = R.call_tir(
cls.contrib_conv2d_NCHWc4,
(lv152, lv160),
out_sinfo=R.Tensor((1, 16, 56, 56, 4), dtype="float32"),
)
lv162 = R.call_tir(
cls.negative1,
(stage1_unit3_bn2_moving_mean,),
out_sinfo=R.Tensor((64,), dtype="float32"),
)
lv163 = R.call_tir(
cls.multiply1,
(lv162, lv155),
out_sinfo=R.Tensor((64,), dtype="float32"),
)
lv164 = R.call_tir(
cls.add4,
(lv163, stage1_unit3_bn2_beta),
out_sinfo=R.Tensor((64,), dtype="float32"),
)
lv165 = R.call_tir(
cls.expand_dims1,
(lv164,),
out_sinfo=R.Tensor((64, 1, 1), dtype="float32"),
)
lv166 = R.call_tir(
cls.expand_dims3,
(lv165,),
out_sinfo=R.Tensor((1, 64, 1, 1), dtype="float32"),
)
lv167 = R.call_tir(
cls.layout_transform2,
(lv166,),
out_sinfo=R.Tensor((1, 16, 1, 1, 4), dtype="float32"),
)
lv168 = R.call_tir(
cls.add6,
(lv161, lv167),
out_sinfo=R.Tensor((1, 16, 56, 56, 4), dtype="float32"),
)
lv169 = R.call_tir(
cls.relu1,
(lv168,),
out_sinfo=R.Tensor((1, 16, 56, 56, 4), dtype="float32"),
)
lv170 = R.call_tir(
cls.add3,
(
stage1_unit3_bn3_moving_var,
R.const(1.9999999494757503e-05, "float32"),
),
out_sinfo=R.Tensor((64,), dtype="float32"),
)
lv171 = R.call_tir(
cls.rsqrt1, (lv170,), out_sinfo=R.Tensor((64,), dtype="float32")
)
lv172 = R.call_tir(
cls.multiply1,
(lv171, stage1_unit3_bn3_gamma),
out_sinfo=R.Tensor((64,), dtype="float32"),
)
lv173 = R.call_tir(
cls.expand_dims1,
(lv172,),
out_sinfo=R.Tensor((64, 1, 1), dtype="float32"),
)
lv174 = R.call_tir(
cls.squeeze1, (lv173,), out_sinfo=R.Tensor((64,), dtype="float32")
)
lv175 = R.call_tir(
cls.expand_dims2,
(lv174,),
out_sinfo=R.Tensor((64, 1, 1, 1), dtype="float32"),
)
lv176 = R.call_tir(
cls.multiply6,
(stage1_unit3_conv2_weight, lv175),
out_sinfo=R.Tensor((64, 64, 3, 3), dtype="float32"),
)
lv177 = R.call_tir(
cls.layout_transform4,
(lv176,),
out_sinfo=R.Tensor((16, 16, 3, 3, 4, 4), dtype="float32"),
)
lv178 = R.call_tir(
cls.contrib_conv2d_NCHWc2,
(lv169, lv177),
out_sinfo=R.Tensor((1, 16, 56, 56, 4), dtype="float32"),
)
lv179 = R.call_tir(
cls.negative1,
(stage1_unit3_bn3_moving_mean,),
out_sinfo=R.Tensor((64,), dtype="float32"),
)
lv180 = R.call_tir(
cls.multiply1,
(lv179, lv172),
out_sinfo=R.Tensor((64,), dtype="float32"),
)
lv181 = R.call_tir(
cls.add4,
(lv180, stage1_unit3_bn3_beta),
out_sinfo=R.Tensor((64,), dtype="float32"),
)
lv182 = R.call_tir(
cls.expand_dims1,
(lv181,),
out_sinfo=R.Tensor((64, 1, 1), dtype="float32"),
)
lv183 = R.call_tir(
cls.expand_dims3,
(lv182,),
out_sinfo=R.Tensor((1, 64, 1, 1), dtype="float32"),
)
lv184 = R.call_tir(
cls.layout_transform2,
(lv183,),
out_sinfo=R.Tensor((1, 16, 1, 1, 4), dtype="float32"),
)
lv185 = R.call_tir(
cls.add6,
(lv178, lv184),
out_sinfo=R.Tensor((1, 16, 56, 56, 4), dtype="float32"),
)
lv186 = R.call_tir(
cls.relu1,
(lv185,),
out_sinfo=R.Tensor((1, 16, 56, 56, 4), dtype="float32"),
)
lv187 = R.call_tir(
cls.layout_transform5,
(stage1_unit3_conv3_weight,),
out_sinfo=R.Tensor((64, 16, 1, 1, 4, 4), dtype="float32"),
)
lv188 = R.call_tir(
cls.contrib_conv2d_NCHWc3,
(lv186, lv187),
out_sinfo=R.Tensor((1, 64, 56, 56, 4), dtype="float32"),
)
lv189 = R.call_tir(
cls.add7,
(lv188, lv137),
out_sinfo=R.Tensor((1, 64, 56, 56, 4), dtype="float32"),
)
lv190 = R.call_tir(
cls.add8,
(
stage2_unit1_bn1_moving_var,
R.const(1.9999999494757503e-05, "float32"),
),
out_sinfo=R.Tensor((256,), dtype="float32"),
)
lv191 = R.call_tir(
cls.rsqrt2, (lv190,), out_sinfo=R.Tensor((256,), dtype="float32")
)
lv192 = R.call_tir(
cls.multiply7,
(lv191, stage2_unit1_bn1_gamma),
out_sinfo=R.Tensor((256,), dtype="float32"),
)
lv193 = R.call_tir(
cls.expand_dims4,
(lv192,),
out_sinfo=R.Tensor((256, 1, 1), dtype="float32"),
)
lv194 = R.call_tir(
cls.expand_dims5,
(lv193,),
out_sinfo=R.Tensor((1, 256, 1, 1), dtype="float32"),
)
lv195 = R.call_tir(
cls.layout_transform6,
(lv194,),
out_sinfo=R.Tensor((1, 64, 1, 1, 4), dtype="float32"),
)
lv196 = R.call_tir(
cls.multiply8,
(lv189, lv195),
out_sinfo=R.Tensor((1, 64, 56, 56, 4), dtype="float32"),
)
lv197 = R.call_tir(
cls.negative2,
(stage2_unit1_bn1_moving_mean,),
out_sinfo=R.Tensor((256,), dtype="float32"),
)
lv198 = R.call_tir(
cls.multiply7,
(lv197, lv192),
out_sinfo=R.Tensor((256,), dtype="float32"),
)
lv199 = R.call_tir(
cls.add9,
(lv198, stage2_unit1_bn1_beta),
out_sinfo=R.Tensor((256,), dtype="float32"),
)
lv200 = R.call_tir(
cls.expand_dims4,
(lv199,),
out_sinfo=R.Tensor((256, 1, 1), dtype="float32"),
)
lv201 = R.call_tir(
cls.expand_dims5,
(lv200,),
out_sinfo=R.Tensor((1, 256, 1, 1), dtype="float32"),
)
lv202 = R.call_tir(
cls.layout_transform6,
(lv201,),
out_sinfo=R.Tensor((1, 64, 1, 1, 4), dtype="float32"),
)
lv203 = R.call_tir(
cls.add10,
(lv196, lv202),
out_sinfo=R.Tensor((1, 64, 56, 56, 4), dtype="float32"),
)
lv204 = R.call_tir(
cls.relu2,
(lv203,),
out_sinfo=R.Tensor((1, 64, 56, 56, 4), dtype="float32"),
)
lv205 = R.call_tir(
cls.add11,
(
stage2_unit1_bn2_moving_var,
R.const(1.9999999494757503e-05, "float32"),
),
out_sinfo=R.Tensor((128,), dtype="float32"),
)
lv206 = R.call_tir(
cls.rsqrt3, (lv205,), out_sinfo=R.Tensor((128,), dtype="float32")
)
lv207 = R.call_tir(
cls.multiply10,
(lv206, stage2_unit1_bn2_gamma),
out_sinfo=R.Tensor((128,), dtype="float32"),
)
lv208 = R.call_tir(
cls.expand_dims6,
(lv207,),
out_sinfo=R.Tensor((128, 1, 1), dtype="float32"),
)
lv209 = R.call_tir(
cls.squeeze2, (lv208,), out_sinfo=R.Tensor((128,), dtype="float32")
)
lv210 = R.call_tir(
cls.expand_dims7,
(lv209,),
out_sinfo=R.Tensor((128, 1, 1, 1), dtype="float32"),
)
lv211 = R.call_tir(
cls.multiply11,
(stage2_unit1_conv1_weight, lv210),
out_sinfo=R.Tensor((128, 256, 1, 1), dtype="float32"),
)
lv212 = R.call_tir(
cls.layout_transform8,
(lv211,),
out_sinfo=R.Tensor((32, 64, 1, 1, 4, 4), dtype="float32"),
)
lv213 = R.call_tir(
cls.contrib_conv2d_NCHWc5,
(lv204, lv212),
out_sinfo=R.Tensor((1, 32, 28, 28, 4), dtype="float32"),
)
lv214 = R.call_tir(
cls.negative3,
(stage2_unit1_bn2_moving_mean,),
out_sinfo=R.Tensor((128,), dtype="float32"),
)
lv215 = R.call_tir(
cls.multiply10,
(lv214, lv207),
out_sinfo=R.Tensor((128,), dtype="float32"),
)
lv216 = R.call_tir(
cls.add12,
(lv215, stage2_unit1_bn2_beta),
out_sinfo=R.Tensor((128,), dtype="float32"),
)
lv217 = R.call_tir(
cls.expand_dims6,
(lv216,),
out_sinfo=R.Tensor((128, 1, 1), dtype="float32"),
)
lv218 = R.call_tir(
cls.expand_dims8,
(lv217,),
out_sinfo=R.Tensor((1, 128, 1, 1), dtype="float32"),
)
lv219 = R.call_tir(
cls.layout_transform9,
(lv218,),
out_sinfo=R.Tensor((1, 32, 1, 1, 4), dtype="float32"),
)
lv220 = R.call_tir(
cls.add13,
(lv213, lv219),
out_sinfo=R.Tensor((1, 32, 28, 28, 4), dtype="float32"),
)
lv221 = R.call_tir(
cls.relu3,
(lv220,),
out_sinfo=R.Tensor((1, 32, 28, 28, 4), dtype="float32"),
)
lv222 = R.call_tir(
cls.add11,
(
stage2_unit1_bn3_moving_var,
R.const(1.9999999494757503e-05, "float32"),
),
out_sinfo=R.Tensor((128,), dtype="float32"),
)
lv223 = R.call_tir(
cls.rsqrt3, (lv222,), out_sinfo=R.Tensor((128,), dtype="float32")
)
lv224 = R.call_tir(
cls.multiply10,
(lv223, stage2_unit1_bn3_gamma),
out_sinfo=R.Tensor((128,), dtype="float32"),
)
lv225 = R.call_tir(
cls.expand_dims6,
(lv224,),
out_sinfo=R.Tensor((128, 1, 1), dtype="float32"),
)
lv226 = R.call_tir(
cls.squeeze2, (lv225,), out_sinfo=R.Tensor((128,), dtype="float32")
)
lv227 = R.call_tir(
cls.expand_dims7,
(lv226,),
out_sinfo=R.Tensor((128, 1, 1, 1), dtype="float32"),
)
lv228 = R.call_tir(
cls.multiply12,
(stage2_unit1_conv2_weight, lv227),
out_sinfo=R.Tensor((128, 128, 3, 3), dtype="float32"),
)
lv229 = R.call_tir(
cls.layout_transform10,
(lv228,),
out_sinfo=R.Tensor((32, 32, 3, 3, 4, 4), dtype="float32"),
)
lv230 = R.call_tir(
cls.contrib_conv2d_NCHWc6,
(lv221, lv229),
out_sinfo=R.Tensor((1, 32, 28, 28, 4), dtype="float32"),
)
lv231 = R.call_tir(
cls.negative3,
(stage2_unit1_bn3_moving_mean,),
out_sinfo=R.Tensor((128,), dtype="float32"),
)
lv232 = R.call_tir(
cls.multiply10,
(lv231, lv224),
out_sinfo=R.Tensor((128,), dtype="float32"),
)
lv233 = R.call_tir(
cls.add12,
(lv232, stage2_unit1_bn3_beta),
out_sinfo=R.Tensor((128,), dtype="float32"),
)
lv234 = R.call_tir(
cls.expand_dims6,
(lv233,),
out_sinfo=R.Tensor((128, 1, 1), dtype="float32"),
)
lv235 = R.call_tir(
cls.expand_dims8,
(lv234,),
out_sinfo=R.Tensor((1, 128, 1, 1), dtype="float32"),
)
lv236 = R.call_tir(
cls.layout_transform9,
(lv235,),
out_sinfo=R.Tensor((1, 32, 1, 1, 4), dtype="float32"),
)
lv237 = R.call_tir(
cls.add13,
(lv230, lv236),
out_sinfo=R.Tensor((1, 32, 28, 28, 4), dtype="float32"),
)
lv238 = R.call_tir(
cls.relu3,
(lv237,),
out_sinfo=R.Tensor((1, 32, 28, 28, 4), dtype="float32"),
)
lv239 = R.call_tir(
cls.layout_transform11,
(stage2_unit1_conv3_weight,),
out_sinfo=R.Tensor((128, 32, 1, 1, 4, 4), dtype="float32"),
)
lv240 = R.call_tir(
cls.contrib_conv2d_NCHWc7,
(lv238, lv239),
out_sinfo=R.Tensor((1, 128, 28, 28, 4), dtype="float32"),
)
lv241 = R.call_tir(
cls.layout_transform12,
(stage2_unit1_sc_weight,),
out_sinfo=R.Tensor((128, 64, 1, 1, 4, 4), dtype="float32"),
)
lv242 = R.call_tir(
cls.contrib_conv2d_NCHWc8,
(lv204, lv241),
out_sinfo=R.Tensor((1, 128, 28, 28, 4), dtype="float32"),
)
lv243 = R.call_tir(
cls.add14,
(lv240, lv242),
out_sinfo=R.Tensor((1, 128, 28, 28, 4), dtype="float32"),
)
lv244 = R.call_tir(
cls.add15,
(
stage2_unit2_bn1_moving_var,
R.const(1.9999999494757503e-05, "float32"),
),
out_sinfo=R.Tensor((512,), dtype="float32"),
)
lv245 = R.call_tir(
cls.rsqrt4, (lv244,), out_sinfo=R.Tensor((512,), dtype="float32")
)
lv246 = R.call_tir(
cls.multiply13,
(lv245, stage2_unit2_bn1_gamma),
out_sinfo=R.Tensor((512,), dtype="float32"),
)
lv247 = R.call_tir(
cls.expand_dims9,
(lv246,),
out_sinfo=R.Tensor((512, 1, 1), dtype="float32"),
)
lv248 = R.call_tir(
cls.expand_dims10,
(lv247,),
out_sinfo=R.Tensor((1, 512, 1, 1), dtype="float32"),
)
lv249 = R.call_tir(
cls.layout_transform13,
(lv248,),
out_sinfo=R.Tensor((1, 128, 1, 1, 4), dtype="float32"),
)
lv250 = R.call_tir(
cls.multiply14,
(lv243, lv249),
out_sinfo=R.Tensor((1, 128, 28, 28, 4), dtype="float32"),
)
lv251 = R.call_tir(
cls.negative4,
(stage2_unit2_bn1_moving_mean,),
out_sinfo=R.Tensor((512,), dtype="float32"),
)
lv252 = R.call_tir(
cls.multiply13,
(lv251, lv246),
out_sinfo=R.Tensor((512,), dtype="float32"),
)
lv253 = R.call_tir(
cls.add16,
(lv252, stage2_unit2_bn1_beta),
out_sinfo=R.Tensor((512,), dtype="float32"),
)
lv254 = R.call_tir(
cls.expand_dims9,
(lv253,),
out_sinfo=R.Tensor((512, 1, 1), dtype="float32"),
)
lv255 = R.call_tir(
cls.expand_dims10,
(lv254,),
out_sinfo=R.Tensor((1, 512, 1, 1), dtype="float32"),
)
lv256 = R.call_tir(
cls.layout_transform13,
(lv255,),
out_sinfo=R.Tensor((1, 128, 1, 1, 4), dtype="float32"),
)
lv257 = R.call_tir(
cls.add17,
(lv250, lv256),
out_sinfo=R.Tensor((1, 128, 28, 28, 4), dtype="float32"),
)
lv258 = R.call_tir(
cls.relu4,
(lv257,),
out_sinfo=R.Tensor((1, 128, 28, 28, 4), dtype="float32"),
)
lv259 = R.call_tir(
cls.add11,
(
stage2_unit2_bn2_moving_var,
R.const(1.9999999494757503e-05, "float32"),
),
out_sinfo=R.Tensor((128,), dtype="float32"),
)
lv260 = R.call_tir(
cls.rsqrt3, (lv259,), out_sinfo=R.Tensor((128,), dtype="float32")
)
lv261 = R.call_tir(
cls.multiply10,
(lv260, stage2_unit2_bn2_gamma),
out_sinfo=R.Tensor((128,), dtype="float32"),
)
lv262 = R.call_tir(
cls.expand_dims6,
(lv261,),
out_sinfo=R.Tensor((128, 1, 1), dtype="float32"),
)
lv263 = R.call_tir(
cls.squeeze2, (lv262,), out_sinfo=R.Tensor((128,), dtype="float32")
)
lv264 = R.call_tir(
cls.expand_dims7,
(lv263,),
out_sinfo=R.Tensor((128, 1, 1, 1), dtype="float32"),
)
lv265 = R.call_tir(
cls.multiply15,
(stage2_unit2_conv1_weight, lv264),
out_sinfo=R.Tensor((128, 512, 1, 1), dtype="float32"),
)
lv266 = R.call_tir(
cls.layout_transform14,
(lv265,),
out_sinfo=R.Tensor((32, 128, 1, 1, 4, 4), dtype="float32"),
)
lv267 = R.call_tir(
cls.contrib_conv2d_NCHWc9,
(lv258, lv266),
out_sinfo=R.Tensor((1, 32, 28, 28, 4), dtype="float32"),
)
lv268 = R.call_tir(
cls.negative3,
(stage2_unit2_bn2_moving_mean,),
out_sinfo=R.Tensor((128,), dtype="float32"),
)
lv269 = R.call_tir(
cls.multiply10,
(lv268, lv261),
out_sinfo=R.Tensor((128,), dtype="float32"),
)
lv270 = R.call_tir(
cls.add12,
(lv269, stage2_unit2_bn2_beta),
out_sinfo=R.Tensor((128,), dtype="float32"),
)
lv271 = R.call_tir(
cls.expand_dims6,
(lv270,),
out_sinfo=R.Tensor((128, 1, 1), dtype="float32"),
)
lv272 = R.call_tir(
cls.expand_dims8,
(lv271,),
out_sinfo=R.Tensor((1, 128, 1, 1), dtype="float32"),
)
lv273 = R.call_tir(
cls.layout_transform9,
(lv272,),
out_sinfo=R.Tensor((1, 32, 1, 1, 4), dtype="float32"),
)
lv274 = R.call_tir(
cls.add13,
(lv267, lv273),
out_sinfo=R.Tensor((1, 32, 28, 28, 4), dtype="float32"),
)
lv275 = R.call_tir(
cls.relu3,
(lv274,),
out_sinfo=R.Tensor((1, 32, 28, 28, 4), dtype="float32"),
)
lv276 = R.call_tir(
cls.add11,
(
stage2_unit2_bn3_moving_var,
R.const(1.9999999494757503e-05, "float32"),
),
out_sinfo=R.Tensor((128,), dtype="float32"),
)
lv277 = R.call_tir(
cls.rsqrt3, (lv276,), out_sinfo=R.Tensor((128,), dtype="float32")
)
lv278 = R.call_tir(
cls.multiply10,
(lv277, stage2_unit2_bn3_gamma),
out_sinfo=R.Tensor((128,), dtype="float32"),
)
lv279 = R.call_tir(
cls.expand_dims6,
(lv278,),
out_sinfo=R.Tensor((128, 1, 1), dtype="float32"),
)
lv280 = R.call_tir(
cls.squeeze2, (lv279,), out_sinfo=R.Tensor((128,), dtype="float32")
)
lv281 = R.call_tir(
cls.expand_dims7,
(lv280,),
out_sinfo=R.Tensor((128, 1, 1, 1), dtype="float32"),
)
lv282 = R.call_tir(
cls.multiply12,
(stage2_unit2_conv2_weight, lv281),
out_sinfo=R.Tensor((128, 128, 3, 3), dtype="float32"),
)
lv283 = R.call_tir(
cls.layout_transform10,
(lv282,),
out_sinfo=R.Tensor((32, 32, 3, 3, 4, 4), dtype="float32"),
)
lv284 = R.call_tir(
cls.contrib_conv2d_NCHWc6,
(lv275, lv283),
out_sinfo=R.Tensor((1, 32, 28, 28, 4), dtype="float32"),
)
lv285 = R.call_tir(
cls.negative3,
(stage2_unit2_bn3_moving_mean,),
out_sinfo=R.Tensor((128,), dtype="float32"),
)
lv286 = R.call_tir(
cls.multiply10,
(lv285, lv278),
out_sinfo=R.Tensor((128,), dtype="float32"),
)
lv287 = R.call_tir(
cls.add12,
(lv286, stage2_unit2_bn3_beta),
out_sinfo=R.Tensor((128,), dtype="float32"),
)
lv288 = R.call_tir(
cls.expand_dims6,
(lv287,),
out_sinfo=R.Tensor((128, 1, 1), dtype="float32"),
)
lv289 = R.call_tir(
cls.expand_dims8,
(lv288,),
out_sinfo=R.Tensor((1, 128, 1, 1), dtype="float32"),
)
lv290 = R.call_tir(
cls.layout_transform9,
(lv289,),
out_sinfo=R.Tensor((1, 32, 1, 1, 4), dtype="float32"),
)
lv291 = R.call_tir(
cls.add13,
(lv284, lv290),
out_sinfo=R.Tensor((1, 32, 28, 28, 4), dtype="float32"),
)
lv292 = R.call_tir(
cls.relu3,
(lv291,),
out_sinfo=R.Tensor((1, 32, 28, 28, 4), dtype="float32"),
)
lv293 = R.call_tir(
cls.layout_transform11,
(stage2_unit2_conv3_weight,),
out_sinfo=R.Tensor((128, 32, 1, 1, 4, 4), dtype="float32"),
)
lv294 = R.call_tir(
cls.contrib_conv2d_NCHWc7,
(lv292, lv293),
out_sinfo=R.Tensor((1, 128, 28, 28, 4), dtype="float32"),
)
lv295 = R.call_tir(
cls.add14,
(lv294, lv243),
out_sinfo=R.Tensor((1, 128, 28, 28, 4), dtype="float32"),
)
lv296 = R.call_tir(
cls.add15,
(
stage2_unit3_bn1_moving_var,
R.const(1.9999999494757503e-05, "float32"),
),
out_sinfo=R.Tensor((512,), dtype="float32"),
)
lv297 = R.call_tir(
cls.rsqrt4, (lv296,), out_sinfo=R.Tensor((512,), dtype="float32")
)
lv298 = R.call_tir(
cls.multiply13,
(lv297, stage2_unit3_bn1_gamma),
out_sinfo=R.Tensor((512,), dtype="float32"),
)
lv299 = R.call_tir(
cls.expand_dims9,
(lv298,),
out_sinfo=R.Tensor((512, 1, 1), dtype="float32"),
)
lv300 = R.call_tir(
cls.expand_dims10,
(lv299,),
out_sinfo=R.Tensor((1, 512, 1, 1), dtype="float32"),
)
lv301 = R.call_tir(
cls.layout_transform13,
(lv300,),
out_sinfo=R.Tensor((1, 128, 1, 1, 4), dtype="float32"),
)
lv302 = R.call_tir(
cls.multiply14,
(lv295, lv301),
out_sinfo=R.Tensor((1, 128, 28, 28, 4), dtype="float32"),
)
lv303 = R.call_tir(
cls.negative4,
(stage2_unit3_bn1_moving_mean,),
out_sinfo=R.Tensor((512,), dtype="float32"),
)
lv304 = R.call_tir(
cls.multiply13,
(lv303, lv298),
out_sinfo=R.Tensor((512,), dtype="float32"),
)
lv305 = R.call_tir(
cls.add16,
(lv304, stage2_unit3_bn1_beta),
out_sinfo=R.Tensor((512,), dtype="float32"),
)
lv306 = R.call_tir(
cls.expand_dims9,
(lv305,),
out_sinfo=R.Tensor((512, 1, 1), dtype="float32"),
)
lv307 = R.call_tir(
cls.expand_dims10,
(lv306,),
out_sinfo=R.Tensor((1, 512, 1, 1), dtype="float32"),
)
lv308 = R.call_tir(
cls.layout_transform13,
(lv307,),
out_sinfo=R.Tensor((1, 128, 1, 1, 4), dtype="float32"),
)
lv309 = R.call_tir(
cls.add17,
(lv302, lv308),
out_sinfo=R.Tensor((1, 128, 28, 28, 4), dtype="float32"),
)
lv310 = R.call_tir(
cls.relu4,
(lv309,),
out_sinfo=R.Tensor((1, 128, 28, 28, 4), dtype="float32"),
)
lv311 = R.call_tir(
cls.add11,
(
stage2_unit3_bn2_moving_var,
R.const(1.9999999494757503e-05, "float32"),
),
out_sinfo=R.Tensor((128,), dtype="float32"),
)
lv312 = R.call_tir(
cls.rsqrt3, (lv311,), out_sinfo=R.Tensor((128,), dtype="float32")
)
lv313 = R.call_tir(
cls.multiply10,
(lv312, stage2_unit3_bn2_gamma),
out_sinfo=R.Tensor((128,), dtype="float32"),
)
lv314 = R.call_tir(
cls.expand_dims6,
(lv313,),
out_sinfo=R.Tensor((128, 1, 1), dtype="float32"),
)
lv315 = R.call_tir(
cls.squeeze2, (lv314,), out_sinfo=R.Tensor((128,), dtype="float32")
)
lv316 = R.call_tir(
cls.expand_dims7,
(lv315,),
out_sinfo=R.Tensor((128, 1, 1, 1), dtype="float32"),
)
lv317 = R.call_tir(
cls.multiply15,
(stage2_unit3_conv1_weight, lv316),
out_sinfo=R.Tensor((128, 512, 1, 1), dtype="float32"),
)
lv318 = R.call_tir(
cls.layout_transform14,
(lv317,),
out_sinfo=R.Tensor((32, 128, 1, 1, 4, 4), dtype="float32"),
)
lv319 = R.call_tir(
cls.contrib_conv2d_NCHWc9,
(lv310, lv318),
out_sinfo=R.Tensor((1, 32, 28, 28, 4), dtype="float32"),
)
lv320 = R.call_tir(
cls.negative3,
(stage2_unit3_bn2_moving_mean,),
out_sinfo=R.Tensor((128,), dtype="float32"),
)
lv321 = R.call_tir(
cls.multiply10,
(lv320, lv313),
out_sinfo=R.Tensor((128,), dtype="float32"),
)
lv322 = R.call_tir(
cls.add12,
(lv321, stage2_unit3_bn2_beta),
out_sinfo=R.Tensor((128,), dtype="float32"),
)
lv323 = R.call_tir(
cls.expand_dims6,
(lv322,),
out_sinfo=R.Tensor((128, 1, 1), dtype="float32"),
)
lv324 = R.call_tir(
cls.expand_dims8,
(lv323,),
out_sinfo=R.Tensor((1, 128, 1, 1), dtype="float32"),
)
lv325 = R.call_tir(
cls.layout_transform9,
(lv324,),
out_sinfo=R.Tensor((1, 32, 1, 1, 4), dtype="float32"),
)
lv326 = R.call_tir(
cls.add13,
(lv319, lv325),
out_sinfo=R.Tensor((1, 32, 28, 28, 4), dtype="float32"),
)
lv327 = R.call_tir(
cls.relu3,
(lv326,),
out_sinfo=R.Tensor((1, 32, 28, 28, 4), dtype="float32"),
)
lv328 = R.call_tir(
cls.add11,
(
stage2_unit3_bn3_moving_var,
R.const(1.9999999494757503e-05, "float32"),
),
out_sinfo=R.Tensor((128,), dtype="float32"),
)
lv329 = R.call_tir(
cls.rsqrt3, (lv328,), out_sinfo=R.Tensor((128,), dtype="float32")
)
lv330 = R.call_tir(
cls.multiply10,
(lv329, stage2_unit3_bn3_gamma),
out_sinfo=R.Tensor((128,), dtype="float32"),
)
lv331 = R.call_tir(
cls.expand_dims6,
(lv330,),
out_sinfo=R.Tensor((128, 1, 1), dtype="float32"),
)
lv332 = R.call_tir(
cls.squeeze2, (lv331,), out_sinfo=R.Tensor((128,), dtype="float32")
)
lv333 = R.call_tir(
cls.expand_dims7,
(lv332,),
out_sinfo=R.Tensor((128, 1, 1, 1), dtype="float32"),
)
lv334 = R.call_tir(
cls.multiply12,
(stage2_unit3_conv2_weight, lv333),
out_sinfo=R.Tensor((128, 128, 3, 3), dtype="float32"),
)
lv335 = R.call_tir(
cls.layout_transform10,
(lv334,),
out_sinfo=R.Tensor((32, 32, 3, 3, 4, 4), dtype="float32"),
)
lv336 = R.call_tir(
cls.contrib_conv2d_NCHWc6,
(lv327, lv335),
out_sinfo=R.Tensor((1, 32, 28, 28, 4), dtype="float32"),
)
lv337 = R.call_tir(
cls.negative3,
(stage2_unit3_bn3_moving_mean,),
out_sinfo=R.Tensor((128,), dtype="float32"),
)
lv338 = R.call_tir(
cls.multiply10,
(lv337, lv330),
out_sinfo=R.Tensor((128,), dtype="float32"),
)
lv339 = R.call_tir(
cls.add12,
(lv338, stage2_unit3_bn3_beta),
out_sinfo=R.Tensor((128,), dtype="float32"),
)
lv340 = R.call_tir(
cls.expand_dims6,
(lv339,),
out_sinfo=R.Tensor((128, 1, 1), dtype="float32"),
)
lv341 = R.call_tir(
cls.expand_dims8,
(lv340,),
out_sinfo=R.Tensor((1, 128, 1, 1), dtype="float32"),
)
lv342 = R.call_tir(
cls.layout_transform9,
(lv341,),
out_sinfo=R.Tensor((1, 32, 1, 1, 4), dtype="float32"),
)
lv343 = R.call_tir(
cls.add13,
(lv336, lv342),
out_sinfo=R.Tensor((1, 32, 28, 28, 4), dtype="float32"),
)
lv344 = R.call_tir(
cls.relu3,
(lv343,),
out_sinfo=R.Tensor((1, 32, 28, 28, 4), dtype="float32"),
)
lv345 = R.call_tir(
cls.layout_transform11,
(stage2_unit3_conv3_weight,),
out_sinfo=R.Tensor((128, 32, 1, 1, 4, 4), dtype="float32"),
)
lv346 = R.call_tir(
cls.contrib_conv2d_NCHWc7,
(lv344, lv345),
out_sinfo=R.Tensor((1, 128, 28, 28, 4), dtype="float32"),
)
lv347 = R.call_tir(
cls.add14,
(lv346, lv295),
out_sinfo=R.Tensor((1, 128, 28, 28, 4), dtype="float32"),
)
lv348 = R.call_tir(
cls.add15,
(
stage2_unit4_bn1_moving_var,
R.const(1.9999999494757503e-05, "float32"),
),
out_sinfo=R.Tensor((512,), dtype="float32"),
)
lv349 = R.call_tir(
cls.rsqrt4, (lv348,), out_sinfo=R.Tensor((512,), dtype="float32")
)
lv350 = R.call_tir(
cls.multiply13,
(lv349, stage2_unit4_bn1_gamma),
out_sinfo=R.Tensor((512,), dtype="float32"),
)
lv351 = R.call_tir(
cls.expand_dims9,
(lv350,),
out_sinfo=R.Tensor((512, 1, 1), dtype="float32"),
)
lv352 = R.call_tir(
cls.expand_dims10,
(lv351,),
out_sinfo=R.Tensor((1, 512, 1, 1), dtype="float32"),
)
lv353 = R.call_tir(
cls.layout_transform13,
(lv352,),
out_sinfo=R.Tensor((1, 128, 1, 1, 4), dtype="float32"),
)
lv354 = R.call_tir(
cls.multiply14,
(lv347, lv353),
out_sinfo=R.Tensor((1, 128, 28, 28, 4), dtype="float32"),
)
lv355 = R.call_tir(
cls.negative4,
(stage2_unit4_bn1_moving_mean,),
out_sinfo=R.Tensor((512,), dtype="float32"),
)
lv356 = R.call_tir(
cls.multiply13,
(lv355, lv350),
out_sinfo=R.Tensor((512,), dtype="float32"),
)
lv357 = R.call_tir(
cls.add16,
(lv356, stage2_unit4_bn1_beta),
out_sinfo=R.Tensor((512,), dtype="float32"),
)
lv358 = R.call_tir(
cls.expand_dims9,
(lv357,),
out_sinfo=R.Tensor((512, 1, 1), dtype="float32"),
)
lv359 = R.call_tir(
cls.expand_dims10,
(lv358,),
out_sinfo=R.Tensor((1, 512, 1, 1), dtype="float32"),
)
lv360 = R.call_tir(
cls.layout_transform13,
(lv359,),
out_sinfo=R.Tensor((1, 128, 1, 1, 4), dtype="float32"),
)
lv361 = R.call_tir(
cls.add17,
(lv354, lv360),
out_sinfo=R.Tensor((1, 128, 28, 28, 4), dtype="float32"),
)
lv362 = R.call_tir(
cls.relu4,
(lv361,),
out_sinfo=R.Tensor((1, 128, 28, 28, 4), dtype="float32"),
)
lv363 = R.call_tir(
cls.add11,
(
stage2_unit4_bn2_moving_var,
R.const(1.9999999494757503e-05, "float32"),
),
out_sinfo=R.Tensor((128,), dtype="float32"),
)
lv364 = R.call_tir(
cls.rsqrt3, (lv363,), out_sinfo=R.Tensor((128,), dtype="float32")
)
lv365 = R.call_tir(
cls.multiply10,
(lv364, stage2_unit4_bn2_gamma),
out_sinfo=R.Tensor((128,), dtype="float32"),
)
lv366 = R.call_tir(
cls.expand_dims6,
(lv365,),
out_sinfo=R.Tensor((128, 1, 1), dtype="float32"),
)
lv367 = R.call_tir(
cls.squeeze2, (lv366,), out_sinfo=R.Tensor((128,), dtype="float32")
)
lv368 = R.call_tir(
cls.expand_dims7,
(lv367,),
out_sinfo=R.Tensor((128, 1, 1, 1), dtype="float32"),
)
lv369 = R.call_tir(
cls.multiply15,
(stage2_unit4_conv1_weight, lv368),
out_sinfo=R.Tensor((128, 512, 1, 1), dtype="float32"),
)
lv370 = R.call_tir(
cls.layout_transform14,
(lv369,),
out_sinfo=R.Tensor((32, 128, 1, 1, 4, 4), dtype="float32"),
)
lv371 = R.call_tir(
cls.contrib_conv2d_NCHWc9,
(lv362, lv370),
out_sinfo=R.Tensor((1, 32, 28, 28, 4), dtype="float32"),
)
lv372 = R.call_tir(
cls.negative3,
(stage2_unit4_bn2_moving_mean,),
out_sinfo=R.Tensor((128,), dtype="float32"),
)
lv373 = R.call_tir(
cls.multiply10,
(lv372, lv365),
out_sinfo=R.Tensor((128,), dtype="float32"),
)
lv374 = R.call_tir(
cls.add12,
(lv373, stage2_unit4_bn2_beta),
out_sinfo=R.Tensor((128,), dtype="float32"),
)
lv375 = R.call_tir(
cls.expand_dims6,
(lv374,),
out_sinfo=R.Tensor((128, 1, 1), dtype="float32"),
)
lv376 = R.call_tir(
cls.expand_dims8,
(lv375,),
out_sinfo=R.Tensor((1, 128, 1, 1), dtype="float32"),
)
lv377 = R.call_tir(
cls.layout_transform9,
(lv376,),
out_sinfo=R.Tensor((1, 32, 1, 1, 4), dtype="float32"),
)
lv378 = R.call_tir(
cls.add13,
(lv371, lv377),
out_sinfo=R.Tensor((1, 32, 28, 28, 4), dtype="float32"),
)
lv379 = R.call_tir(
cls.relu3,
(lv378,),
out_sinfo=R.Tensor((1, 32, 28, 28, 4), dtype="float32"),
)
lv380 = R.call_tir(
cls.add11,
(
stage2_unit4_bn3_moving_var,
R.const(1.9999999494757503e-05, "float32"),
),
out_sinfo=R.Tensor((128,), dtype="float32"),
)
lv381 = R.call_tir(
cls.rsqrt3, (lv380,), out_sinfo=R.Tensor((128,), dtype="float32")
)
lv382 = R.call_tir(
cls.multiply10,
(lv381, stage2_unit4_bn3_gamma),
out_sinfo=R.Tensor((128,), dtype="float32"),
)
lv383 = R.call_tir(
cls.expand_dims6,
(lv382,),
out_sinfo=R.Tensor((128, 1, 1), dtype="float32"),
)
lv384 = R.call_tir(
cls.squeeze2, (lv383,), out_sinfo=R.Tensor((128,), dtype="float32")
)
lv385 = R.call_tir(
cls.expand_dims7,
(lv384,),
out_sinfo=R.Tensor((128, 1, 1, 1), dtype="float32"),
)
lv386 = R.call_tir(
cls.multiply12,
(stage2_unit4_conv2_weight, lv385),
out_sinfo=R.Tensor((128, 128, 3, 3), dtype="float32"),
)
lv387 = R.call_tir(
cls.layout_transform10,
(lv386,),
out_sinfo=R.Tensor((32, 32, 3, 3, 4, 4), dtype="float32"),
)
lv388 = R.call_tir(
cls.contrib_conv2d_NCHWc6,
(lv379, lv387),
out_sinfo=R.Tensor((1, 32, 28, 28, 4), dtype="float32"),
)
lv389 = R.call_tir(
cls.negative3,
(stage2_unit4_bn3_moving_mean,),
out_sinfo=R.Tensor((128,), dtype="float32"),
)
lv390 = R.call_tir(
cls.multiply10,
(lv389, lv382),
out_sinfo=R.Tensor((128,), dtype="float32"),
)
lv391 = R.call_tir(
cls.add12,
(lv390, stage2_unit4_bn3_beta),
out_sinfo=R.Tensor((128,), dtype="float32"),
)
lv392 = R.call_tir(
cls.expand_dims6,
(lv391,),
out_sinfo=R.Tensor((128, 1, 1), dtype="float32"),
)
lv393 = R.call_tir(
cls.expand_dims8,
(lv392,),
out_sinfo=R.Tensor((1, 128, 1, 1), dtype="float32"),
)
lv394 = R.call_tir(
cls.layout_transform9,
(lv393,),
out_sinfo=R.Tensor((1, 32, 1, 1, 4), dtype="float32"),
)
lv395 = R.call_tir(
cls.add13,
(lv388, lv394),
out_sinfo=R.Tensor((1, 32, 28, 28, 4), dtype="float32"),
)
lv396 = R.call_tir(
cls.relu3,
(lv395,),
out_sinfo=R.Tensor((1, 32, 28, 28, 4), dtype="float32"),
)
lv397 = R.call_tir(
cls.layout_transform11,
(stage2_unit4_conv3_weight,),
out_sinfo=R.Tensor((128, 32, 1, 1, 4, 4), dtype="float32"),
)
lv398 = R.call_tir(
cls.contrib_conv2d_NCHWc7,
(lv396, lv397),
out_sinfo=R.Tensor((1, 128, 28, 28, 4), dtype="float32"),
)
lv399 = R.call_tir(
cls.add14,
(lv398, lv347),
out_sinfo=R.Tensor((1, 128, 28, 28, 4), dtype="float32"),
)
lv400 = R.call_tir(
cls.add15,
(
stage3_unit1_bn1_moving_var,
R.const(1.9999999494757503e-05, "float32"),
),
out_sinfo=R.Tensor((512,), dtype="float32"),
)
lv401 = R.call_tir(
cls.rsqrt4, (lv400,), out_sinfo=R.Tensor((512,), dtype="float32")
)
lv402 = R.call_tir(
cls.multiply13,
(lv401, stage3_unit1_bn1_gamma),
out_sinfo=R.Tensor((512,), dtype="float32"),
)
lv403 = R.call_tir(
cls.expand_dims9,
(lv402,),
out_sinfo=R.Tensor((512, 1, 1), dtype="float32"),
)
lv404 = R.call_tir(
cls.expand_dims10,
(lv403,),
out_sinfo=R.Tensor((1, 512, 1, 1), dtype="float32"),
)
lv405 = R.call_tir(
cls.layout_transform13,
(lv404,),
out_sinfo=R.Tensor((1, 128, 1, 1, 4), dtype="float32"),
)
lv406 = R.call_tir(
cls.multiply14,
(lv399, lv405),
out_sinfo=R.Tensor((1, 128, 28, 28, 4), dtype="float32"),
)
lv407 = R.call_tir(
cls.negative4,
(stage3_unit1_bn1_moving_mean,),
out_sinfo=R.Tensor((512,), dtype="float32"),
)
lv408 = R.call_tir(
cls.multiply13,
(lv407, lv402),
out_sinfo=R.Tensor((512,), dtype="float32"),
)
lv409 = R.call_tir(
cls.add16,
(lv408, stage3_unit1_bn1_beta),
out_sinfo=R.Tensor((512,), dtype="float32"),
)
lv410 = R.call_tir(
cls.expand_dims9,
(lv409,),
out_sinfo=R.Tensor((512, 1, 1), dtype="float32"),
)
lv411 = R.call_tir(
cls.expand_dims10,
(lv410,),
out_sinfo=R.Tensor((1, 512, 1, 1), dtype="float32"),
)
lv412 = R.call_tir(
cls.layout_transform13,
(lv411,),
out_sinfo=R.Tensor((1, 128, 1, 1, 4), dtype="float32"),
)
lv413 = R.call_tir(
cls.add17,
(lv406, lv412),
out_sinfo=R.Tensor((1, 128, 28, 28, 4), dtype="float32"),
)
lv414 = R.call_tir(
cls.relu4,
(lv413,),
out_sinfo=R.Tensor((1, 128, 28, 28, 4), dtype="float32"),
)
lv415 = R.call_tir(
cls.add8,
(
stage3_unit1_bn2_moving_var,
R.const(1.9999999494757503e-05, "float32"),
),
out_sinfo=R.Tensor((256,), dtype="float32"),
)
lv416 = R.call_tir(
cls.rsqrt2, (lv415,), out_sinfo=R.Tensor((256,), dtype="float32")
)
lv417 = R.call_tir(
cls.multiply7,
(lv416, stage3_unit1_bn2_gamma),
out_sinfo=R.Tensor((256,), dtype="float32"),
)
lv418 = R.call_tir(
cls.expand_dims4,
(lv417,),
out_sinfo=R.Tensor((256, 1, 1), dtype="float32"),
)
lv419 = R.call_tir(
cls.squeeze3, (lv418,), out_sinfo=R.Tensor((256,), dtype="float32")
)
lv420 = R.call_tir(
cls.expand_dims11,
(lv419,),
out_sinfo=R.Tensor((256, 1, 1, 1), dtype="float32"),
)
lv421 = R.call_tir(
cls.multiply16,
(stage3_unit1_conv1_weight, lv420),
out_sinfo=R.Tensor((256, 512, 1, 1), dtype="float32"),
)
lv422 = R.call_tir(
cls.layout_transform15,
(lv421,),
out_sinfo=R.Tensor((64, 128, 1, 1, 4, 4), dtype="float32"),
)
lv423 = R.call_tir(
cls.contrib_conv2d_NCHWc10,
(lv414, lv422),
out_sinfo=R.Tensor((1, 64, 14, 14, 4), dtype="float32"),
)
lv424 = R.call_tir(
cls.negative2,
(stage3_unit1_bn2_moving_mean,),
out_sinfo=R.Tensor((256,), dtype="float32"),
)
lv425 = R.call_tir(
cls.multiply7,
(lv424, lv417),
out_sinfo=R.Tensor((256,), dtype="float32"),
)
lv426 = R.call_tir(
cls.add9,
(lv425, stage3_unit1_bn2_beta),
out_sinfo=R.Tensor((256,), dtype="float32"),
)
lv427 = R.call_tir(
cls.expand_dims4,
(lv426,),
out_sinfo=R.Tensor((256, 1, 1), dtype="float32"),
)
lv428 = R.call_tir(
cls.expand_dims5,
(lv427,),
out_sinfo=R.Tensor((1, 256, 1, 1), dtype="float32"),
)
lv429 = R.call_tir(
cls.layout_transform6,
(lv428,),
out_sinfo=R.Tensor((1, 64, 1, 1, 4), dtype="float32"),
)
lv430 = R.call_tir(
cls.add18,
(lv423, lv429),
out_sinfo=R.Tensor((1, 64, 14, 14, 4), dtype="float32"),
)
lv431 = R.call_tir(
cls.relu5,
(lv430,),
out_sinfo=R.Tensor((1, 64, 14, 14, 4), dtype="float32"),
)
lv432 = R.call_tir(
cls.add8,
(
stage3_unit1_bn3_moving_var,
R.const(1.9999999494757503e-05, "float32"),
),
out_sinfo=R.Tensor((256,), dtype="float32"),
)
lv433 = R.call_tir(
cls.rsqrt2, (lv432,), out_sinfo=R.Tensor((256,), dtype="float32")
)
lv434 = R.call_tir(
cls.multiply7,
(lv433, stage3_unit1_bn3_gamma),
out_sinfo=R.Tensor((256,), dtype="float32"),
)
lv435 = R.call_tir(
cls.expand_dims4,
(lv434,),
out_sinfo=R.Tensor((256, 1, 1), dtype="float32"),
)
lv436 = R.call_tir(
cls.squeeze3, (lv435,), out_sinfo=R.Tensor((256,), dtype="float32")
)
lv437 = R.call_tir(
cls.expand_dims11,
(lv436,),
out_sinfo=R.Tensor((256, 1, 1, 1), dtype="float32"),
)
lv438 = R.call_tir(
cls.multiply17,
(stage3_unit1_conv2_weight, lv437),
out_sinfo=R.Tensor((256, 256, 3, 3), dtype="float32"),
)
lv439 = R.call_tir(
cls.layout_transform16,
(lv438,),
out_sinfo=R.Tensor((64, 64, 3, 3, 4, 4), dtype="float32"),
)
lv440 = R.call_tir(
cls.contrib_conv2d_NCHWc11,
(lv431, lv439),
out_sinfo=R.Tensor((1, 64, 14, 14, 4), dtype="float32"),
)
lv441 = R.call_tir(
cls.negative2,
(stage3_unit1_bn3_moving_mean,),
out_sinfo=R.Tensor((256,), dtype="float32"),
)
lv442 = R.call_tir(
cls.multiply7,
(lv441, lv434),
out_sinfo=R.Tensor((256,), dtype="float32"),
)
lv443 = R.call_tir(
cls.add9,
(lv442, stage3_unit1_bn3_beta),
out_sinfo=R.Tensor((256,), dtype="float32"),
)
lv444 = R.call_tir(
cls.expand_dims4,
(lv443,),
out_sinfo=R.Tensor((256, 1, 1), dtype="float32"),
)
lv445 = R.call_tir(
cls.expand_dims5,
(lv444,),
out_sinfo=R.Tensor((1, 256, 1, 1), dtype="float32"),
)
lv446 = R.call_tir(
cls.layout_transform6,
(lv445,),
out_sinfo=R.Tensor((1, 64, 1, 1, 4), dtype="float32"),
)
lv447 = R.call_tir(
cls.add18,
(lv440, lv446),
out_sinfo=R.Tensor((1, 64, 14, 14, 4), dtype="float32"),
)
lv448 = R.call_tir(
cls.relu5,
(lv447,),
out_sinfo=R.Tensor((1, 64, 14, 14, 4), dtype="float32"),
)
lv449 = R.call_tir(
cls.layout_transform17,
(stage3_unit1_conv3_weight,),
out_sinfo=R.Tensor((256, 64, 1, 1, 4, 4), dtype="float32"),
)
lv450 = R.call_tir(
cls.contrib_conv2d_NCHWc12,
(lv448, lv449),
out_sinfo=R.Tensor((1, 256, 14, 14, 4), dtype="float32"),
)
lv451 = R.call_tir(
cls.layout_transform18,
(stage3_unit1_sc_weight,),
out_sinfo=R.Tensor((256, 128, 1, 1, 4, 4), dtype="float32"),
)
lv452 = R.call_tir(
cls.contrib_conv2d_NCHWc13,
(lv414, lv451),
out_sinfo=R.Tensor((1, 256, 14, 14, 4), dtype="float32"),
)
lv453 = R.call_tir(
cls.add19,
(lv450, lv452),
out_sinfo=R.Tensor((1, 256, 14, 14, 4), dtype="float32"),
)
lv454 = R.call_tir(
cls.add20,
(
stage3_unit2_bn1_moving_var,
R.const(1.9999999494757503e-05, "float32"),
),
out_sinfo=R.Tensor((1024,), dtype="float32"),
)
lv455 = R.call_tir(
cls.rsqrt5, (lv454,), out_sinfo=R.Tensor((1024,), dtype="float32")
)
lv456 = R.call_tir(
cls.multiply18,
(lv455, stage3_unit2_bn1_gamma),
out_sinfo=R.Tensor((1024,), dtype="float32"),
)
lv457 = R.call_tir(
cls.expand_dims12,
(lv456,),
out_sinfo=R.Tensor((1024, 1, 1), dtype="float32"),
)
lv458 = R.call_tir(
cls.expand_dims13,
(lv457,),
out_sinfo=R.Tensor((1, 1024, 1, 1), dtype="float32"),
)
lv459 = R.call_tir(
cls.layout_transform19,
(lv458,),
out_sinfo=R.Tensor((1, 256, 1, 1, 4), dtype="float32"),
)
lv460 = R.call_tir(
cls.multiply19,
(lv453, lv459),
out_sinfo=R.Tensor((1, 256, 14, 14, 4), dtype="float32"),
)
lv461 = R.call_tir(
cls.negative5,
(stage3_unit2_bn1_moving_mean,),
out_sinfo=R.Tensor((1024,), dtype="float32"),
)
lv462 = R.call_tir(
cls.multiply18,
(lv461, lv456),
out_sinfo=R.Tensor((1024,), dtype="float32"),
)
lv463 = R.call_tir(
cls.add21,
(lv462, stage3_unit2_bn1_beta),
out_sinfo=R.Tensor((1024,), dtype="float32"),
)
lv464 = R.call_tir(
cls.expand_dims12,
(lv463,),
out_sinfo=R.Tensor((1024, 1, 1), dtype="float32"),
)
lv465 = R.call_tir(
cls.expand_dims13,
(lv464,),
out_sinfo=R.Tensor((1, 1024, 1, 1), dtype="float32"),
)
lv466 = R.call_tir(
cls.layout_transform19,
(lv465,),
out_sinfo=R.Tensor((1, 256, 1, 1, 4), dtype="float32"),
)
lv467 = R.call_tir(
cls.add22,
(lv460, lv466),
out_sinfo=R.Tensor((1, 256, 14, 14, 4), dtype="float32"),
)
lv468 = R.call_tir(
cls.relu6,
(lv467,),
out_sinfo=R.Tensor((1, 256, 14, 14, 4), dtype="float32"),
)
lv469 = R.call_tir(
cls.add8,
(
stage3_unit2_bn2_moving_var,
R.const(1.9999999494757503e-05, "float32"),
),
out_sinfo=R.Tensor((256,), dtype="float32"),
)
lv470 = R.call_tir(
cls.rsqrt2, (lv469,), out_sinfo=R.Tensor((256,), dtype="float32")
)
lv471 = R.call_tir(
cls.multiply7,
(lv470, stage3_unit2_bn2_gamma),
out_sinfo=R.Tensor((256,), dtype="float32"),
)
lv472 = R.call_tir(
cls.expand_dims4,
(lv471,),
out_sinfo=R.Tensor((256, 1, 1), dtype="float32"),
)
lv473 = R.call_tir(
cls.squeeze3, (lv472,), out_sinfo=R.Tensor((256,), dtype="float32")
)
lv474 = R.call_tir(
cls.expand_dims11,
(lv473,),
out_sinfo=R.Tensor((256, 1, 1, 1), dtype="float32"),
)
lv475 = R.call_tir(
cls.multiply20,
(stage3_unit2_conv1_weight, lv474),
out_sinfo=R.Tensor((256, 1024, 1, 1), dtype="float32"),
)
lv476 = R.call_tir(
cls.layout_transform20,
(lv475,),
out_sinfo=R.Tensor((64, 256, 1, 1, 4, 4), dtype="float32"),
)
lv477 = R.call_tir(
cls.contrib_conv2d_NCHWc14,
(lv468, lv476),
out_sinfo=R.Tensor((1, 64, 14, 14, 4), dtype="float32"),
)
lv478 = R.call_tir(
cls.negative2,
(stage3_unit2_bn2_moving_mean,),
out_sinfo=R.Tensor((256,), dtype="float32"),
)
lv479 = R.call_tir(
cls.multiply7,
(lv478, lv471),
out_sinfo=R.Tensor((256,), dtype="float32"),
)
lv480 = R.call_tir(
cls.add9,
(lv479, stage3_unit2_bn2_beta),
out_sinfo=R.Tensor((256,), dtype="float32"),
)
lv481 = R.call_tir(
cls.expand_dims4,
(lv480,),
out_sinfo=R.Tensor((256, 1, 1), dtype="float32"),
)
lv482 = R.call_tir(
cls.expand_dims5,
(lv481,),
out_sinfo=R.Tensor((1, 256, 1, 1), dtype="float32"),
)
lv483 = R.call_tir(
cls.layout_transform6,
(lv482,),
out_sinfo=R.Tensor((1, 64, 1, 1, 4), dtype="float32"),
)
lv484 = R.call_tir(
cls.add18,
(lv477, lv483),
out_sinfo=R.Tensor((1, 64, 14, 14, 4), dtype="float32"),
)
lv485 = R.call_tir(
cls.relu5,
(lv484,),
out_sinfo=R.Tensor((1, 64, 14, 14, 4), dtype="float32"),
)
lv486 = R.call_tir(
cls.add8,
(
stage3_unit2_bn3_moving_var,
R.const(1.9999999494757503e-05, "float32"),
),
out_sinfo=R.Tensor((256,), dtype="float32"),
)
lv487 = R.call_tir(
cls.rsqrt2, (lv486,), out_sinfo=R.Tensor((256,), dtype="float32")
)
lv488 = R.call_tir(
cls.multiply7,
(lv487, stage3_unit2_bn3_gamma),
out_sinfo=R.Tensor((256,), dtype="float32"),
)
lv489 = R.call_tir(
cls.expand_dims4,
(lv488,),
out_sinfo=R.Tensor((256, 1, 1), dtype="float32"),
)
lv490 = R.call_tir(
cls.squeeze3, (lv489,), out_sinfo=R.Tensor((256,), dtype="float32")
)
lv491 = R.call_tir(
cls.expand_dims11,
(lv490,),
out_sinfo=R.Tensor((256, 1, 1, 1), dtype="float32"),
)
lv492 = R.call_tir(
cls.multiply17,
(stage3_unit2_conv2_weight, lv491),
out_sinfo=R.Tensor((256, 256, 3, 3), dtype="float32"),
)
lv493 = R.call_tir(
cls.layout_transform16,
(lv492,),
out_sinfo=R.Tensor((64, 64, 3, 3, 4, 4), dtype="float32"),
)
lv494 = R.call_tir(
cls.contrib_conv2d_NCHWc11,
(lv485, lv493),
out_sinfo=R.Tensor((1, 64, 14, 14, 4), dtype="float32"),
)
lv495 = R.call_tir(
cls.negative2,
(stage3_unit2_bn3_moving_mean,),
out_sinfo=R.Tensor((256,), dtype="float32"),
)
lv496 = R.call_tir(
cls.multiply7,
(lv495, lv488),
out_sinfo=R.Tensor((256,), dtype="float32"),
)
lv497 = R.call_tir(
cls.add9,
(lv496, stage3_unit2_bn3_beta),
out_sinfo=R.Tensor((256,), dtype="float32"),
)
lv498 = R.call_tir(
cls.expand_dims4,
(lv497,),
out_sinfo=R.Tensor((256, 1, 1), dtype="float32"),
)
lv499 = R.call_tir(
cls.expand_dims5,
(lv498,),
out_sinfo=R.Tensor((1, 256, 1, 1), dtype="float32"),
)
lv500 = R.call_tir(
cls.layout_transform6,
(lv499,),
out_sinfo=R.Tensor((1, 64, 1, 1, 4), dtype="float32"),
)
lv501 = R.call_tir(
cls.add18,
(lv494, lv500),
out_sinfo=R.Tensor((1, 64, 14, 14, 4), dtype="float32"),
)
lv502 = R.call_tir(
cls.relu5,
(lv501,),
out_sinfo=R.Tensor((1, 64, 14, 14, 4), dtype="float32"),
)
lv503 = R.call_tir(
cls.layout_transform17,
(stage3_unit2_conv3_weight,),
out_sinfo=R.Tensor((256, 64, 1, 1, 4, 4), dtype="float32"),
)
lv504 = R.call_tir(
cls.contrib_conv2d_NCHWc12,
(lv502, lv503),
out_sinfo=R.Tensor((1, 256, 14, 14, 4), dtype="float32"),
)
lv505 = R.call_tir(
cls.add19,
(lv504, lv453),
out_sinfo=R.Tensor((1, 256, 14, 14, 4), dtype="float32"),
)
lv506 = R.call_tir(
cls.add20,
(
stage3_unit3_bn1_moving_var,
R.const(1.9999999494757503e-05, "float32"),
),
out_sinfo=R.Tensor((1024,), dtype="float32"),
)
lv507 = R.call_tir(
cls.rsqrt5, (lv506,), out_sinfo=R.Tensor((1024,), dtype="float32")
)
lv508 = R.call_tir(
cls.multiply18,
(lv507, stage3_unit3_bn1_gamma),
out_sinfo=R.Tensor((1024,), dtype="float32"),
)
lv509 = R.call_tir(
cls.expand_dims12,
(lv508,),
out_sinfo=R.Tensor((1024, 1, 1), dtype="float32"),
)
lv510 = R.call_tir(
cls.expand_dims13,
(lv509,),
out_sinfo=R.Tensor((1, 1024, 1, 1), dtype="float32"),
)
lv511 = R.call_tir(
cls.layout_transform19,
(lv510,),
out_sinfo=R.Tensor((1, 256, 1, 1, 4), dtype="float32"),
)
lv512 = R.call_tir(
cls.multiply19,
(lv505, lv511),
out_sinfo=R.Tensor((1, 256, 14, 14, 4), dtype="float32"),
)
lv513 = R.call_tir(
cls.negative5,
(stage3_unit3_bn1_moving_mean,),
out_sinfo=R.Tensor((1024,), dtype="float32"),
)
lv514 = R.call_tir(
cls.multiply18,
(lv513, lv508),
out_sinfo=R.Tensor((1024,), dtype="float32"),
)
lv515 = R.call_tir(
cls.add21,
(lv514, stage3_unit3_bn1_beta),
out_sinfo=R.Tensor((1024,), dtype="float32"),
)
lv516 = R.call_tir(
cls.expand_dims12,
(lv515,),
out_sinfo=R.Tensor((1024, 1, 1), dtype="float32"),
)
lv517 = R.call_tir(
cls.expand_dims13,
(lv516,),
out_sinfo=R.Tensor((1, 1024, 1, 1), dtype="float32"),
)
lv518 = R.call_tir(
cls.layout_transform19,
(lv517,),
out_sinfo=R.Tensor((1, 256, 1, 1, 4), dtype="float32"),
)
lv519 = R.call_tir(
cls.add22,
(lv512, lv518),
out_sinfo=R.Tensor((1, 256, 14, 14, 4), dtype="float32"),
)
lv520 = R.call_tir(
cls.relu6,
(lv519,),
out_sinfo=R.Tensor((1, 256, 14, 14, 4), dtype="float32"),
)
lv521 = R.call_tir(
cls.add8,
(
stage3_unit3_bn2_moving_var,
R.const(1.9999999494757503e-05, "float32"),
),
out_sinfo=R.Tensor((256,), dtype="float32"),
)
lv522 = R.call_tir(
cls.rsqrt2, (lv521,), out_sinfo=R.Tensor((256,), dtype="float32")
)
lv523 = R.call_tir(
cls.multiply7,
(lv522, stage3_unit3_bn2_gamma),
out_sinfo=R.Tensor((256,), dtype="float32"),
)
lv524 = R.call_tir(
cls.expand_dims4,
(lv523,),
out_sinfo=R.Tensor((256, 1, 1), dtype="float32"),
)
lv525 = R.call_tir(
cls.squeeze3, (lv524,), out_sinfo=R.Tensor((256,), dtype="float32")
)
lv526 = R.call_tir(
cls.expand_dims11,
(lv525,),
out_sinfo=R.Tensor((256, 1, 1, 1), dtype="float32"),
)
lv527 = R.call_tir(
cls.multiply20,
(stage3_unit3_conv1_weight, lv526),
out_sinfo=R.Tensor((256, 1024, 1, 1), dtype="float32"),
)
lv528 = R.call_tir(
cls.layout_transform20,
(lv527,),
out_sinfo=R.Tensor((64, 256, 1, 1, 4, 4), dtype="float32"),
)
lv529 = R.call_tir(
cls.contrib_conv2d_NCHWc14,
(lv520, lv528),
out_sinfo=R.Tensor((1, 64, 14, 14, 4), dtype="float32"),
)
lv530 = R.call_tir(
cls.negative2,
(stage3_unit3_bn2_moving_mean,),
out_sinfo=R.Tensor((256,), dtype="float32"),
)
lv531 = R.call_tir(
cls.multiply7,
(lv530, lv523),
out_sinfo=R.Tensor((256,), dtype="float32"),
)
lv532 = R.call_tir(
cls.add9,
(lv531, stage3_unit3_bn2_beta),
out_sinfo=R.Tensor((256,), dtype="float32"),
)
lv533 = R.call_tir(
cls.expand_dims4,
(lv532,),
out_sinfo=R.Tensor((256, 1, 1), dtype="float32"),
)
lv534 = R.call_tir(
cls.expand_dims5,
(lv533,),
out_sinfo=R.Tensor((1, 256, 1, 1), dtype="float32"),
)
lv535 = R.call_tir(
cls.layout_transform6,
(lv534,),
out_sinfo=R.Tensor((1, 64, 1, 1, 4), dtype="float32"),
)
lv536 = R.call_tir(
cls.add18,
(lv529, lv535),
out_sinfo=R.Tensor((1, 64, 14, 14, 4), dtype="float32"),
)
lv537 = R.call_tir(
cls.relu5,
(lv536,),
out_sinfo=R.Tensor((1, 64, 14, 14, 4), dtype="float32"),
)
lv538 = R.call_tir(
cls.add8,
(
stage3_unit3_bn3_moving_var,
R.const(1.9999999494757503e-05, "float32"),
),
out_sinfo=R.Tensor((256,), dtype="float32"),
)
lv539 = R.call_tir(
cls.rsqrt2, (lv538,), out_sinfo=R.Tensor((256,), dtype="float32")
)
lv540 = R.call_tir(
cls.multiply7,
(lv539, stage3_unit3_bn3_gamma),
out_sinfo=R.Tensor((256,), dtype="float32"),
)
lv541 = R.call_tir(
cls.expand_dims4,
(lv540,),
out_sinfo=R.Tensor((256, 1, 1), dtype="float32"),
)
lv542 = R.call_tir(
cls.squeeze3, (lv541,), out_sinfo=R.Tensor((256,), dtype="float32")
)
lv543 = R.call_tir(
cls.expand_dims11,
(lv542,),
out_sinfo=R.Tensor((256, 1, 1, 1), dtype="float32"),
)
lv544 = R.call_tir(
cls.multiply17,
(stage3_unit3_conv2_weight, lv543),
out_sinfo=R.Tensor((256, 256, 3, 3), dtype="float32"),
)
lv545 = R.call_tir(
cls.layout_transform16,
(lv544,),
out_sinfo=R.Tensor((64, 64, 3, 3, 4, 4), dtype="float32"),
)
lv546 = R.call_tir(
cls.contrib_conv2d_NCHWc11,
(lv537, lv545),
out_sinfo=R.Tensor((1, 64, 14, 14, 4), dtype="float32"),
)
lv547 = R.call_tir(
cls.negative2,
(stage3_unit3_bn3_moving_mean,),
out_sinfo=R.Tensor((256,), dtype="float32"),
)
lv548 = R.call_tir(
cls.multiply7,
(lv547, lv540),
out_sinfo=R.Tensor((256,), dtype="float32"),
)
lv549 = R.call_tir(
cls.add9,
(lv548, stage3_unit3_bn3_beta),
out_sinfo=R.Tensor((256,), dtype="float32"),
)
lv550 = R.call_tir(
cls.expand_dims4,
(lv549,),
out_sinfo=R.Tensor((256, 1, 1), dtype="float32"),
)
lv551 = R.call_tir(
cls.expand_dims5,
(lv550,),
out_sinfo=R.Tensor((1, 256, 1, 1), dtype="float32"),
)
lv552 = R.call_tir(
cls.layout_transform6,
(lv551,),
out_sinfo=R.Tensor((1, 64, 1, 1, 4), dtype="float32"),
)
lv553 = R.call_tir(
cls.add18,
(lv546, lv552),
out_sinfo=R.Tensor((1, 64, 14, 14, 4), dtype="float32"),
)
lv554 = R.call_tir(
cls.relu5,
(lv553,),
out_sinfo=R.Tensor((1, 64, 14, 14, 4), dtype="float32"),
)
lv555 = R.call_tir(
cls.layout_transform17,
(stage3_unit3_conv3_weight,),
out_sinfo=R.Tensor((256, 64, 1, 1, 4, 4), dtype="float32"),
)
lv556 = R.call_tir(
cls.contrib_conv2d_NCHWc12,
(lv554, lv555),
out_sinfo=R.Tensor((1, 256, 14, 14, 4), dtype="float32"),
)
lv557 = R.call_tir(
cls.add19,
(lv556, lv505),
out_sinfo=R.Tensor((1, 256, 14, 14, 4), dtype="float32"),
)
lv558 = R.call_tir(
cls.add20,
(
stage3_unit4_bn1_moving_var,
R.const(1.9999999494757503e-05, "float32"),
),
out_sinfo=R.Tensor((1024,), dtype="float32"),
)
lv559 = R.call_tir(
cls.rsqrt5, (lv558,), out_sinfo=R.Tensor((1024,), dtype="float32")
)
lv560 = R.call_tir(
cls.multiply18,
(lv559, stage3_unit4_bn1_gamma),
out_sinfo=R.Tensor((1024,), dtype="float32"),
)
lv561 = R.call_tir(
cls.expand_dims12,
(lv560,),
out_sinfo=R.Tensor((1024, 1, 1), dtype="float32"),
)
lv562 = R.call_tir(
cls.expand_dims13,
(lv561,),
out_sinfo=R.Tensor((1, 1024, 1, 1), dtype="float32"),
)
lv563 = R.call_tir(
cls.layout_transform19,
(lv562,),
out_sinfo=R.Tensor((1, 256, 1, 1, 4), dtype="float32"),
)
lv564 = R.call_tir(
cls.multiply19,
(lv557, lv563),
out_sinfo=R.Tensor((1, 256, 14, 14, 4), dtype="float32"),
)
lv565 = R.call_tir(
cls.negative5,
(stage3_unit4_bn1_moving_mean,),
out_sinfo=R.Tensor((1024,), dtype="float32"),
)
lv566 = R.call_tir(
cls.multiply18,
(lv565, lv560),
out_sinfo=R.Tensor((1024,), dtype="float32"),
)
lv567 = R.call_tir(
cls.add21,
(lv566, stage3_unit4_bn1_beta),
out_sinfo=R.Tensor((1024,), dtype="float32"),
)
lv568 = R.call_tir(
cls.expand_dims12,
(lv567,),
out_sinfo=R.Tensor((1024, 1, 1), dtype="float32"),
)
lv569 = R.call_tir(
cls.expand_dims13,
(lv568,),
out_sinfo=R.Tensor((1, 1024, 1, 1), dtype="float32"),
)
lv570 = R.call_tir(
cls.layout_transform19,
(lv569,),
out_sinfo=R.Tensor((1, 256, 1, 1, 4), dtype="float32"),
)
lv571 = R.call_tir(
cls.add22,
(lv564, lv570),
out_sinfo=R.Tensor((1, 256, 14, 14, 4), dtype="float32"),
)
lv572 = R.call_tir(
cls.relu6,
(lv571,),
out_sinfo=R.Tensor((1, 256, 14, 14, 4), dtype="float32"),
)
lv573 = R.call_tir(
cls.add8,
(
stage3_unit4_bn2_moving_var,
R.const(1.9999999494757503e-05, "float32"),
),
out_sinfo=R.Tensor((256,), dtype="float32"),
)
lv574 = R.call_tir(
cls.rsqrt2, (lv573,), out_sinfo=R.Tensor((256,), dtype="float32")
)
lv575 = R.call_tir(
cls.multiply7,
(lv574, stage3_unit4_bn2_gamma),
out_sinfo=R.Tensor((256,), dtype="float32"),
)
lv576 = R.call_tir(
cls.expand_dims4,
(lv575,),
out_sinfo=R.Tensor((256, 1, 1), dtype="float32"),
)
lv577 = R.call_tir(
cls.squeeze3, (lv576,), out_sinfo=R.Tensor((256,), dtype="float32")
)
lv578 = R.call_tir(
cls.expand_dims11,
(lv577,),
out_sinfo=R.Tensor((256, 1, 1, 1), dtype="float32"),
)
lv579 = R.call_tir(
cls.multiply20,
(stage3_unit4_conv1_weight, lv578),
out_sinfo=R.Tensor((256, 1024, 1, 1), dtype="float32"),
)
lv580 = R.call_tir(
cls.layout_transform20,
(lv579,),
out_sinfo=R.Tensor((64, 256, 1, 1, 4, 4), dtype="float32"),
)
lv581 = R.call_tir(
cls.contrib_conv2d_NCHWc14,
(lv572, lv580),
out_sinfo=R.Tensor((1, 64, 14, 14, 4), dtype="float32"),
)
lv582 = R.call_tir(
cls.negative2,
(stage3_unit4_bn2_moving_mean,),
out_sinfo=R.Tensor((256,), dtype="float32"),
)
lv583 = R.call_tir(
cls.multiply7,
(lv582, lv575),
out_sinfo=R.Tensor((256,), dtype="float32"),
)
lv584 = R.call_tir(
cls.add9,
(lv583, stage3_unit4_bn2_beta),
out_sinfo=R.Tensor((256,), dtype="float32"),
)
lv585 = R.call_tir(
cls.expand_dims4,
(lv584,),
out_sinfo=R.Tensor((256, 1, 1), dtype="float32"),
)
lv586 = R.call_tir(
cls.expand_dims5,
(lv585,),
out_sinfo=R.Tensor((1, 256, 1, 1), dtype="float32"),
)
lv587 = R.call_tir(
cls.layout_transform6,
(lv586,),
out_sinfo=R.Tensor((1, 64, 1, 1, 4), dtype="float32"),
)
lv588 = R.call_tir(
cls.add18,
(lv581, lv587),
out_sinfo=R.Tensor((1, 64, 14, 14, 4), dtype="float32"),
)
lv589 = R.call_tir(
cls.relu5,
(lv588,),
out_sinfo=R.Tensor((1, 64, 14, 14, 4), dtype="float32"),
)
lv590 = R.call_tir(
cls.add8,
(
stage3_unit4_bn3_moving_var,
R.const(1.9999999494757503e-05, "float32"),
),
out_sinfo=R.Tensor((256,), dtype="float32"),
)
lv591 = R.call_tir(
cls.rsqrt2, (lv590,), out_sinfo=R.Tensor((256,), dtype="float32")
)
lv592 = R.call_tir(
cls.multiply7,
(lv591, stage3_unit4_bn3_gamma),
out_sinfo=R.Tensor((256,), dtype="float32"),
)
lv593 = R.call_tir(
cls.expand_dims4,
(lv592,),
out_sinfo=R.Tensor((256, 1, 1), dtype="float32"),
)
lv594 = R.call_tir(
cls.squeeze3, (lv593,), out_sinfo=R.Tensor((256,), dtype="float32")
)
lv595 = R.call_tir(
cls.expand_dims11,
(lv594,),
out_sinfo=R.Tensor((256, 1, 1, 1), dtype="float32"),
)
lv596 = R.call_tir(
cls.multiply17,
(stage3_unit4_conv2_weight, lv595),
out_sinfo=R.Tensor((256, 256, 3, 3), dtype="float32"),
)
lv597 = R.call_tir(
cls.layout_transform16,
(lv596,),
out_sinfo=R.Tensor((64, 64, 3, 3, 4, 4), dtype="float32"),
)
lv598 = R.call_tir(
cls.contrib_conv2d_NCHWc11,
(lv589, lv597),
out_sinfo=R.Tensor((1, 64, 14, 14, 4), dtype="float32"),
)
lv599 = R.call_tir(
cls.negative2,
(stage3_unit4_bn3_moving_mean,),
out_sinfo=R.Tensor((256,), dtype="float32"),
)
lv600 = R.call_tir(
cls.multiply7,
(lv599, lv592),
out_sinfo=R.Tensor((256,), dtype="float32"),
)
lv601 = R.call_tir(
cls.add9,
(lv600, stage3_unit4_bn3_beta),
out_sinfo=R.Tensor((256,), dtype="float32"),
)
lv602 = R.call_tir(
cls.expand_dims4,
(lv601,),
out_sinfo=R.Tensor((256, 1, 1), dtype="float32"),
)
lv603 = R.call_tir(
cls.expand_dims5,
(lv602,),
out_sinfo=R.Tensor((1, 256, 1, 1), dtype="float32"),
)
lv604 = R.call_tir(
cls.layout_transform6,
(lv603,),
out_sinfo=R.Tensor((1, 64, 1, 1, 4), dtype="float32"),
)
lv605 = R.call_tir(
cls.add18,
(lv598, lv604),
out_sinfo=R.Tensor((1, 64, 14, 14, 4), dtype="float32"),
)
lv606 = R.call_tir(
cls.relu5,
(lv605,),
out_sinfo=R.Tensor((1, 64, 14, 14, 4), dtype="float32"),
)
lv607 = R.call_tir(
cls.layout_transform17,
(stage3_unit4_conv3_weight,),
out_sinfo=R.Tensor((256, 64, 1, 1, 4, 4), dtype="float32"),
)
lv608 = R.call_tir(
cls.contrib_conv2d_NCHWc12,
(lv606, lv607),
out_sinfo=R.Tensor((1, 256, 14, 14, 4), dtype="float32"),
)
lv609 = R.call_tir(
cls.add19,
(lv608, lv557),
out_sinfo=R.Tensor((1, 256, 14, 14, 4), dtype="float32"),
)
lv610 = R.call_tir(
cls.add20,
(
stage3_unit5_bn1_moving_var,
R.const(1.9999999494757503e-05, "float32"),
),
out_sinfo=R.Tensor((1024,), dtype="float32"),
)
lv611 = R.call_tir(
cls.rsqrt5, (lv610,), out_sinfo=R.Tensor((1024,), dtype="float32")
)
lv612 = R.call_tir(
cls.multiply18,
(lv611, stage3_unit5_bn1_gamma),
out_sinfo=R.Tensor((1024,), dtype="float32"),
)
lv613 = R.call_tir(
cls.expand_dims12,
(lv612,),
out_sinfo=R.Tensor((1024, 1, 1), dtype="float32"),
)
lv614 = R.call_tir(
cls.expand_dims13,
(lv613,),
out_sinfo=R.Tensor((1, 1024, 1, 1), dtype="float32"),
)
lv615 = R.call_tir(
cls.layout_transform19,
(lv614,),
out_sinfo=R.Tensor((1, 256, 1, 1, 4), dtype="float32"),
)
lv616 = R.call_tir(
cls.multiply19,
(lv609, lv615),
out_sinfo=R.Tensor((1, 256, 14, 14, 4), dtype="float32"),
)
lv617 = R.call_tir(
cls.negative5,
(stage3_unit5_bn1_moving_mean,),
out_sinfo=R.Tensor((1024,), dtype="float32"),
)
lv618 = R.call_tir(
cls.multiply18,
(lv617, lv612),
out_sinfo=R.Tensor((1024,), dtype="float32"),
)
lv619 = R.call_tir(
cls.add21,
(lv618, stage3_unit5_bn1_beta),
out_sinfo=R.Tensor((1024,), dtype="float32"),
)
lv620 = R.call_tir(
cls.expand_dims12,
(lv619,),
out_sinfo=R.Tensor((1024, 1, 1), dtype="float32"),
)
lv621 = R.call_tir(
cls.expand_dims13,
(lv620,),
out_sinfo=R.Tensor((1, 1024, 1, 1), dtype="float32"),
)
lv622 = R.call_tir(
cls.layout_transform19,
(lv621,),
out_sinfo=R.Tensor((1, 256, 1, 1, 4), dtype="float32"),
)
lv623 = R.call_tir(
cls.add22,
(lv616, lv622),
out_sinfo=R.Tensor((1, 256, 14, 14, 4), dtype="float32"),
)
lv624 = R.call_tir(
cls.relu6,
(lv623,),
out_sinfo=R.Tensor((1, 256, 14, 14, 4), dtype="float32"),
)
lv625 = R.call_tir(
cls.add8,
(
stage3_unit5_bn2_moving_var,
R.const(1.9999999494757503e-05, "float32"),
),
out_sinfo=R.Tensor((256,), dtype="float32"),
)
lv626 = R.call_tir(
cls.rsqrt2, (lv625,), out_sinfo=R.Tensor((256,), dtype="float32")
)
lv627 = R.call_tir(
cls.multiply7,
(lv626, stage3_unit5_bn2_gamma),
out_sinfo=R.Tensor((256,), dtype="float32"),
)
lv628 = R.call_tir(
cls.expand_dims4,
(lv627,),
out_sinfo=R.Tensor((256, 1, 1), dtype="float32"),
)
lv629 = R.call_tir(
cls.squeeze3, (lv628,), out_sinfo=R.Tensor((256,), dtype="float32")
)
lv630 = R.call_tir(
cls.expand_dims11,
(lv629,),
out_sinfo=R.Tensor((256, 1, 1, 1), dtype="float32"),
)
lv631 = R.call_tir(
cls.multiply20,
(stage3_unit5_conv1_weight, lv630),
out_sinfo=R.Tensor((256, 1024, 1, 1), dtype="float32"),
)
lv632 = R.call_tir(
cls.layout_transform20,
(lv631,),
out_sinfo=R.Tensor((64, 256, 1, 1, 4, 4), dtype="float32"),
)
lv633 = R.call_tir(
cls.contrib_conv2d_NCHWc14,
(lv624, lv632),
out_sinfo=R.Tensor((1, 64, 14, 14, 4), dtype="float32"),
)
lv634 = R.call_tir(
cls.negative2,
(stage3_unit5_bn2_moving_mean,),
out_sinfo=R.Tensor((256,), dtype="float32"),
)
lv635 = R.call_tir(
cls.multiply7,
(lv634, lv627),
out_sinfo=R.Tensor((256,), dtype="float32"),
)
lv636 = R.call_tir(
cls.add9,
(lv635, stage3_unit5_bn2_beta),
out_sinfo=R.Tensor((256,), dtype="float32"),
)
lv637 = R.call_tir(
cls.expand_dims4,
(lv636,),
out_sinfo=R.Tensor((256, 1, 1), dtype="float32"),
)
lv638 = R.call_tir(
cls.expand_dims5,
(lv637,),
out_sinfo=R.Tensor((1, 256, 1, 1), dtype="float32"),
)
lv639 = R.call_tir(
cls.layout_transform6,
(lv638,),
out_sinfo=R.Tensor((1, 64, 1, 1, 4), dtype="float32"),
)
lv640 = R.call_tir(
cls.add18,
(lv633, lv639),
out_sinfo=R.Tensor((1, 64, 14, 14, 4), dtype="float32"),
)
lv641 = R.call_tir(
cls.relu5,
(lv640,),
out_sinfo=R.Tensor((1, 64, 14, 14, 4), dtype="float32"),
)
lv642 = R.call_tir(
cls.add8,
(
stage3_unit5_bn3_moving_var,
R.const(1.9999999494757503e-05, "float32"),
),
out_sinfo=R.Tensor((256,), dtype="float32"),
)
lv643 = R.call_tir(
cls.rsqrt2, (lv642,), out_sinfo=R.Tensor((256,), dtype="float32")
)
lv644 = R.call_tir(
cls.multiply7,
(lv643, stage3_unit5_bn3_gamma),
out_sinfo=R.Tensor((256,), dtype="float32"),
)
lv645 = R.call_tir(
cls.expand_dims4,
(lv644,),
out_sinfo=R.Tensor((256, 1, 1), dtype="float32"),
)
lv646 = R.call_tir(
cls.squeeze3, (lv645,), out_sinfo=R.Tensor((256,), dtype="float32")
)
lv647 = R.call_tir(
cls.expand_dims11,
(lv646,),
out_sinfo=R.Tensor((256, 1, 1, 1), dtype="float32"),
)
lv648 = R.call_tir(
cls.multiply17,
(stage3_unit5_conv2_weight, lv647),
out_sinfo=R.Tensor((256, 256, 3, 3), dtype="float32"),
)
lv649 = R.call_tir(
cls.layout_transform16,
(lv648,),
out_sinfo=R.Tensor((64, 64, 3, 3, 4, 4), dtype="float32"),
)
lv650 = R.call_tir(
cls.contrib_conv2d_NCHWc11,
(lv641, lv649),
out_sinfo=R.Tensor((1, 64, 14, 14, 4), dtype="float32"),
)
lv651 = R.call_tir(
cls.negative2,
(stage3_unit5_bn3_moving_mean,),
out_sinfo=R.Tensor((256,), dtype="float32"),
)
lv652 = R.call_tir(
cls.multiply7,
(lv651, lv644),
out_sinfo=R.Tensor((256,), dtype="float32"),
)
lv653 = R.call_tir(
cls.add9,
(lv652, stage3_unit5_bn3_beta),
out_sinfo=R.Tensor((256,), dtype="float32"),
)
lv654 = R.call_tir(
cls.expand_dims4,
(lv653,),
out_sinfo=R.Tensor((256, 1, 1), dtype="float32"),
)
lv655 = R.call_tir(
cls.expand_dims5,
(lv654,),
out_sinfo=R.Tensor((1, 256, 1, 1), dtype="float32"),
)
lv656 = R.call_tir(
cls.layout_transform6,
(lv655,),
out_sinfo=R.Tensor((1, 64, 1, 1, 4), dtype="float32"),
)
lv657 = R.call_tir(
cls.add18,
(lv650, lv656),
out_sinfo=R.Tensor((1, 64, 14, 14, 4), dtype="float32"),
)
lv658 = R.call_tir(
cls.relu5,
(lv657,),
out_sinfo=R.Tensor((1, 64, 14, 14, 4), dtype="float32"),
)
lv659 = R.call_tir(
cls.layout_transform17,
(stage3_unit5_conv3_weight,),
out_sinfo=R.Tensor((256, 64, 1, 1, 4, 4), dtype="float32"),
)
lv660 = R.call_tir(
cls.contrib_conv2d_NCHWc12,
(lv658, lv659),
out_sinfo=R.Tensor((1, 256, 14, 14, 4), dtype="float32"),
)
lv661 = R.call_tir(
cls.add19,
(lv660, lv609),
out_sinfo=R.Tensor((1, 256, 14, 14, 4), dtype="float32"),
)
lv662 = R.call_tir(
cls.add20,
(
stage3_unit6_bn1_moving_var,
R.const(1.9999999494757503e-05, "float32"),
),
out_sinfo=R.Tensor((1024,), dtype="float32"),
)
lv663 = R.call_tir(
cls.rsqrt5, (lv662,), out_sinfo=R.Tensor((1024,), dtype="float32")
)
lv664 = R.call_tir(
cls.multiply18,
(lv663, stage3_unit6_bn1_gamma),
out_sinfo=R.Tensor((1024,), dtype="float32"),
)
lv665 = R.call_tir(
cls.expand_dims12,
(lv664,),
out_sinfo=R.Tensor((1024, 1, 1), dtype="float32"),
)
lv666 = R.call_tir(
cls.expand_dims13,
(lv665,),
out_sinfo=R.Tensor((1, 1024, 1, 1), dtype="float32"),
)
lv667 = R.call_tir(
cls.layout_transform19,
(lv666,),
out_sinfo=R.Tensor((1, 256, 1, 1, 4), dtype="float32"),
)
lv668 = R.call_tir(
cls.multiply19,
(lv661, lv667),
out_sinfo=R.Tensor((1, 256, 14, 14, 4), dtype="float32"),
)
lv669 = R.call_tir(
cls.negative5,
(stage3_unit6_bn1_moving_mean,),
out_sinfo=R.Tensor((1024,), dtype="float32"),
)
lv670 = R.call_tir(
cls.multiply18,
(lv669, lv664),
out_sinfo=R.Tensor((1024,), dtype="float32"),
)
lv671 = R.call_tir(
cls.add21,
(lv670, stage3_unit6_bn1_beta),
out_sinfo=R.Tensor((1024,), dtype="float32"),
)
lv672 = R.call_tir(
cls.expand_dims12,
(lv671,),
out_sinfo=R.Tensor((1024, 1, 1), dtype="float32"),
)
lv673 = R.call_tir(
cls.expand_dims13,
(lv672,),
out_sinfo=R.Tensor((1, 1024, 1, 1), dtype="float32"),
)
lv674 = R.call_tir(
cls.layout_transform19,
(lv673,),
out_sinfo=R.Tensor((1, 256, 1, 1, 4), dtype="float32"),
)
lv675 = R.call_tir(
cls.add22,
(lv668, lv674),
out_sinfo=R.Tensor((1, 256, 14, 14, 4), dtype="float32"),
)
lv676 = R.call_tir(
cls.relu6,
(lv675,),
out_sinfo=R.Tensor((1, 256, 14, 14, 4), dtype="float32"),
)
lv677 = R.call_tir(
cls.add8,
(
stage3_unit6_bn2_moving_var,
R.const(1.9999999494757503e-05, "float32"),
),
out_sinfo=R.Tensor((256,), dtype="float32"),
)
lv678 = R.call_tir(
cls.rsqrt2, (lv677,), out_sinfo=R.Tensor((256,), dtype="float32")
)
lv679 = R.call_tir(
cls.multiply7,
(lv678, stage3_unit6_bn2_gamma),
out_sinfo=R.Tensor((256,), dtype="float32"),
)
lv680 = R.call_tir(
cls.expand_dims4,
(lv679,),
out_sinfo=R.Tensor((256, 1, 1), dtype="float32"),
)
lv681 = R.call_tir(
cls.squeeze3, (lv680,), out_sinfo=R.Tensor((256,), dtype="float32")
)
lv682 = R.call_tir(
cls.expand_dims11,
(lv681,),
out_sinfo=R.Tensor((256, 1, 1, 1), dtype="float32"),
)
lv683 = R.call_tir(
cls.multiply20,
(stage3_unit6_conv1_weight, lv682),
out_sinfo=R.Tensor((256, 1024, 1, 1), dtype="float32"),
)
lv684 = R.call_tir(
cls.layout_transform20,
(lv683,),
out_sinfo=R.Tensor((64, 256, 1, 1, 4, 4), dtype="float32"),
)
lv685 = R.call_tir(
cls.contrib_conv2d_NCHWc14,
(lv676, lv684),
out_sinfo=R.Tensor((1, 64, 14, 14, 4), dtype="float32"),
)
lv686 = R.call_tir(
cls.negative2,
(stage3_unit6_bn2_moving_mean,),
out_sinfo=R.Tensor((256,), dtype="float32"),
)
lv687 = R.call_tir(
cls.multiply7,
(lv686, lv679),
out_sinfo=R.Tensor((256,), dtype="float32"),
)
lv688 = R.call_tir(
cls.add9,
(lv687, stage3_unit6_bn2_beta),
out_sinfo=R.Tensor((256,), dtype="float32"),
)
lv689 = R.call_tir(
cls.expand_dims4,
(lv688,),
out_sinfo=R.Tensor((256, 1, 1), dtype="float32"),
)
lv690 = R.call_tir(
cls.expand_dims5,
(lv689,),
out_sinfo=R.Tensor((1, 256, 1, 1), dtype="float32"),
)
lv691 = R.call_tir(
cls.layout_transform6,
(lv690,),
out_sinfo=R.Tensor((1, 64, 1, 1, 4), dtype="float32"),
)
lv692 = R.call_tir(
cls.add18,
(lv685, lv691),
out_sinfo=R.Tensor((1, 64, 14, 14, 4), dtype="float32"),
)
lv693 = R.call_tir(
cls.relu5,
(lv692,),
out_sinfo=R.Tensor((1, 64, 14, 14, 4), dtype="float32"),
)
lv694 = R.call_tir(
cls.add8,
(
stage3_unit6_bn3_moving_var,
R.const(1.9999999494757503e-05, "float32"),
),
out_sinfo=R.Tensor((256,), dtype="float32"),
)
lv695 = R.call_tir(
cls.rsqrt2, (lv694,), out_sinfo=R.Tensor((256,), dtype="float32")
)
lv696 = R.call_tir(
cls.multiply7,
(lv695, stage3_unit6_bn3_gamma),
out_sinfo=R.Tensor((256,), dtype="float32"),
)
lv697 = R.call_tir(
cls.expand_dims4,
(lv696,),
out_sinfo=R.Tensor((256, 1, 1), dtype="float32"),
)
lv698 = R.call_tir(
cls.squeeze3, (lv697,), out_sinfo=R.Tensor((256,), dtype="float32")
)
lv699 = R.call_tir(
cls.expand_dims11,
(lv698,),
out_sinfo=R.Tensor((256, 1, 1, 1), dtype="float32"),
)
lv700 = R.call_tir(
cls.multiply17,
(stage3_unit6_conv2_weight, lv699),
out_sinfo=R.Tensor((256, 256, 3, 3), dtype="float32"),
)
lv701 = R.call_tir(
cls.layout_transform16,
(lv700,),
out_sinfo=R.Tensor((64, 64, 3, 3, 4, 4), dtype="float32"),
)
lv702 = R.call_tir(
cls.contrib_conv2d_NCHWc11,
(lv693, lv701),
out_sinfo=R.Tensor((1, 64, 14, 14, 4), dtype="float32"),
)
lv703 = R.call_tir(
cls.negative2,
(stage3_unit6_bn3_moving_mean,),
out_sinfo=R.Tensor((256,), dtype="float32"),
)
lv704 = R.call_tir(
cls.multiply7,
(lv703, lv696),
out_sinfo=R.Tensor((256,), dtype="float32"),
)
lv705 = R.call_tir(
cls.add9,
(lv704, stage3_unit6_bn3_beta),
out_sinfo=R.Tensor((256,), dtype="float32"),
)
lv706 = R.call_tir(
cls.expand_dims4,
(lv705,),
out_sinfo=R.Tensor((256, 1, 1), dtype="float32"),
)
lv707 = R.call_tir(
cls.expand_dims5,
(lv706,),
out_sinfo=R.Tensor((1, 256, 1, 1), dtype="float32"),
)
lv708 = R.call_tir(
cls.layout_transform6,
(lv707,),
out_sinfo=R.Tensor((1, 64, 1, 1, 4), dtype="float32"),
)
lv709 = R.call_tir(
cls.add18,
(lv702, lv708),
out_sinfo=R.Tensor((1, 64, 14, 14, 4), dtype="float32"),
)
lv710 = R.call_tir(
cls.relu5,
(lv709,),
out_sinfo=R.Tensor((1, 64, 14, 14, 4), dtype="float32"),
)
lv711 = R.call_tir(
cls.layout_transform17,
(stage3_unit6_conv3_weight,),
out_sinfo=R.Tensor((256, 64, 1, 1, 4, 4), dtype="float32"),
)
lv712 = R.call_tir(
cls.contrib_conv2d_NCHWc12,
(lv710, lv711),
out_sinfo=R.Tensor((1, 256, 14, 14, 4), dtype="float32"),
)
lv713 = R.call_tir(
cls.add19,
(lv712, lv661),
out_sinfo=R.Tensor((1, 256, 14, 14, 4), dtype="float32"),
)
lv714 = R.call_tir(
cls.add20,
(
stage4_unit1_bn1_moving_var,
R.const(1.9999999494757503e-05, "float32"),
),
out_sinfo=R.Tensor((1024,), dtype="float32"),
)
lv715 = R.call_tir(
cls.rsqrt5, (lv714,), out_sinfo=R.Tensor((1024,), dtype="float32")
)
lv716 = R.call_tir(
cls.multiply18,
(lv715, stage4_unit1_bn1_gamma),
out_sinfo=R.Tensor((1024,), dtype="float32"),
)
lv717 = R.call_tir(
cls.expand_dims12,
(lv716,),
out_sinfo=R.Tensor((1024, 1, 1), dtype="float32"),
)
lv718 = R.call_tir(
cls.expand_dims13,
(lv717,),
out_sinfo=R.Tensor((1, 1024, 1, 1), dtype="float32"),
)
lv719 = R.call_tir(
cls.layout_transform19,
(lv718,),
out_sinfo=R.Tensor((1, 256, 1, 1, 4), dtype="float32"),
)
lv720 = R.call_tir(
cls.multiply19,
(lv713, lv719),
out_sinfo=R.Tensor((1, 256, 14, 14, 4), dtype="float32"),
)
lv721 = R.call_tir(
cls.negative5,
(stage4_unit1_bn1_moving_mean,),
out_sinfo=R.Tensor((1024,), dtype="float32"),
)
lv722 = R.call_tir(
cls.multiply18,
(lv721, lv716),
out_sinfo=R.Tensor((1024,), dtype="float32"),
)
lv723 = R.call_tir(
cls.add21,
(lv722, stage4_unit1_bn1_beta),
out_sinfo=R.Tensor((1024,), dtype="float32"),
)
lv724 = R.call_tir(
cls.expand_dims12,
(lv723,),
out_sinfo=R.Tensor((1024, 1, 1), dtype="float32"),
)
lv725 = R.call_tir(
cls.expand_dims13,
(lv724,),
out_sinfo=R.Tensor((1, 1024, 1, 1), dtype="float32"),
)
lv726 = R.call_tir(
cls.layout_transform19,
(lv725,),
out_sinfo=R.Tensor((1, 256, 1, 1, 4), dtype="float32"),
)
lv727 = R.call_tir(
cls.add22,
(lv720, lv726),
out_sinfo=R.Tensor((1, 256, 14, 14, 4), dtype="float32"),
)
lv728 = R.call_tir(
cls.relu6,
(lv727,),
out_sinfo=R.Tensor((1, 256, 14, 14, 4), dtype="float32"),
)
lv729 = R.call_tir(
cls.add15,
(
stage4_unit1_bn2_moving_var,
R.const(1.9999999494757503e-05, "float32"),
),
out_sinfo=R.Tensor((512,), dtype="float32"),
)
lv730 = R.call_tir(
cls.rsqrt4, (lv729,), out_sinfo=R.Tensor((512,), dtype="float32")
)
lv731 = R.call_tir(
cls.multiply13,
(lv730, stage4_unit1_bn2_gamma),
out_sinfo=R.Tensor((512,), dtype="float32"),
)
lv732 = R.call_tir(
cls.expand_dims9,
(lv731,),
out_sinfo=R.Tensor((512, 1, 1), dtype="float32"),
)
lv733 = R.call_tir(
cls.squeeze4, (lv732,), out_sinfo=R.Tensor((512,), dtype="float32")
)
lv734 = R.call_tir(
cls.expand_dims14,
(lv733,),
out_sinfo=R.Tensor((512, 1, 1, 1), dtype="float32"),
)
lv735 = R.call_tir(
cls.multiply21,
(stage4_unit1_conv1_weight, lv734),
out_sinfo=R.Tensor((512, 1024, 1, 1), dtype="float32"),
)
lv736 = R.call_tir(
cls.layout_transform21,
(lv735,),
out_sinfo=R.Tensor((128, 256, 1, 1, 4, 4), dtype="float32"),
)
lv737 = R.call_tir(
cls.contrib_conv2d_NCHWc15,
(lv728, lv736),
out_sinfo=R.Tensor((1, 128, 7, 7, 4), dtype="float32"),
)
lv738 = R.call_tir(
cls.negative4,
(stage4_unit1_bn2_moving_mean,),
out_sinfo=R.Tensor((512,), dtype="float32"),
)
lv739 = R.call_tir(
cls.multiply13,
(lv738, lv731),
out_sinfo=R.Tensor((512,), dtype="float32"),
)
lv740 = R.call_tir(
cls.add16,
(lv739, stage4_unit1_bn2_beta),
out_sinfo=R.Tensor((512,), dtype="float32"),
)
lv741 = R.call_tir(
cls.expand_dims9,
(lv740,),
out_sinfo=R.Tensor((512, 1, 1), dtype="float32"),
)
lv742 = R.call_tir(
cls.expand_dims10,
(lv741,),
out_sinfo=R.Tensor((1, 512, 1, 1), dtype="float32"),
)
lv743 = R.call_tir(
cls.layout_transform13,
(lv742,),
out_sinfo=R.Tensor((1, 128, 1, 1, 4), dtype="float32"),
)
lv744 = R.call_tir(
cls.add23,
(lv737, lv743),
out_sinfo=R.Tensor((1, 128, 7, 7, 4), dtype="float32"),
)
lv745 = R.call_tir(
cls.relu7,
(lv744,),
out_sinfo=R.Tensor((1, 128, 7, 7, 4), dtype="float32"),
)
lv746 = R.call_tir(
cls.add15,
(
stage4_unit1_bn3_moving_var,
R.const(1.9999999494757503e-05, "float32"),
),
out_sinfo=R.Tensor((512,), dtype="float32"),
)
lv747 = R.call_tir(
cls.rsqrt4, (lv746,), out_sinfo=R.Tensor((512,), dtype="float32")
)
lv748 = R.call_tir(
cls.multiply13,
(lv747, stage4_unit1_bn3_gamma),
out_sinfo=R.Tensor((512,), dtype="float32"),
)
lv749 = R.call_tir(
cls.expand_dims9,
(lv748,),
out_sinfo=R.Tensor((512, 1, 1), dtype="float32"),
)
lv750 = R.call_tir(
cls.squeeze4, (lv749,), out_sinfo=R.Tensor((512,), dtype="float32")
)
lv751 = R.call_tir(
cls.expand_dims14,
(lv750,),
out_sinfo=R.Tensor((512, 1, 1, 1), dtype="float32"),
)
lv752 = R.call_tir(
cls.multiply22,
(stage4_unit1_conv2_weight, lv751),
out_sinfo=R.Tensor((512, 512, 3, 3), dtype="float32"),
)
lv753 = R.call_tir(
cls.layout_transform22,
(lv752,),
out_sinfo=R.Tensor((128, 128, 3, 3, 4, 4), dtype="float32"),
)
lv754 = R.call_tir(
cls.contrib_conv2d_NCHWc16,
(lv745, lv753),
out_sinfo=R.Tensor((1, 128, 7, 7, 4), dtype="float32"),
)
lv755 = R.call_tir(
cls.negative4,
(stage4_unit1_bn3_moving_mean,),
out_sinfo=R.Tensor((512,), dtype="float32"),
)
lv756 = R.call_tir(
cls.multiply13,
(lv755, lv748),
out_sinfo=R.Tensor((512,), dtype="float32"),
)
lv757 = R.call_tir(
cls.add16,
(lv756, stage4_unit1_bn3_beta),
out_sinfo=R.Tensor((512,), dtype="float32"),
)
lv758 = R.call_tir(
cls.expand_dims9,
(lv757,),
out_sinfo=R.Tensor((512, 1, 1), dtype="float32"),
)
lv759 = R.call_tir(
cls.expand_dims10,
(lv758,),
out_sinfo=R.Tensor((1, 512, 1, 1), dtype="float32"),
)
lv760 = R.call_tir(
cls.layout_transform13,
(lv759,),
out_sinfo=R.Tensor((1, 128, 1, 1, 4), dtype="float32"),
)
lv761 = R.call_tir(
cls.add23,
(lv754, lv760),
out_sinfo=R.Tensor((1, 128, 7, 7, 4), dtype="float32"),
)
lv762 = R.call_tir(
cls.relu7,
(lv761,),
out_sinfo=R.Tensor((1, 128, 7, 7, 4), dtype="float32"),
)
lv763 = R.call_tir(
cls.layout_transform23,
(stage4_unit1_conv3_weight,),
out_sinfo=R.Tensor((512, 128, 1, 1, 4, 4), dtype="float32"),
)
lv764 = R.call_tir(
cls.contrib_conv2d_NCHWc17,
(lv762, lv763),
out_sinfo=R.Tensor((1, 512, 7, 7, 4), dtype="float32"),
)
lv765 = R.call_tir(
cls.layout_transform24,
(stage4_unit1_sc_weight,),
out_sinfo=R.Tensor((512, 256, 1, 1, 4, 4), dtype="float32"),
)
lv766 = R.call_tir(
cls.contrib_conv2d_NCHWc18,
(lv728, lv765),
out_sinfo=R.Tensor((1, 512, 7, 7, 4), dtype="float32"),
)
lv767 = R.call_tir(
cls.add24,
(lv764, lv766),
out_sinfo=R.Tensor((1, 512, 7, 7, 4), dtype="float32"),
)
lv768 = R.call_tir(
cls.add25,
(
stage4_unit2_bn1_moving_var,
R.const(1.9999999494757503e-05, "float32"),
),
out_sinfo=R.Tensor((2048,), dtype="float32"),
)
lv769 = R.call_tir(
cls.rsqrt6, (lv768,), out_sinfo=R.Tensor((2048,), dtype="float32")
)
lv770 = R.call_tir(
cls.multiply23,
(lv769, stage4_unit2_bn1_gamma),
out_sinfo=R.Tensor((2048,), dtype="float32"),
)
lv771 = R.call_tir(
cls.expand_dims15,
(lv770,),
out_sinfo=R.Tensor((2048, 1, 1), dtype="float32"),
)
lv772 = R.call_tir(
cls.expand_dims16,
(lv771,),
out_sinfo=R.Tensor((1, 2048, 1, 1), dtype="float32"),
)
lv773 = R.call_tir(
cls.layout_transform25,
(lv772,),
out_sinfo=R.Tensor((1, 512, 1, 1, 4), dtype="float32"),
)
lv774 = R.call_tir(
cls.multiply24,
(lv767, lv773),
out_sinfo=R.Tensor((1, 512, 7, 7, 4), dtype="float32"),
)
lv775 = R.call_tir(
cls.negative6,
(stage4_unit2_bn1_moving_mean,),
out_sinfo=R.Tensor((2048,), dtype="float32"),
)
lv776 = R.call_tir(
cls.multiply23,
(lv775, lv770),
out_sinfo=R.Tensor((2048,), dtype="float32"),
)
lv777 = R.call_tir(
cls.add26,
(lv776, stage4_unit2_bn1_beta),
out_sinfo=R.Tensor((2048,), dtype="float32"),
)
lv778 = R.call_tir(
cls.expand_dims15,
(lv777,),
out_sinfo=R.Tensor((2048, 1, 1), dtype="float32"),
)
lv779 = R.call_tir(
cls.expand_dims16,
(lv778,),
out_sinfo=R.Tensor((1, 2048, 1, 1), dtype="float32"),
)
lv780 = R.call_tir(
cls.layout_transform25,
(lv779,),
out_sinfo=R.Tensor((1, 512, 1, 1, 4), dtype="float32"),
)
lv781 = R.call_tir(
cls.add27,
(lv774, lv780),
out_sinfo=R.Tensor((1, 512, 7, 7, 4), dtype="float32"),
)
lv782 = R.call_tir(
cls.relu8,
(lv781,),
out_sinfo=R.Tensor((1, 512, 7, 7, 4), dtype="float32"),
)
lv783 = R.call_tir(
cls.add15,
(
stage4_unit2_bn2_moving_var,
R.const(1.9999999494757503e-05, "float32"),
),
out_sinfo=R.Tensor((512,), dtype="float32"),
)
lv784 = R.call_tir(
cls.rsqrt4, (lv783,), out_sinfo=R.Tensor((512,), dtype="float32")
)
lv785 = R.call_tir(
cls.multiply13,
(lv784, stage4_unit2_bn2_gamma),
out_sinfo=R.Tensor((512,), dtype="float32"),
)
lv786 = R.call_tir(
cls.expand_dims9,
(lv785,),
out_sinfo=R.Tensor((512, 1, 1), dtype="float32"),
)
lv787 = R.call_tir(
cls.squeeze4, (lv786,), out_sinfo=R.Tensor((512,), dtype="float32")
)
lv788 = R.call_tir(
cls.expand_dims14,
(lv787,),
out_sinfo=R.Tensor((512, 1, 1, 1), dtype="float32"),
)
lv789 = R.call_tir(
cls.multiply25,
(stage4_unit2_conv1_weight, lv788),
out_sinfo=R.Tensor((512, 2048, 1, 1), dtype="float32"),
)
lv790 = R.call_tir(
cls.layout_transform26,
(lv789,),
out_sinfo=R.Tensor((128, 512, 1, 1, 4, 4), dtype="float32"),
)
lv791 = R.call_tir(
cls.contrib_conv2d_NCHWc19,
(lv782, lv790),
out_sinfo=R.Tensor((1, 128, 7, 7, 4), dtype="float32"),
)
lv792 = R.call_tir(
cls.negative4,
(stage4_unit2_bn2_moving_mean,),
out_sinfo=R.Tensor((512,), dtype="float32"),
)
lv793 = R.call_tir(
cls.multiply13,
(lv792, lv785),
out_sinfo=R.Tensor((512,), dtype="float32"),
)
lv794 = R.call_tir(
cls.add16,
(lv793, stage4_unit2_bn2_beta),
out_sinfo=R.Tensor((512,), dtype="float32"),
)
lv795 = R.call_tir(
cls.expand_dims9,
(lv794,),
out_sinfo=R.Tensor((512, 1, 1), dtype="float32"),
)
lv796 = R.call_tir(
cls.expand_dims10,
(lv795,),
out_sinfo=R.Tensor((1, 512, 1, 1), dtype="float32"),
)
lv797 = R.call_tir(
cls.layout_transform13,
(lv796,),
out_sinfo=R.Tensor((1, 128, 1, 1, 4), dtype="float32"),
)
lv798 = R.call_tir(
cls.add23,
(lv791, lv797),
out_sinfo=R.Tensor((1, 128, 7, 7, 4), dtype="float32"),
)
lv799 = R.call_tir(
cls.relu7,
(lv798,),
out_sinfo=R.Tensor((1, 128, 7, 7, 4), dtype="float32"),
)
lv800 = R.call_tir(
cls.add15,
(
stage4_unit2_bn3_moving_var,
R.const(1.9999999494757503e-05, "float32"),
),
out_sinfo=R.Tensor((512,), dtype="float32"),
)
lv801 = R.call_tir(
cls.rsqrt4, (lv800,), out_sinfo=R.Tensor((512,), dtype="float32")
)
lv802 = R.call_tir(
cls.multiply13,
(lv801, stage4_unit2_bn3_gamma),
out_sinfo=R.Tensor((512,), dtype="float32"),
)
lv803 = R.call_tir(
cls.expand_dims9,
(lv802,),
out_sinfo=R.Tensor((512, 1, 1), dtype="float32"),
)
lv804 = R.call_tir(
cls.squeeze4, (lv803,), out_sinfo=R.Tensor((512,), dtype="float32")
)
lv805 = R.call_tir(
cls.expand_dims14,
(lv804,),
out_sinfo=R.Tensor((512, 1, 1, 1), dtype="float32"),
)
lv806 = R.call_tir(
cls.multiply22,
(stage4_unit2_conv2_weight, lv805),
out_sinfo=R.Tensor((512, 512, 3, 3), dtype="float32"),
)
lv807 = R.call_tir(
cls.layout_transform22,
(lv806,),
out_sinfo=R.Tensor((128, 128, 3, 3, 4, 4), dtype="float32"),
)
lv808 = R.call_tir(
cls.contrib_conv2d_NCHWc16,
(lv799, lv807),
out_sinfo=R.Tensor((1, 128, 7, 7, 4), dtype="float32"),
)
lv809 = R.call_tir(
cls.negative4,
(stage4_unit2_bn3_moving_mean,),
out_sinfo=R.Tensor((512,), dtype="float32"),
)
lv810 = R.call_tir(
cls.multiply13,
(lv809, lv802),
out_sinfo=R.Tensor((512,), dtype="float32"),
)
lv811 = R.call_tir(
cls.add16,
(lv810, stage4_unit2_bn3_beta),
out_sinfo=R.Tensor((512,), dtype="float32"),
)
lv812 = R.call_tir(
cls.expand_dims9,
(lv811,),
out_sinfo=R.Tensor((512, 1, 1), dtype="float32"),
)
lv813 = R.call_tir(
cls.expand_dims10,
(lv812,),
out_sinfo=R.Tensor((1, 512, 1, 1), dtype="float32"),
)
lv814 = R.call_tir(
cls.layout_transform13,
(lv813,),
out_sinfo=R.Tensor((1, 128, 1, 1, 4), dtype="float32"),
)
lv815 = R.call_tir(
cls.add23,
(lv808, lv814),
out_sinfo=R.Tensor((1, 128, 7, 7, 4), dtype="float32"),
)
lv816 = R.call_tir(
cls.relu7,
(lv815,),
out_sinfo=R.Tensor((1, 128, 7, 7, 4), dtype="float32"),
)
lv817 = R.call_tir(
cls.layout_transform23,
(stage4_unit2_conv3_weight,),
out_sinfo=R.Tensor((512, 128, 1, 1, 4, 4), dtype="float32"),
)
lv818 = R.call_tir(
cls.contrib_conv2d_NCHWc17,
(lv816, lv817),
out_sinfo=R.Tensor((1, 512, 7, 7, 4), dtype="float32"),
)
lv819 = R.call_tir(
cls.add24,
(lv818, lv767),
out_sinfo=R.Tensor((1, 512, 7, 7, 4), dtype="float32"),
)
lv820 = R.call_tir(
cls.add25,
(
stage4_unit3_bn1_moving_var,
R.const(1.9999999494757503e-05, "float32"),
),
out_sinfo=R.Tensor((2048,), dtype="float32"),
)
lv821 = R.call_tir(
cls.rsqrt6, (lv820,), out_sinfo=R.Tensor((2048,), dtype="float32")
)
lv822 = R.call_tir(
cls.multiply23,
(lv821, stage4_unit3_bn1_gamma),
out_sinfo=R.Tensor((2048,), dtype="float32"),
)
lv823 = R.call_tir(
cls.expand_dims15,
(lv822,),
out_sinfo=R.Tensor((2048, 1, 1), dtype="float32"),
)
lv824 = R.call_tir(
cls.expand_dims16,
(lv823,),
out_sinfo=R.Tensor((1, 2048, 1, 1), dtype="float32"),
)
lv825 = R.call_tir(
cls.layout_transform25,
(lv824,),
out_sinfo=R.Tensor((1, 512, 1, 1, 4), dtype="float32"),
)
lv826 = R.call_tir(
cls.multiply24,
(lv819, lv825),
out_sinfo=R.Tensor((1, 512, 7, 7, 4), dtype="float32"),
)
lv827 = R.call_tir(
cls.negative6,
(stage4_unit3_bn1_moving_mean,),
out_sinfo=R.Tensor((2048,), dtype="float32"),
)
lv828 = R.call_tir(
cls.multiply23,
(lv827, lv822),
out_sinfo=R.Tensor((2048,), dtype="float32"),
)
lv829 = R.call_tir(
cls.add26,
(lv828, stage4_unit3_bn1_beta),
out_sinfo=R.Tensor((2048,), dtype="float32"),
)
lv830 = R.call_tir(
cls.expand_dims15,
(lv829,),
out_sinfo=R.Tensor((2048, 1, 1), dtype="float32"),
)
lv831 = R.call_tir(
cls.expand_dims16,
(lv830,),
out_sinfo=R.Tensor((1, 2048, 1, 1), dtype="float32"),
)
lv832 = R.call_tir(
cls.layout_transform25,
(lv831,),
out_sinfo=R.Tensor((1, 512, 1, 1, 4), dtype="float32"),
)
lv833 = R.call_tir(
cls.add27,
(lv826, lv832),
out_sinfo=R.Tensor((1, 512, 7, 7, 4), dtype="float32"),
)
lv834 = R.call_tir(
cls.relu8,
(lv833,),
out_sinfo=R.Tensor((1, 512, 7, 7, 4), dtype="float32"),
)
lv835 = R.call_tir(
cls.add15,
(
stage4_unit3_bn2_moving_var,
R.const(1.9999999494757503e-05, "float32"),
),
out_sinfo=R.Tensor((512,), dtype="float32"),
)
lv836 = R.call_tir(
cls.rsqrt4, (lv835,), out_sinfo=R.Tensor((512,), dtype="float32")
)
lv837 = R.call_tir(
cls.multiply13,
(lv836, stage4_unit3_bn2_gamma),
out_sinfo=R.Tensor((512,), dtype="float32"),
)
lv838 = R.call_tir(
cls.expand_dims9,
(lv837,),
out_sinfo=R.Tensor((512, 1, 1), dtype="float32"),
)
lv839 = R.call_tir(
cls.squeeze4, (lv838,), out_sinfo=R.Tensor((512,), dtype="float32")
)
lv840 = R.call_tir(
cls.expand_dims14,
(lv839,),
out_sinfo=R.Tensor((512, 1, 1, 1), dtype="float32"),
)
lv841 = R.call_tir(
cls.multiply25,
(stage4_unit3_conv1_weight, lv840),
out_sinfo=R.Tensor((512, 2048, 1, 1), dtype="float32"),
)
lv842 = R.call_tir(
cls.layout_transform26,
(lv841,),
out_sinfo=R.Tensor((128, 512, 1, 1, 4, 4), dtype="float32"),
)
lv843 = R.call_tir(
cls.contrib_conv2d_NCHWc19,
(lv834, lv842),
out_sinfo=R.Tensor((1, 128, 7, 7, 4), dtype="float32"),
)
lv844 = R.call_tir(
cls.negative4,
(stage4_unit3_bn2_moving_mean,),
out_sinfo=R.Tensor((512,), dtype="float32"),
)
lv845 = R.call_tir(
cls.multiply13,
(lv844, lv837),
out_sinfo=R.Tensor((512,), dtype="float32"),
)
lv846 = R.call_tir(
cls.add16,
(lv845, stage4_unit3_bn2_beta),
out_sinfo=R.Tensor((512,), dtype="float32"),
)
lv847 = R.call_tir(
cls.expand_dims9,
(lv846,),
out_sinfo=R.Tensor((512, 1, 1), dtype="float32"),
)
lv848 = R.call_tir(
cls.expand_dims10,
(lv847,),
out_sinfo=R.Tensor((1, 512, 1, 1), dtype="float32"),
)
lv849 = R.call_tir(
cls.layout_transform13,
(lv848,),
out_sinfo=R.Tensor((1, 128, 1, 1, 4), dtype="float32"),
)
lv850 = R.call_tir(
cls.add23,
(lv843, lv849),
out_sinfo=R.Tensor((1, 128, 7, 7, 4), dtype="float32"),
)
lv851 = R.call_tir(
cls.relu7,
(lv850,),
out_sinfo=R.Tensor((1, 128, 7, 7, 4), dtype="float32"),
)
lv852 = R.call_tir(
cls.add15,
(
stage4_unit3_bn3_moving_var,
R.const(1.9999999494757503e-05, "float32"),
),
out_sinfo=R.Tensor((512,), dtype="float32"),
)
lv853 = R.call_tir(
cls.rsqrt4, (lv852,), out_sinfo=R.Tensor((512,), dtype="float32")
)
lv854 = R.call_tir(
cls.multiply13,
(lv853, stage4_unit3_bn3_gamma),
out_sinfo=R.Tensor((512,), dtype="float32"),
)
lv855 = R.call_tir(
cls.expand_dims9,
(lv854,),
out_sinfo=R.Tensor((512, 1, 1), dtype="float32"),
)
lv856 = R.call_tir(
cls.squeeze4, (lv855,), out_sinfo=R.Tensor((512,), dtype="float32")
)
lv857 = R.call_tir(
cls.expand_dims14,
(lv856,),
out_sinfo=R.Tensor((512, 1, 1, 1), dtype="float32"),
)
lv858 = R.call_tir(
cls.multiply22,
(stage4_unit3_conv2_weight, lv857),
out_sinfo=R.Tensor((512, 512, 3, 3), dtype="float32"),
)
lv859 = R.call_tir(
cls.layout_transform22,
(lv858,),
out_sinfo=R.Tensor((128, 128, 3, 3, 4, 4), dtype="float32"),
)
lv860 = R.call_tir(
cls.contrib_conv2d_NCHWc16,
(lv851, lv859),
out_sinfo=R.Tensor((1, 128, 7, 7, 4), dtype="float32"),
)
lv861 = R.call_tir(
cls.negative4,
(stage4_unit3_bn3_moving_mean,),
out_sinfo=R.Tensor((512,), dtype="float32"),
)
lv862 = R.call_tir(
cls.multiply13,
(lv861, lv854),
out_sinfo=R.Tensor((512,), dtype="float32"),
)
lv863 = R.call_tir(
cls.add16,
(lv862, stage4_unit3_bn3_beta),
out_sinfo=R.Tensor((512,), dtype="float32"),
)
lv864 = R.call_tir(
cls.expand_dims9,
(lv863,),
out_sinfo=R.Tensor((512, 1, 1), dtype="float32"),
)
lv865 = R.call_tir(
cls.expand_dims10,
(lv864,),
out_sinfo=R.Tensor((1, 512, 1, 1), dtype="float32"),
)
lv866 = R.call_tir(
cls.layout_transform13,
(lv865,),
out_sinfo=R.Tensor((1, 128, 1, 1, 4), dtype="float32"),
)
lv867 = R.call_tir(
cls.add23,
(lv860, lv866),
out_sinfo=R.Tensor((1, 128, 7, 7, 4), dtype="float32"),
)
lv868 = R.call_tir(
cls.relu7,
(lv867,),
out_sinfo=R.Tensor((1, 128, 7, 7, 4), dtype="float32"),
)
lv869 = R.call_tir(
cls.layout_transform23,
(stage4_unit3_conv3_weight,),
out_sinfo=R.Tensor((512, 128, 1, 1, 4, 4), dtype="float32"),
)
lv870 = R.call_tir(
cls.contrib_conv2d_NCHWc17,
(lv868, lv869),
out_sinfo=R.Tensor((1, 512, 7, 7, 4), dtype="float32"),
)
lv871 = R.call_tir(
cls.add24,
(lv870, lv819),
out_sinfo=R.Tensor((1, 512, 7, 7, 4), dtype="float32"),
)
lv872 = R.call_tir(
cls.add25,
(bn1_moving_var, R.const(1.9999999494757503e-05, "float32")),
out_sinfo=R.Tensor((2048,), dtype="float32"),
)
lv873 = R.call_tir(
cls.rsqrt6, (lv872,), out_sinfo=R.Tensor((2048,), dtype="float32")
)
lv874 = R.call_tir(
cls.multiply23,
(lv873, bn1_gamma),
out_sinfo=R.Tensor((2048,), dtype="float32"),
)
lv875 = R.call_tir(
cls.expand_dims15,
(lv874,),
out_sinfo=R.Tensor((2048, 1, 1), dtype="float32"),
)
lv876 = R.call_tir(
cls.expand_dims16,
(lv875,),
out_sinfo=R.Tensor((1, 2048, 1, 1), dtype="float32"),
)
lv877 = R.call_tir(
cls.layout_transform25,
(lv876,),
out_sinfo=R.Tensor((1, 512, 1, 1, 4), dtype="float32"),
)
lv878 = R.call_tir(
cls.multiply24,
(lv871, lv877),
out_sinfo=R.Tensor((1, 512, 7, 7, 4), dtype="float32"),
)
lv879 = R.call_tir(
cls.negative6,
(bn1_moving_mean,),
out_sinfo=R.Tensor((2048,), dtype="float32"),
)
lv880 = R.call_tir(
cls.multiply23,
(lv879, lv874),
out_sinfo=R.Tensor((2048,), dtype="float32"),
)
lv881 = R.call_tir(
cls.add26,
(lv880, bn1_beta),
out_sinfo=R.Tensor((2048,), dtype="float32"),
)
lv882 = R.call_tir(
cls.expand_dims15,
(lv881,),
out_sinfo=R.Tensor((2048, 1, 1), dtype="float32"),
)
lv883 = R.call_tir(
cls.expand_dims16,
(lv882,),
out_sinfo=R.Tensor((1, 2048, 1, 1), dtype="float32"),
)
lv884 = R.call_tir(
cls.layout_transform25,
(lv883,),
out_sinfo=R.Tensor((1, 512, 1, 1, 4), dtype="float32"),
)
lv885 = R.call_tir(
cls.add27,
(lv878, lv884),
out_sinfo=R.Tensor((1, 512, 7, 7, 4), dtype="float32"),
)
lv886 = R.call_tir(
cls.relu8,
(lv885,),
out_sinfo=R.Tensor((1, 512, 7, 7, 4), dtype="float32"),
)
lv887 = R.call_tir(
cls.global_avg_pool2d,
(lv886,),
out_sinfo=R.Tensor((1, 512, 1, 1, 4), dtype="float32"),
)
lv888 = R.call_tir(
cls.layout_transform27,
(lv887,),
out_sinfo=R.Tensor((1, 2048, 1, 1), dtype="float32"),
)
lv889 = R.call_tir(
cls.batch_flatten,
(lv888,),
out_sinfo=R.Tensor((1, 2048), dtype="float32"),
)
lv890 = R.call_tir(
cls.dense,
(lv889, fc1_weight),
out_sinfo=R.Tensor((1, 1000), dtype="float32"),
)
lv891 = R.call_tir(
cls.expand_dims17,
(fc1_bias,),
out_sinfo=R.Tensor((1, 1000), dtype="float32"),
)
lv892 = R.call_tir(
cls.add28,
(lv890, lv891),
out_sinfo=R.Tensor((1, 1000), dtype="float32"),
)
lv893 = R.call_tir(
cls.softmax, (lv892,), out_sinfo=R.Tensor((1, 1000), dtype="float32")
)
gv: R.Tensor((1, 1000), dtype="float32") = lv893
R.output(gv)
return gv