DPL 模式上下文#

from tvm.script import relax as R
from tvm.script import tir as T
from tvm import relax as rx
from tvm import relay, tir
from tvm.relax.analysis import get_var2val
from tvm.relax.dpl import *

QKV_proj#

@tvm.script.ir_module
class QKV_proj:
    @R.function
    def main(
        x: R.Tensor((2, 1024, 640), "float32"),
        w0: R.Tensor((640, 640), "float32"),
        w1: R.Tensor((640, 640), "float32"),
        w2: R.Tensor((640, 640), "float32"),
    ) -> R.Tensor:
        with R.dataflow():
            lv0 = R.matmul(x, w0)
            lv1 = R.matmul(x, w1)
            lv2 = R.matmul(x, w2)
            out = (lv0, lv1, lv2)
            R.output(out)
        return out
with PatternContext() as ctx:
    inp_pat = wildcard()
    Q_weight_pat = wildcard()
    K_weight_pat = wildcard()
    V_weight_pat = wildcard()

    matmul1 = is_op("relax.matmul")(inp_pat, Q_weight_pat)
    matmul2 = is_op("relax.matmul")(inp_pat, K_weight_pat)
    matmul3 = is_op("relax.matmul")(inp_pat, V_weight_pat)

    dfb = QKV_proj["main"].body.blocks[0]
    out = ctx.match_dfb(dfb)

    assert out[Q_weight_pat].name_hint == "w0"
    assert out[K_weight_pat].name_hint == "w1"
    assert out[V_weight_pat].name_hint == "w2"
def test_attention_fake_qkv():
    @tvm.script.ir_module
    class QKV_proj:
        @R.function
        def main(
            x1: R.Tensor((2, 1024, 640), "float32"),
            x2: R.Tensor((2, 1024, 640), "float32"),
            w0: R.Tensor((640, 640), "float32"),
            w1: R.Tensor((640, 640), "float32"),
            w2: R.Tensor((640, 640), "float32"),
        ) -> R.Tensor:
            with R.dataflow():
                lv0 = R.matmul(x1, w0)
                lv1 = R.matmul(x2, w1)
                lv2 = R.matmul(x2, w2)
                out = (lv0, lv1, lv2)
                R.output(out)
            return out

    with PatternContext() as ctx:
        inp_pat = wildcard()
        Q_weight_pat = wildcard()
        K_weight_pat = wildcard()
        V_weight_pat = wildcard()

        matmul1 = is_op("relax.matmul")(inp_pat, Q_weight_pat)
        matmul2 = is_op("relax.matmul")(inp_pat, K_weight_pat)
        matmul3 = is_op("relax.matmul")(inp_pat, V_weight_pat)

        dfb = QKV_proj["main"].body.blocks[0]
        assert ctx.match_dfb(dfb) is None

重写 QKV_proj#

def get_qkv_proj_rewriter():
    with PatternContext() as ctx:
        inp_pat = wildcard()
        Q_weight_pat = wildcard()
        K_weight_pat = wildcard()
        V_weight_pat = wildcard()

        matmul1 = is_op("relax.matmul")(inp_pat, Q_weight_pat)
        matmul2 = is_op("relax.matmul")(inp_pat, K_weight_pat)
        matmul3 = is_op("relax.matmul")(inp_pat, V_weight_pat)

    def qkv_proj_rewriter(matchings, _):
        inp = matchings[inp_pat]
        Q_weight = matchings[Q_weight_pat]
        K_weight = matchings[K_weight_pat]
        V_weight = matchings[V_weight_pat]
        width = Q_weight.struct_info.shape[1]

        concat = R.concat([Q_weight, K_weight, V_weight], axis=1)
        matmul = R.matmul(inp, concat)
        Q = R.strided_slice(matmul, axes=[2], begin=[0], end=[width])
        K = R.strided_slice(matmul, axes=[2], begin=[width], end=[width * 2])
        V = R.strided_slice(matmul, axes=[2], begin=[width * 2], end=[width * 3])

        return {matchings[matmul1]: Q, matchings[matmul2]: K, matchings[matmul3]: V}

    return ctx, qkv_proj_rewriter

组合矩阵乘法两次#

@R.function(private=True)
def qkv_x2(
    x1: R.Tensor((2, 1024, 640), "float32"),
    x2: R.Tensor((2, 1024, 640), "float32"),
    w0: R.Tensor((640, 640), "float32"),
    w1: R.Tensor((640, 640), "float32"),
    w2: R.Tensor((640, 640), "float32"),
    w3: R.Tensor((640, 640), "float32"),
    w4: R.Tensor((640, 640), "float32"),
    w5: R.Tensor((640, 640), "float32"),
):
    with R.dataflow():
        lv0 = R.matmul(x1, w0)
        lv1 = R.matmul(x1, w1)
        lv2 = R.matmul(x1, w2)
        lv3 = R.matmul(x2, w3)
        lv4 = R.matmul(x2, w4)
        lv5 = R.matmul(x2, w5)
        out = (lv0, lv1, lv2, lv3, lv4, lv5)
        R.output(out)
    return out
qkv_x2.show()
# from tvm.script import relax as R

@R.function(private=True)
def main(x1: R.Tensor((2, 1024, 640), dtype="float32"), x2: R.Tensor((2, 1024, 640), dtype="float32"), w0: R.Tensor((640, 640), dtype="float32"), w1: R.Tensor((640, 640), dtype="float32"), w2: R.Tensor((640, 640), dtype="float32"), w3: R.Tensor((640, 640), dtype="float32"), w4: R.Tensor((640, 640), dtype="float32"), w5: R.Tensor((640, 640), dtype="float32")) -> R.Tuple(R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32")):
    with R.dataflow():
        lv0: R.Tensor((2, 1024, 640), dtype="float32") = R.matmul(x1, w0, out_dtype="void")
        lv1: R.Tensor((2, 1024, 640), dtype="float32") = R.matmul(x1, w1, out_dtype="void")
        lv2: R.Tensor((2, 1024, 640), dtype="float32") = R.matmul(x1, w2, out_dtype="void")
        lv3: R.Tensor((2, 1024, 640), dtype="float32") = R.matmul(x2, w3, out_dtype="void")
        lv4: R.Tensor((2, 1024, 640), dtype="float32") = R.matmul(x2, w4, out_dtype="void")
        lv5: R.Tensor((2, 1024, 640), dtype="float32") = R.matmul(x2, w5, out_dtype="void")
        out: R.Tuple(R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32")) = lv0, lv1, lv2, lv3, lv4, lv5
        R.output(out)
    return out
ctx, rewriter = get_qkv_proj_rewriter()
rewritten = rewrite_bindings(ctx, rewriter, qkv_x2)
rewritten.show()
# from tvm.script import relax as R

@R.function(private=True)
def main(x1: R.Tensor((2, 1024, 640), dtype="float32"), x2: R.Tensor((2, 1024, 640), dtype="float32"), w0: R.Tensor((640, 640), dtype="float32"), w1: R.Tensor((640, 640), dtype="float32"), w2: R.Tensor((640, 640), dtype="float32"), w3: R.Tensor((640, 640), dtype="float32"), w4: R.Tensor((640, 640), dtype="float32"), w5: R.Tensor((640, 640), dtype="float32")) -> R.Tuple(R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32")):
    with R.dataflow():
        lv: R.Tensor((640, 1920), dtype="float32") = R.concat((w0, w1, w2), axis=1)
        lv1: R.Tensor((2, 1024, 1920), dtype="float32") = R.matmul(x1, lv, out_dtype="void")
        lv0: R.Tensor((2, 1024, 640), dtype="float32") = R.strided_slice(lv1, (R.prim_value(2),), (R.prim_value(0),), (R.prim_value(640),), assume_inbound=False)
        lv1_1: R.Tensor((2, 1024, 640), dtype="float32") = R.strided_slice(lv1, (R.prim_value(2),), (R.prim_value(640),), (R.prim_value(1280),), assume_inbound=False)
        lv2: R.Tensor((2, 1024, 640), dtype="float32") = R.strided_slice(lv1, (R.prim_value(2),), (R.prim_value(1280),), (R.prim_value(1920),), assume_inbound=False)
        lv2_1: R.Tensor((640, 1920), dtype="float32") = R.concat((w3, w4, w5), axis=1)
        lv3: R.Tensor((2, 1024, 1920), dtype="float32") = R.matmul(x2, lv2_1, out_dtype="void")
        lv3_1: R.Tensor((2, 1024, 640), dtype="float32") = R.strided_slice(lv3, (R.prim_value(2),), (R.prim_value(0),), (R.prim_value(640),), assume_inbound=False)
        lv4: R.Tensor((2, 1024, 640), dtype="float32") = R.strided_slice(lv3, (R.prim_value(2),), (R.prim_value(640),), (R.prim_value(1280),), assume_inbound=False)
        lv5: R.Tensor((2, 1024, 640), dtype="float32") = R.strided_slice(lv3, (R.prim_value(2),), (R.prim_value(1280),), (R.prim_value(1920),), assume_inbound=False)
        out: R.Tuple(R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32")) = lv0, lv1_1, lv2, lv3_1, lv4, lv5
        R.output(out)
    return out

测试数据流可能以匹配转换开始#

rewrite_bindings 的输入可能包含 R.match_cast

这是回归测试。在之前的实现中,当 R.match_castR.dataflow 块的第一个绑定时,应用 rewrite_bindings 会导致段错误。

@R.function(private=True)
def before(
    x_untyped: R.Tensor,
    w0_untyped: R.Tensor,
    w1_untyped: R.Tensor,
    w2_untyped: R.Tensor,
):
    with R.dataflow():
        x = R.match_cast(x_untyped, R.Tensor((2, 1024, 640), "float32"))
        w0 = R.match_cast(w0_untyped, R.Tensor((640, 640), "float32"))
        w1 = R.match_cast(w1_untyped, R.Tensor((640, 640), "float32"))
        w2 = R.match_cast(w2_untyped, R.Tensor((640, 640), "float32"))
        out_0 = R.matmul(x, w0)
        out_1 = R.matmul(x, w1)
        out_2 = R.matmul(x, w2)
        out = (out_0, out_1, out_2)
        R.output(out)
    return out
before.show()
# from tvm.script import relax as R

@R.function(private=True)
def main(x_untyped: R.Tensor, w0_untyped: R.Tensor, w1_untyped: R.Tensor, w2_untyped: R.Tensor) -> R.Tuple(R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32")):
    with R.dataflow():
        x: R.Tensor((2, 1024, 640), dtype="float32") = R.match_cast(x_untyped, R.Tensor((2, 1024, 640), dtype="float32"))
        w0: R.Tensor((640, 640), dtype="float32") = R.match_cast(w0_untyped, R.Tensor((640, 640), dtype="float32"))
        w1: R.Tensor((640, 640), dtype="float32") = R.match_cast(w1_untyped, R.Tensor((640, 640), dtype="float32"))
        w2: R.Tensor((640, 640), dtype="float32") = R.match_cast(w2_untyped, R.Tensor((640, 640), dtype="float32"))
        out_0: R.Tensor((2, 1024, 640), dtype="float32") = R.matmul(x, w0, out_dtype="void")
        out_1: R.Tensor((2, 1024, 640), dtype="float32") = R.matmul(x, w1, out_dtype="void")
        out_2: R.Tensor((2, 1024, 640), dtype="float32") = R.matmul(x, w2, out_dtype="void")
        out: R.Tuple(R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32")) = out_0, out_1, out_2
        R.output(out)
    return out
ctx, rewriter = get_qkv_proj_rewriter()
rewritten = rewrite_bindings(ctx, rewriter, before)
rewritten.show()
# from tvm.script import relax as R

@R.function(private=True)
def main(x_untyped: R.Tensor, w0_untyped: R.Tensor, w1_untyped: R.Tensor, w2_untyped: R.Tensor) -> R.Tuple(R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32")):
    with R.dataflow():
        x: R.Tensor((2, 1024, 640), dtype="float32") = R.match_cast(x_untyped, R.Tensor((2, 1024, 640), dtype="float32"))
        w0: R.Tensor((640, 640), dtype="float32") = R.match_cast(w0_untyped, R.Tensor((640, 640), dtype="float32"))
        w1: R.Tensor((640, 640), dtype="float32") = R.match_cast(w1_untyped, R.Tensor((640, 640), dtype="float32"))
        w2: R.Tensor((640, 640), dtype="float32") = R.match_cast(w2_untyped, R.Tensor((640, 640), dtype="float32"))
        lv: R.Tensor((640, 1920), dtype="float32") = R.concat((w0, w1, w2), axis=1)
        lv1: R.Tensor((2, 1024, 1920), dtype="float32") = R.matmul(x, lv, out_dtype="void")
        out_0: R.Tensor((2, 1024, 640), dtype="float32") = R.strided_slice(lv1, (R.prim_value(2),), (R.prim_value(0),), (R.prim_value(640),), assume_inbound=False)
        out_1: R.Tensor((2, 1024, 640), dtype="float32") = R.strided_slice(lv1, (R.prim_value(2),), (R.prim_value(640),), (R.prim_value(1280),), assume_inbound=False)
        out_2: R.Tensor((2, 1024, 640), dtype="float32") = R.strided_slice(lv1, (R.prim_value(2),), (R.prim_value(1280),), (R.prim_value(1920),), assume_inbound=False)
        out: R.Tuple(R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32")) = out_0, out_1, out_2
        R.output(out)
    return out

测试组合矩阵乘法的发射顺序#

@R.function(private=True)
def main(
    x1: R.Tensor((2, 1024, 640), "float32"),
    w0: R.Tensor((640, 640), "float32"),
    w1: R.Tensor((640, 640), "float32"),
    w2: R.Tensor((640, 640), "float32"),
):
    with R.dataflow():
        w0_t = R.permute_dims(w0, axes=None)
        lv0 = R.matmul(x1, w0_t)
        w1_t = R.permute_dims(w1, axes=None)
        w1_t_t = R.permute_dims(w1_t, axes=None)
        lv1 = R.matmul(x1, w1_t_t)
        w2_t = R.permute_dims(w2, axes=None)
        lv2 = R.matmul(x1, w2_t)
        out = (lv0, lv1, lv2)
        R.output(out)
    return out

main.show()
# from tvm.script import relax as R

@R.function(private=True)
def main(x1: R.Tensor((2, 1024, 640), dtype="float32"), w0: R.Tensor((640, 640), dtype="float32"), w1: R.Tensor((640, 640), dtype="float32"), w2: R.Tensor((640, 640), dtype="float32")) -> R.Tuple(R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32")):
    with R.dataflow():
        w0_t: R.Tensor((640, 640), dtype="float32") = R.permute_dims(w0, axes=None)
        lv0: R.Tensor((2, 1024, 640), dtype="float32") = R.matmul(x1, w0_t, out_dtype="void")
        w1_t: R.Tensor((640, 640), dtype="float32") = R.permute_dims(w1, axes=None)
        w1_t_t: R.Tensor((640, 640), dtype="float32") = R.permute_dims(w1_t, axes=None)
        lv1: R.Tensor((2, 1024, 640), dtype="float32") = R.matmul(x1, w1_t_t, out_dtype="void")
        w2_t: R.Tensor((640, 640), dtype="float32") = R.permute_dims(w2, axes=None)
        lv2: R.Tensor((2, 1024, 640), dtype="float32") = R.matmul(x1, w2_t, out_dtype="void")
        out: R.Tuple(R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32")) = lv0, lv1, lv2
        R.output(out)
    return out
ctx, rewriter = get_qkv_proj_rewriter()

rewritten = rewrite_bindings(ctx, rewriter, main)
rewritten.show()
# from tvm.script import relax as R

@R.function(private=True)
def main(x1: R.Tensor((2, 1024, 640), dtype="float32"), w0: R.Tensor((640, 640), dtype="float32"), w1: R.Tensor((640, 640), dtype="float32"), w2: R.Tensor((640, 640), dtype="float32")) -> R.Tuple(R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32")):
    with R.dataflow():
        w0_t: R.Tensor((640, 640), dtype="float32") = R.permute_dims(w0, axes=None)
        w1_t: R.Tensor((640, 640), dtype="float32") = R.permute_dims(w1, axes=None)
        w1_t_t: R.Tensor((640, 640), dtype="float32") = R.permute_dims(w1_t, axes=None)
        w2_t: R.Tensor((640, 640), dtype="float32") = R.permute_dims(w2, axes=None)
        lv: R.Tensor((640, 1920), dtype="float32") = R.concat((w0_t, w1_t_t, w2_t), axis=1)
        lv1: R.Tensor((2, 1024, 1920), dtype="float32") = R.matmul(x1, lv, out_dtype="void")
        lv0: R.Tensor((2, 1024, 640), dtype="float32") = R.strided_slice(lv1, (R.prim_value(2),), (R.prim_value(0),), (R.prim_value(640),), assume_inbound=False)
        lv1_1: R.Tensor((2, 1024, 640), dtype="float32") = R.strided_slice(lv1, (R.prim_value(2),), (R.prim_value(640),), (R.prim_value(1280),), assume_inbound=False)
        lv2: R.Tensor((2, 1024, 640), dtype="float32") = R.strided_slice(lv1, (R.prim_value(2),), (R.prim_value(1280),), (R.prim_value(1920),), assume_inbound=False)
        out: R.Tuple(R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32")) = lv0, lv1_1, lv2
        R.output(out)
    return out
# 确保它能够构建
mod = tvm.IRModule()
mod["main"] = rewritten

rx.build(mod, target="llvm")
<tvm.relax.vm_build.Executable at 0x7f4ec7fd4920>

组合转置矩阵乘法两次#

@R.function(private=True)
def main(
    x1: R.Tensor((2, 1024, 640), "float32"),
    x2: R.Tensor((2, 1024, 640), "float32"),
    w0: R.Tensor((640, 640), "float32"),
    w1: R.Tensor((640, 640), "float32"),
    w2: R.Tensor((640, 640), "float32"),
    w3: R.Tensor((640, 640), "float32"),
):
    with R.dataflow():
        w0_t = R.permute_dims(w0, axes=None)
        lv0 = R.matmul(x1, w0_t)
        w1_t = R.permute_dims(w1, axes=None)
        lv1 = R.matmul(x1, w1_t)
        w2_t = R.permute_dims(w2, axes=None)
        lv2 = R.matmul(x2, w2_t)
        w3_t = R.permute_dims(w3, axes=None)
        lv3 = R.matmul(x2, w3_t)
        out = (lv0, lv1, lv2, lv3)
        R.output(out)
    return out
main.show()
# from tvm.script import relax as R

@R.function(private=True)
def main(x1: R.Tensor((2, 1024, 640), dtype="float32"), x2: R.Tensor((2, 1024, 640), dtype="float32"), w0: R.Tensor((640, 640), dtype="float32"), w1: R.Tensor((640, 640), dtype="float32"), w2: R.Tensor((640, 640), dtype="float32"), w3: R.Tensor((640, 640), dtype="float32")) -> R.Tuple(R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32")):
    with R.dataflow():
        w0_t: R.Tensor((640, 640), dtype="float32") = R.permute_dims(w0, axes=None)
        lv0: R.Tensor((2, 1024, 640), dtype="float32") = R.matmul(x1, w0_t, out_dtype="void")
        w1_t: R.Tensor((640, 640), dtype="float32") = R.permute_dims(w1, axes=None)
        lv1: R.Tensor((2, 1024, 640), dtype="float32") = R.matmul(x1, w1_t, out_dtype="void")
        w2_t: R.Tensor((640, 640), dtype="float32") = R.permute_dims(w2, axes=None)
        lv2: R.Tensor((2, 1024, 640), dtype="float32") = R.matmul(x2, w2_t, out_dtype="void")
        w3_t: R.Tensor((640, 640), dtype="float32") = R.permute_dims(w3, axes=None)
        lv3: R.Tensor((2, 1024, 640), dtype="float32") = R.matmul(x2, w3_t, out_dtype="void")
        out: R.Tuple(R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32")) = lv0, lv1, lv2, lv3
        R.output(out)
    return out
with PatternContext() as ctx:
    inp_pat = wildcard()
    w1_pat = wildcard()
    w2_pat = wildcard()
    matmul1 = is_op("relax.matmul")(inp_pat, is_op("relax.permute_dims")(w1_pat))
    matmul2 = is_op("relax.matmul")(inp_pat, is_op("relax.permute_dims")(w2_pat))

    def rewriter(matchings, _):
        inp = matchings[inp_pat]
        w1 = matchings[w1_pat]
        w2 = matchings[w2_pat]

        concat = R.concat([w1, w2], axis=0)
        matmul = R.matmul(inp, R.permute_dims(concat))
        sections = [w1.struct_info.shape[0]]

        chunks = R.split(matmul, sections, -1)

        return {
            matchings[matmul1]: chunks[0],
            matchings[matmul2]: chunks[1],
        }

    rewritten = rewrite_bindings(ctx, rewriter, main)

    # make sure it builds
    mod = tvm.IRModule()
    mod["main"] = rewritten
    mod.show()

    rx.build(mod, target="llvm")
# from tvm.script import ir as I
# from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function(private=True)
    def main(x1: R.Tensor((2, 1024, 640), dtype="float32"), x2: R.Tensor((2, 1024, 640), dtype="float32"), w0: R.Tensor((640, 640), dtype="float32"), w1: R.Tensor((640, 640), dtype="float32"), w2: R.Tensor((640, 640), dtype="float32"), w3: R.Tensor((640, 640), dtype="float32")) -> R.Tuple(R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32")):
        with R.dataflow():
            lv: R.Tensor((1280, 640), dtype="float32") = R.concat((w0, w1), axis=0)
            lv1: R.Tensor((640, 1280), dtype="float32") = R.permute_dims(lv, axes=None)
            lv2: R.Tensor((2, 1024, 1280), dtype="float32") = R.matmul(x1, lv1, out_dtype="void")
            lv3: R.Tuple(R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32")) = R.split(lv2, indices_or_sections=[640], axis=-1)
            lv0: R.Tensor((2, 1024, 640), dtype="float32") = lv3[0]
            lv1_1: R.Tensor((2, 1024, 640), dtype="float32") = lv3[1]
            lv4: R.Tensor((1280, 640), dtype="float32") = R.concat((w2, w3), axis=0)
            lv5: R.Tensor((640, 1280), dtype="float32") = R.permute_dims(lv4, axes=None)
            lv6: R.Tensor((2, 1024, 1280), dtype="float32") = R.matmul(x2, lv5, out_dtype="void")
            lv7: R.Tuple(R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32")) = R.split(lv6, indices_or_sections=[640], axis=-1)
            lv2_1: R.Tensor((2, 1024, 640), dtype="float32") = lv7[0]
            lv3_1: R.Tensor((2, 1024, 640), dtype="float32") = lv7[1]
            out: R.Tuple(R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32"), R.Tensor((2, 1024, 640), dtype="float32")) = lv0, lv1_1, lv2_1, lv3_1
            R.output(out)
        return out