测试相同形状模式

测试相同形状模式#

from tvm.script import relax as R
from tvm.script import tir as T
from tvm import relax as rx
from tvm import relay, tir
from tvm.relax.analysis import get_var2val
import tvm.testing
from tvm.relax.dpl import *

same_shape_func_type = tvm.testing.parameter(
    "same_static_shape",
    "same_dynamic_shape",
    "different_static_shape",
    "different_dynamic_shape",
)
def test_same_shape_pattern(same_shape_func_type):
    if same_shape_func_type == "same_static_shape":

        @R.function(private=True)
        def func(
            a: R.Tensor((1024, 128), "float32"),
            b: R.Tensor((1024, 128), "float32"),
        ) -> R.Tensor:
            with R.dataflow():
                c = R.multiply(a, R.const(2.0))
                d = R.add(b, c)
                out = d
                R.output(out)
            return out

    elif same_shape_func_type == "same_dynamic_shape":

        @R.function(private=True)
        def func(
            a: R.Tensor(("n", 128), "float32"),
            b: R.Tensor(("n", 128), "float32"),
        ) -> R.Tensor:
            with R.dataflow():
                c = R.multiply(a, R.const(2.0))
                d = R.add(b, c)
                out = d
                R.output(out)
            return out

    elif same_shape_func_type == "different_static_shape":

        @R.function(private=True)
        def func(
            a: R.Tensor((1024, 128), "float32"),
            b: R.Tensor((1, 128), "float32"),
        ) -> R.Tensor:
            with R.dataflow():
                c = R.multiply(a, R.const(2.0))
                d = R.add(b, c)
                out = d
                R.output(out)
            return out

    elif same_shape_func_type == "different_dynamic_shape":

        @R.function(private=True)
        def func(
            a: R.Tensor(("n", 128), "float32"),
            b: R.Tensor(("m", 128), "float32"),
        ) -> R.Tensor:
            with R.dataflow():
                c = R.multiply(a, R.const(2.0))
                d = R.add(b, c)
                out = d
                R.output(out)
            return out

    else:
        raise ValueError(f"Unknown value of same_shape_func_type={same_shape_func_type}")

    with PatternContext() as ctx:
        pat_lhs = wildcard()
        pat_rhs = wildcard()
        pat_sum = is_op("relax.add")(pat_lhs, pat_rhs)
        pat_lhs.same_shape_as(pat_rhs)

    block = func.body.blocks[0]
    match = ctx.match_dfb(block)

    if "same" in same_shape_func_type:
        assert match
    else:
        assert match is None