回溯

回溯#

from tvm.script import relax as R
from tvm.script import tir as T
from tvm import relax as rx
from tvm import relay, tir
from tvm.relax.analysis import get_var2val
import tvm.testing
from tvm.relax.dpl import *
def test_backtrack_if_rewriter_returns_no_op():
    """Rewriter participates in the pattern matching

    Sometimes, the pattern-matching syntax is insufficient to check if
    a replacement may be performed.  In this case, the `rewriter`
    function may perform additional validation.  If this validation
    fails, the `rewriter` function can return the original expression,
    and no replacement is performed.

    In addition, when the `rewriter` returns the original expression,
    the pattern match should backtrack to determine if another branch
    of the match may have produced a replacement.

    This functionality allows pattern replacements to be composed.
    """

    pat_match_no_rewrite = is_op("relax.add")(wildcard(), wildcard())

    pat_arg = wildcard()
    pat_zeros = is_op("relax.zeros")(wildcard())
    pat_add = is_op("relax.add")(pat_arg, pat_zeros)

    # OR conditions are checked in the order that they occur.  Because
    # `pat_match_no_rewrite` is a superset of `pat_add`, it will
    # always match first.
    pat = pat_match_no_rewrite | pat_add

    def rewriter(expr, matches):
        if pat_match_no_rewrite in matches:
            # This branch simulates a rewrite whose precondition has
            # failed.  If the pattern-matching treats this as a
            # successful match with no replacemen required, then no
            # rewrite would be performed.  On the other hand, if the
            # pattern-matching treats this as an unsuccessful match,
            # then it can backtrack and attempt `pat_add` instead.
            return expr
        elif pat_add in matches:
            return matches[pat_arg]
        else:
            raise RuntimeError("Pattern matched, but neither branch matched")

    @R.function(private=True)
    def before():
        with R.dataflow():
            A = R.ones([64, 128], "int32")
            B = R.zeros([64, 128], "int32")
            C = R.add(A, B)

            R.output(C)
        return C

    @R.function(private=True)
    def expected():
        with R.dataflow():
            C = R.ones([64, 128], "int32")

            R.output(C)
        return C

    after = rewrite_call(pat, rewriter, before)
    tvm.ir.assert_structural_equal(expected, after)


def test_backtrack_for_no_op_rewriter_does_not_match_on_var():
    """The matches should always contain the bound value

    This is a regression test.  In versions from
    https://github.com/apache/tvm/pull/16732 to
    https://github.com/apache/tvm/pull/16828, the `rewrite_call`
    function could erroneously call the rewriter with `expr` and
    `matches[pat]` set to a variable (`C`) instead of the value to
    which it is bound (`R.add(A,B)`).
    """
    pat_a = is_op("relax.add")(wildcard(), wildcard())
    pat_b = is_op("relax.add")(wildcard(), wildcard())
    pat = pat_a | pat_b

    def rewriter(expr, matches):
        assert isinstance(matches[pat], rx.Call)
        return expr

    @R.function(private=True)
    def before():
        with R.dataflow():
            A = R.ones([64, 128], "int32")
            B = R.zeros([64, 128], "int32")
            C = R.add(A, B)

            R.output(C)
        return C

    expected = before
    after = rewrite_call(pat, rewriter, before)
    tvm.ir.assert_structural_equal(expected, after)