测试匹配结构信息

测试匹配结构信息#

from tvm.script import relax as R
from tvm.script import tir as T
from tvm import relax as rx
from tvm import relay, tir
from tvm.relax.analysis import get_var2val
import tvm.testing
from tvm.relax.dpl import *

测试匹配时带有结构信息更新的通配符:

def test_wildcard_with_struct_info_updates_when_matching():
    """A DFPattern may be restricted to a specific StructInfo"""

    pat_lhs = wildcard().has_struct_info(R.Tensor([2, 3]))
    pat_rhs = wildcard().has_struct_info(R.Tensor([2, 3]))
    pat = is_op("relax.add")(pat_lhs, pat_rhs)

    def rewriter(expr, matches):
        lhs = matches[pat_lhs]
        rhs = matches[pat_rhs]
        return rx.op.multiply(lhs, rhs)

    @R.function(private=True)
    def before():
        with R.dataflow():
            A = R.zeros([2, 3], "int32")
            B = R.ones([2, 3], "int32")
            C = R.add(A, B)

            R.output(C)
        return C

    @R.function(private=True)
    def expected():
        with R.dataflow():
            A = R.zeros([2, 3], "int32")
            B = R.ones([2, 3], "int32")
            C = R.multiply(A, B)

            R.output(C)
        return C

    after = rewrite_call(pat, rewriter, before)
    tvm.ir.assert_structural_equal(expected, after)
def test_wildcard_with_struct_info_is_no_op_when_not_matching():
    """StructInfoPattern requires the StructInfo provided

    Here, the pattern would match, expect that the function has
    `R.Tensor([16,32])`, and the pattern requires `R.Tensor([2,3])`.
    """

    pat_lhs = wildcard().has_struct_info(R.Tensor([2, 3]))
    pat_rhs = wildcard().has_struct_info(R.Tensor([2, 3]))
    pat = is_op("relax.add")(pat_lhs, pat_rhs)

    def rewriter(expr, matches):
        lhs = matches[pat_lhs]
        rhs = matches[pat_rhs]
        return rx.op.multiply(lhs, rhs)

    @R.function(private=True)
    def before():
        with R.dataflow():
            # This R.add has the same shape as the pattern, and will
            # be updated.
            A = R.zeros([16, 32], "int32")
            B = R.ones([16, 32], "int32")
            C = R.add(A, B)

            R.output(C)
        return C

    expected = before

    after = rewrite_call(pat, rewriter, before)
    tvm.ir.assert_structural_equal(expected, after)


def test_wildcard_struct_info_for_unknown_dtype():
    """TensorStructInfo with unknown dtype allows any dtype"""

    pat_lhs = wildcard().has_struct_info(R.Tensor([2, 3]))
    pat_rhs = wildcard().has_struct_info(R.Tensor([2, 3]))
    pat = is_op("relax.add")(pat_lhs, pat_rhs)

    def rewriter(expr, matches):
        lhs = matches[pat_lhs]
        rhs = matches[pat_rhs]
        return rx.op.multiply(lhs, rhs)

    @R.function(private=True)
    def before():
        with R.dataflow():
            A = R.zeros([2, 3], "int32")
            B = R.ones([2, 3], "int32")
            C = R.add(A, B)

            D = R.zeros([2, 3], "float32")
            E = R.ones([2, 3], "float32")
            F = R.add(D, E)

            output = (C, F)
            R.output(output)
        return output

    @R.function(private=True)
    def expected():
        with R.dataflow():
            A = R.zeros([2, 3], "int32")
            B = R.ones([2, 3], "int32")
            C = R.multiply(A, B)

            D = R.zeros([2, 3], "float32")
            E = R.ones([2, 3], "float32")
            F = R.multiply(D, E)

            output = (C, F)
            R.output(output)
        return output

    after = rewrite_call(pat, rewriter, before)
    tvm.ir.assert_structural_equal(expected, after)


def test_wildcard_struct_info_with_symbolic_vars():
    """StructInfoPattern may define symbolic vars

    This test finds an elementwise `R.add`, while ignoring a
    broadcasted `R.add`.
    """

    m = tir.Var("m", "int64")
    n = tir.Var("n", "int64")

    pat_lhs = wildcard().has_struct_info(R.Tensor([m, n]))
    pat_rhs = wildcard().has_struct_info(R.Tensor([m, n]))
    pat = is_op("relax.add")(pat_lhs, pat_rhs)

    def rewriter(expr, matches):
        lhs = matches[pat_lhs]
        rhs = matches[pat_rhs]
        return rx.op.multiply(lhs, rhs)

    @R.function(private=True)
    def before():
        with R.dataflow():
            A = R.zeros([64, 128], "int32")
            B = R.ones([64, 128], "int32")
            C = R.add(A, B)

            D = R.zeros([64, 128], "float32")
            E = R.ones([1, 128], "float32")
            F = R.add(D, E)

            output = (C, F)
            R.output(output)
        return output

    @R.function(private=True)
    def expected():
        with R.dataflow():
            A = R.zeros([64, 128], "int32")
            B = R.ones([64, 128], "int32")
            C = R.multiply(A, B)

            D = R.zeros([64, 128], "float32")
            E = R.ones([1, 128], "float32")
            F = R.add(D, E)

            output = (C, F)
            R.output(output)
        return output

    after = rewrite_call(pat, rewriter, before)
    tvm.ir.assert_structural_equal(expected, after)