DPL 模式上下文#
from tvm.script import relax as R
from tvm.script import tir as T
from tvm import relax as rx
from tvm import relay, tir
from tvm.relax.analysis import get_var2val
from tvm.relax.dpl import *
QKV_proj#
@tvm.script.ir_module
class QKV_proj:
@R.function
def main(
x: R.Tensor((2, 1024, 640), "float32"),
w0: R.Tensor((640, 640), "float32"),
w1: R.Tensor((640, 640), "float32"),
w2: R.Tensor((640, 640), "float32"),
) -> R.Tensor:
with R.dataflow():
lv0 = R.matmul(x, w0)
lv1 = R.matmul(x, w1)
lv2 = R.matmul(x, w2)
out = (lv0, lv1, lv2)
R.output(out)
return out
with PatternContext() as ctx:
inp_pat = wildcard()
Q_weight_pat = wildcard()
K_weight_pat = wildcard()
V_weight_pat = wildcard()
matmul1 = is_op("relax.matmul")(inp_pat, Q_weight_pat)
matmul2 = is_op("relax.matmul")(inp_pat, K_weight_pat)
matmul3 = is_op("relax.matmul")(inp_pat, V_weight_pat)
dfb = QKV_proj["main"].body.blocks[0]
out = ctx.match_dfb(dfb)
assert out[Q_weight_pat].name_hint == "w0"
assert out[K_weight_pat].name_hint == "w1"
assert out[V_weight_pat].name_hint == "w2"
def test_attention_fake_qkv():
@tvm.script.ir_module
class QKV_proj:
@R.function
def main(
x1: R.Tensor((2, 1024, 640), "float32"),
x2: R.Tensor((2, 1024, 640), "float32"),
w0: R.Tensor((640, 640), "float32"),
w1: R.Tensor((640, 640), "float32"),
w2: R.Tensor((640, 640), "float32"),
) -> R.Tensor:
with R.dataflow():
lv0 = R.matmul(x1, w0)
lv1 = R.matmul(x2, w1)
lv2 = R.matmul(x2, w2)
out = (lv0, lv1, lv2)
R.output(out)
return out
with PatternContext() as ctx:
inp_pat = wildcard()
Q_weight_pat = wildcard()
K_weight_pat = wildcard()
V_weight_pat = wildcard()
matmul1 = is_op("relax.matmul")(inp_pat, Q_weight_pat)
matmul2 = is_op("relax.matmul")(inp_pat, K_weight_pat)
matmul3 = is_op("relax.matmul")(inp_pat, V_weight_pat)
dfb = QKV_proj["main"].body.blocks[0]
assert ctx.match_dfb(dfb) is None
重写 QKV_proj#
def get_qkv_proj_rewriter():
with PatternContext() as ctx:
inp_pat = wildcard()
Q_weight_pat = wildcard()
K_weight_pat = wildcard()
V_weight_pat = wildcard()
matmul1 = is_op("relax.matmul")(inp_pat, Q_weight_pat)
matmul2 = is_op("relax.matmul")(inp_pat, K_weight_pat)
matmul3 = is_op("relax.matmul")(inp_pat, V_weight_pat)
def qkv_proj_rewriter(matchings, _):
inp = matchings[inp_pat]
Q_weight = matchings[Q_weight_pat]
K_weight = matchings[K_weight_pat]
V_weight = matchings[V_weight_pat]
width = Q_weight.struct_info.shape[1]
concat = R.concat([Q_weight, K_weight, V_weight], axis=1)
matmul = R.matmul(inp, concat)
Q = R.strided_slice(matmul, axes=[2], begin=[0], end=[width])
K = R.strided_slice(matmul, axes=[2], begin=[width], end=[width * 2])
V = R.strided_slice(matmul, axes=[2], begin=[width * 2], end=[width * 3])
return {matchings[matmul1]: Q, matchings[matmul2]: K, matchings[matmul3]: V}
return ctx, qkv_proj_rewriter
组合矩阵乘法两次#
@R.function(private=True)
def qkv_x2(
x1: R.Tensor((2, 1024, 640), "float32"),
x2: R.Tensor((2, 1024, 640), "float32"),
w0: R.Tensor((640, 640), "float32"),
w1: R.Tensor((640, 640), "float32"),
w2: R.Tensor((640, 640), "float32"),
w3: R.Tensor((640, 640), "float32"),
w4: R.Tensor((640, 640), "float32"),
w5: R.Tensor((640, 640), "float32"),
):
with R.dataflow():
lv0 = R.matmul(x1, w0)
lv1 = R.matmul(x1, w1)
lv2 = R.matmul(x1, w2)
lv3 = R.matmul(x2, w3)
lv4 = R.matmul(x2, w4)
lv5 = R.matmul(x2, w5)
out = (lv0, lv1, lv2, lv3, lv4, lv5)
R.output(out)
return out
qkv_x2.show()
# from tvm.script import relax as R
@R.function(private=True)
def main(x1: R.Tensor((2, 1024, 640), dtype="float32"), x2: R.Tensor((2, 1024, 640), dtype="float32"), w0: R.Tensor((640, 640), dtype="float32"), w1: R.Tensor((640, 640), dtype="float32"), w2: R.Tensor((640, 640), dtype="float32"), w3: R.Tensor((640, 640), dtype="float32"), w4: R.Tensor((640, 640), dtype="float32"), w5: R.Tensor((640, 640), dtype="float32")) -> R.Tuple(R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32")):
with R.dataflow():
lv0: R.Tensor((2, 1024, 640), dtype="float32") = R.matmul(x1, w0, out_dtype="void")
lv1: R.Tensor((2, 1024, 640), dtype="float32") = R.matmul(x1, w1, out_dtype="void")
lv2: R.Tensor((2, 1024, 640), dtype="float32") = R.matmul(x1, w2, out_dtype="void")
lv3: R.Tensor((2, 1024, 640), dtype="float32") = R.matmul(x2, w3, out_dtype="void")
lv4: R.Tensor((2, 1024, 640), dtype="float32") = R.matmul(x2, w4, out_dtype="void")
lv5: R.Tensor((2, 1024, 640), dtype="float32") = R.matmul(x2, w5, out_dtype="void")
out: R.Tuple(R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32")) = lv0, lv1, lv2, lv3, lv4, lv5
R.output(out)
return out
ctx, rewriter = get_qkv_proj_rewriter()
rewritten = rewrite_bindings(ctx, rewriter, qkv_x2)
rewritten.show()
# from tvm.script import relax as R
@R.function(private=True)
def main(x1: R.Tensor((2, 1024, 640), dtype="float32"), x2: R.Tensor((2, 1024, 640), dtype="float32"), w0: R.Tensor((640, 640), dtype="float32"), w1: R.Tensor((640, 640), dtype="float32"), w2: R.Tensor((640, 640), dtype="float32"), w3: R.Tensor((640, 640), dtype="float32"), w4: R.Tensor((640, 640), dtype="float32"), w5: R.Tensor((640, 640), dtype="float32")) -> R.Tuple(R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32")):
with R.dataflow():
lv: R.Tensor((640, 1920), dtype="float32") = R.concat((w0, w1, w2), axis=1)
lv1: R.Tensor((2, 1024, 1920), dtype="float32") = R.matmul(x1, lv, out_dtype="void")
lv0: R.Tensor((2, 1024, 640), dtype="float32") = R.strided_slice(lv1, (R.prim_value(2),), (R.prim_value(0),), (R.prim_value(640),), assume_inbound=False)
lv1_1: R.Tensor((2, 1024, 640), dtype="float32") = R.strided_slice(lv1, (R.prim_value(2),), (R.prim_value(640),), (R.prim_value(1280),), assume_inbound=False)
lv2: R.Tensor((2, 1024, 640), dtype="float32") = R.strided_slice(lv1, (R.prim_value(2),), (R.prim_value(1280),), (R.prim_value(1920),), assume_inbound=False)
lv2_1: R.Tensor((640, 1920), dtype="float32") = R.concat((w3, w4, w5), axis=1)
lv3: R.Tensor((2, 1024, 1920), dtype="float32") = R.matmul(x2, lv2_1, out_dtype="void")
lv3_1: R.Tensor((2, 1024, 640), dtype="float32") = R.strided_slice(lv3, (R.prim_value(2),), (R.prim_value(0),), (R.prim_value(640),), assume_inbound=False)
lv4: R.Tensor((2, 1024, 640), dtype="float32") = R.strided_slice(lv3, (R.prim_value(2),), (R.prim_value(640),), (R.prim_value(1280),), assume_inbound=False)
lv5: R.Tensor((2, 1024, 640), dtype="float32") = R.strided_slice(lv3, (R.prim_value(2),), (R.prim_value(1280),), (R.prim_value(1920),), assume_inbound=False)
out: R.Tuple(R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32")) = lv0, lv1_1, lv2, lv3_1, lv4, lv5
R.output(out)
return out
测试数据流可能以匹配转换开始#
rewrite_bindings
的输入可能包含 R.match_cast
。
这是回归测试。在之前的实现中,当 R.match_cast
是 R.dataflow
块的第一个绑定时,应用 rewrite_bindings
会导致段错误。
@R.function(private=True)
def before(
x_untyped: R.Tensor,
w0_untyped: R.Tensor,
w1_untyped: R.Tensor,
w2_untyped: R.Tensor,
):
with R.dataflow():
x = R.match_cast(x_untyped, R.Tensor((2, 1024, 640), "float32"))
w0 = R.match_cast(w0_untyped, R.Tensor((640, 640), "float32"))
w1 = R.match_cast(w1_untyped, R.Tensor((640, 640), "float32"))
w2 = R.match_cast(w2_untyped, R.Tensor((640, 640), "float32"))
out_0 = R.matmul(x, w0)
out_1 = R.matmul(x, w1)
out_2 = R.matmul(x, w2)
out = (out_0, out_1, out_2)
R.output(out)
return out
before.show()
# from tvm.script import relax as R
@R.function(private=True)
def main(x_untyped: R.Tensor, w0_untyped: R.Tensor, w1_untyped: R.Tensor, w2_untyped: R.Tensor) -> R.Tuple(R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32")):
with R.dataflow():
x: R.Tensor((2, 1024, 640), dtype="float32") = R.match_cast(x_untyped, R.Tensor((2, 1024, 640), dtype="float32"))
w0: R.Tensor((640, 640), dtype="float32") = R.match_cast(w0_untyped, R.Tensor((640, 640), dtype="float32"))
w1: R.Tensor((640, 640), dtype="float32") = R.match_cast(w1_untyped, R.Tensor((640, 640), dtype="float32"))
w2: R.Tensor((640, 640), dtype="float32") = R.match_cast(w2_untyped, R.Tensor((640, 640), dtype="float32"))
out_0: R.Tensor((2, 1024, 640), dtype="float32") = R.matmul(x, w0, out_dtype="void")
out_1: R.Tensor((2, 1024, 640), dtype="float32") = R.matmul(x, w1, out_dtype="void")
out_2: R.Tensor((2, 1024, 640), dtype="float32") = R.matmul(x, w2, out_dtype="void")
out: R.Tuple(R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32")) = out_0, out_1, out_2
R.output(out)
return out
ctx, rewriter = get_qkv_proj_rewriter()
rewritten = rewrite_bindings(ctx, rewriter, before)
rewritten.show()
# from tvm.script import relax as R
@R.function(private=True)
def main(x_untyped: R.Tensor, w0_untyped: R.Tensor, w1_untyped: R.Tensor, w2_untyped: R.Tensor) -> R.Tuple(R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32")):
with R.dataflow():
x: R.Tensor((2, 1024, 640), dtype="float32") = R.match_cast(x_untyped, R.Tensor((2, 1024, 640), dtype="float32"))
w0: R.Tensor((640, 640), dtype="float32") = R.match_cast(w0_untyped, R.Tensor((640, 640), dtype="float32"))
w1: R.Tensor((640, 640), dtype="float32") = R.match_cast(w1_untyped, R.Tensor((640, 640), dtype="float32"))
w2: R.Tensor((640, 640), dtype="float32") = R.match_cast(w2_untyped, R.Tensor((640, 640), dtype="float32"))
lv: R.Tensor((640, 1920), dtype="float32") = R.concat((w0, w1, w2), axis=1)
lv1: R.Tensor((2, 1024, 1920), dtype="float32") = R.matmul(x, lv, out_dtype="void")
out_0: R.Tensor((2, 1024, 640), dtype="float32") = R.strided_slice(lv1, (R.prim_value(2),), (R.prim_value(0),), (R.prim_value(640),), assume_inbound=False)
out_1: R.Tensor((2, 1024, 640), dtype="float32") = R.strided_slice(lv1, (R.prim_value(2),), (R.prim_value(640),), (R.prim_value(1280),), assume_inbound=False)
out_2: R.Tensor((2, 1024, 640), dtype="float32") = R.strided_slice(lv1, (R.prim_value(2),), (R.prim_value(1280),), (R.prim_value(1920),), assume_inbound=False)
out: R.Tuple(R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32")) = out_0, out_1, out_2
R.output(out)
return out
测试组合矩阵乘法的发射顺序#
@R.function(private=True)
def main(
x1: R.Tensor((2, 1024, 640), "float32"),
w0: R.Tensor((640, 640), "float32"),
w1: R.Tensor((640, 640), "float32"),
w2: R.Tensor((640, 640), "float32"),
):
with R.dataflow():
w0_t = R.permute_dims(w0, axes=None)
lv0 = R.matmul(x1, w0_t)
w1_t = R.permute_dims(w1, axes=None)
w1_t_t = R.permute_dims(w1_t, axes=None)
lv1 = R.matmul(x1, w1_t_t)
w2_t = R.permute_dims(w2, axes=None)
lv2 = R.matmul(x1, w2_t)
out = (lv0, lv1, lv2)
R.output(out)
return out
main.show()
# from tvm.script import relax as R
@R.function(private=True)
def main(x1: R.Tensor((2, 1024, 640), dtype="float32"), w0: R.Tensor((640, 640), dtype="float32"), w1: R.Tensor((640, 640), dtype="float32"), w2: R.Tensor((640, 640), dtype="float32")) -> R.Tuple(R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32")):
with R.dataflow():
w0_t: R.Tensor((640, 640), dtype="float32") = R.permute_dims(w0, axes=None)
lv0: R.Tensor((2, 1024, 640), dtype="float32") = R.matmul(x1, w0_t, out_dtype="void")
w1_t: R.Tensor((640, 640), dtype="float32") = R.permute_dims(w1, axes=None)
w1_t_t: R.Tensor((640, 640), dtype="float32") = R.permute_dims(w1_t, axes=None)
lv1: R.Tensor((2, 1024, 640), dtype="float32") = R.matmul(x1, w1_t_t, out_dtype="void")
w2_t: R.Tensor((640, 640), dtype="float32") = R.permute_dims(w2, axes=None)
lv2: R.Tensor((2, 1024, 640), dtype="float32") = R.matmul(x1, w2_t, out_dtype="void")
out: R.Tuple(R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32")) = lv0, lv1, lv2
R.output(out)
return out
ctx, rewriter = get_qkv_proj_rewriter()
rewritten = rewrite_bindings(ctx, rewriter, main)
rewritten.show()
# from tvm.script import relax as R
@R.function(private=True)
def main(x1: R.Tensor((2, 1024, 640), dtype="float32"), w0: R.Tensor((640, 640), dtype="float32"), w1: R.Tensor((640, 640), dtype="float32"), w2: R.Tensor((640, 640), dtype="float32")) -> R.Tuple(R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32")):
with R.dataflow():
w0_t: R.Tensor((640, 640), dtype="float32") = R.permute_dims(w0, axes=None)
w1_t: R.Tensor((640, 640), dtype="float32") = R.permute_dims(w1, axes=None)
w1_t_t: R.Tensor((640, 640), dtype="float32") = R.permute_dims(w1_t, axes=None)
w2_t: R.Tensor((640, 640), dtype="float32") = R.permute_dims(w2, axes=None)
lv: R.Tensor((640, 1920), dtype="float32") = R.concat((w0_t, w1_t_t, w2_t), axis=1)
lv1: R.Tensor((2, 1024, 1920), dtype="float32") = R.matmul(x1, lv, out_dtype="void")
lv0: R.Tensor((2, 1024, 640), dtype="float32") = R.strided_slice(lv1, (R.prim_value(2),), (R.prim_value(0),), (R.prim_value(640),), assume_inbound=False)
lv1_1: R.Tensor((2, 1024, 640), dtype="float32") = R.strided_slice(lv1, (R.prim_value(2),), (R.prim_value(640),), (R.prim_value(1280),), assume_inbound=False)
lv2: R.Tensor((2, 1024, 640), dtype="float32") = R.strided_slice(lv1, (R.prim_value(2),), (R.prim_value(1280),), (R.prim_value(1920),), assume_inbound=False)
out: R.Tuple(R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32")) = lv0, lv1_1, lv2
R.output(out)
return out
# 确保它能够构建
mod = tvm.IRModule()
mod["main"] = rewritten
rx.build(mod, target="llvm")
<tvm.relax.vm_build.Executable at 0x7f4ec7fd4920>
组合转置矩阵乘法两次#
@R.function(private=True)
def main(
x1: R.Tensor((2, 1024, 640), "float32"),
x2: R.Tensor((2, 1024, 640), "float32"),
w0: R.Tensor((640, 640), "float32"),
w1: R.Tensor((640, 640), "float32"),
w2: R.Tensor((640, 640), "float32"),
w3: R.Tensor((640, 640), "float32"),
):
with R.dataflow():
w0_t = R.permute_dims(w0, axes=None)
lv0 = R.matmul(x1, w0_t)
w1_t = R.permute_dims(w1, axes=None)
lv1 = R.matmul(x1, w1_t)
w2_t = R.permute_dims(w2, axes=None)
lv2 = R.matmul(x2, w2_t)
w3_t = R.permute_dims(w3, axes=None)
lv3 = R.matmul(x2, w3_t)
out = (lv0, lv1, lv2, lv3)
R.output(out)
return out
main.show()
# from tvm.script import relax as R
@R.function(private=True)
def main(x1: R.Tensor((2, 1024, 640), dtype="float32"), x2: R.Tensor((2, 1024, 640), dtype="float32"), w0: R.Tensor((640, 640), dtype="float32"), w1: R.Tensor((640, 640), dtype="float32"), w2: R.Tensor((640, 640), dtype="float32"), w3: R.Tensor((640, 640), dtype="float32")) -> R.Tuple(R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32")):
with R.dataflow():
w0_t: R.Tensor((640, 640), dtype="float32") = R.permute_dims(w0, axes=None)
lv0: R.Tensor((2, 1024, 640), dtype="float32") = R.matmul(x1, w0_t, out_dtype="void")
w1_t: R.Tensor((640, 640), dtype="float32") = R.permute_dims(w1, axes=None)
lv1: R.Tensor((2, 1024, 640), dtype="float32") = R.matmul(x1, w1_t, out_dtype="void")
w2_t: R.Tensor((640, 640), dtype="float32") = R.permute_dims(w2, axes=None)
lv2: R.Tensor((2, 1024, 640), dtype="float32") = R.matmul(x2, w2_t, out_dtype="void")
w3_t: R.Tensor((640, 640), dtype="float32") = R.permute_dims(w3, axes=None)
lv3: R.Tensor((2, 1024, 640), dtype="float32") = R.matmul(x2, w3_t, out_dtype="void")
out: R.Tuple(R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32")) = lv0, lv1, lv2, lv3
R.output(out)
return out
with PatternContext() as ctx:
inp_pat = wildcard()
w1_pat = wildcard()
w2_pat = wildcard()
matmul1 = is_op("relax.matmul")(inp_pat, is_op("relax.permute_dims")(w1_pat))
matmul2 = is_op("relax.matmul")(inp_pat, is_op("relax.permute_dims")(w2_pat))
def rewriter(matchings, _):
inp = matchings[inp_pat]
w1 = matchings[w1_pat]
w2 = matchings[w2_pat]
concat = R.concat([w1, w2], axis=0)
matmul = R.matmul(inp, R.permute_dims(concat))
sections = [w1.struct_info.shape[0]]
chunks = R.split(matmul, sections, -1)
return {
matchings[matmul1]: chunks[0],
matchings[matmul2]: chunks[1],
}
rewritten = rewrite_bindings(ctx, rewriter, main)
# make sure it builds
mod = tvm.IRModule()
mod["main"] = rewritten
mod.show()
rx.build(mod, target="llvm")
# from tvm.script import ir as I
# from tvm.script import relax as R
@I.ir_module
class Module:
@R.function(private=True)
def main(x1: R.Tensor((2, 1024, 640), dtype="float32"), x2: R.Tensor((2, 1024, 640), dtype="float32"), w0: R.Tensor((640, 640), dtype="float32"), w1: R.Tensor((640, 640), dtype="float32"), w2: R.Tensor((640, 640), dtype="float32"), w3: R.Tensor((640, 640), dtype="float32")) -> R.Tuple(R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32")):
with R.dataflow():
lv: R.Tensor((1280, 640), dtype="float32") = R.concat((w0, w1), axis=0)
lv1: R.Tensor((640, 1280), dtype="float32") = R.permute_dims(lv, axes=None)
lv2: R.Tensor((2, 1024, 1280), dtype="float32") = R.matmul(x1, lv1, out_dtype="void")
lv3: R.Tuple(R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32")) = R.split(lv2, indices_or_sections=[640], axis=-1)
lv0: R.Tensor((2, 1024, 640), dtype="float32") = lv3[0]
lv1_1: R.Tensor((2, 1024, 640), dtype="float32") = lv3[1]
lv4: R.Tensor((1280, 640), dtype="float32") = R.concat((w2, w3), axis=0)
lv5: R.Tensor((640, 1280), dtype="float32") = R.permute_dims(lv4, axes=None)
lv6: R.Tensor((2, 1024, 1280), dtype="float32") = R.matmul(x2, lv5, out_dtype="void")
lv7: R.Tuple(R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32")) = R.split(lv6, indices_or_sections=[640], axis=-1)
lv2_1: R.Tensor((2, 1024, 640), dtype="float32") = lv7[0]
lv3_1: R.Tensor((2, 1024, 640), dtype="float32") = lv7[1]
out: R.Tuple(R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32")) = lv0, lv1_1, lv2_1, lv3_1
R.output(out)
return out