def test_backtrack_if_rewriter_returns_no_op():
"""Rewriter participates in the pattern matching
Sometimes, the pattern-matching syntax is insufficient to check if
a replacement may be performed. In this case, the `rewriter`
function may perform additional validation. If this validation
fails, the `rewriter` function can return the original expression,
and no replacement is performed.
In addition, when the `rewriter` returns the original expression,
the pattern match should backtrack to determine if another branch
of the match may have produced a replacement.
This functionality allows pattern replacements to be composed.
"""
pat_match_no_rewrite = is_op("relax.add")(wildcard(), wildcard())
pat_arg = wildcard()
pat_zeros = is_op("relax.zeros")(wildcard())
pat_add = is_op("relax.add")(pat_arg, pat_zeros)
# OR conditions are checked in the order that they occur. Because
# `pat_match_no_rewrite` is a superset of `pat_add`, it will
# always match first.
pat = pat_match_no_rewrite | pat_add
def rewriter(expr, matches):
if pat_match_no_rewrite in matches:
# This branch simulates a rewrite whose precondition has
# failed. If the pattern-matching treats this as a
# successful match with no replacemen required, then no
# rewrite would be performed. On the other hand, if the
# pattern-matching treats this as an unsuccessful match,
# then it can backtrack and attempt `pat_add` instead.
return expr
elif pat_add in matches:
return matches[pat_arg]
else:
raise RuntimeError("Pattern matched, but neither branch matched")
@R.function(private=True)
def before():
with R.dataflow():
A = R.ones([64, 128], "int32")
B = R.zeros([64, 128], "int32")
C = R.add(A, B)
R.output(C)
return C
@R.function(private=True)
def expected():
with R.dataflow():
C = R.ones([64, 128], "int32")
R.output(C)
return C
after = rewrite_call(pat, rewriter, before)
tvm.ir.assert_structural_equal(expected, after)
def test_backtrack_for_no_op_rewriter_does_not_match_on_var():
"""The matches should always contain the bound value
This is a regression test. In versions from
https://github.com/apache/tvm/pull/16732 to
https://github.com/apache/tvm/pull/16828, the `rewrite_call`
function could erroneously call the rewriter with `expr` and
`matches[pat]` set to a variable (`C`) instead of the value to
which it is bound (`R.add(A,B)`).
"""
pat_a = is_op("relax.add")(wildcard(), wildcard())
pat_b = is_op("relax.add")(wildcard(), wildcard())
pat = pat_a | pat_b
def rewriter(expr, matches):
assert isinstance(matches[pat], rx.Call)
return expr
@R.function(private=True)
def before():
with R.dataflow():
A = R.ones([64, 128], "int32")
B = R.zeros([64, 128], "int32")
C = R.add(A, B)
R.output(C)
return C
expected = before
after = rewrite_call(pat, rewriter, before)
tvm.ir.assert_structural_equal(expected, after)