deftest_backtrack_if_rewriter_returns_no_op():"""Rewriter participates in the pattern matching Sometimes, the pattern-matching syntax is insufficient to check if a replacement may be performed. In this case, the `rewriter` function may perform additional validation. If this validation fails, the `rewriter` function can return the original expression, and no replacement is performed. In addition, when the `rewriter` returns the original expression, the pattern match should backtrack to determine if another branch of the match may have produced a replacement. This functionality allows pattern replacements to be composed. """pat_match_no_rewrite=is_op("relax.add")(wildcard(),wildcard())pat_arg=wildcard()pat_zeros=is_op("relax.zeros")(wildcard())pat_add=is_op("relax.add")(pat_arg,pat_zeros)# OR conditions are checked in the order that they occur. Because# `pat_match_no_rewrite` is a superset of `pat_add`, it will# always match first.pat=pat_match_no_rewrite|pat_adddefrewriter(expr,matches):ifpat_match_no_rewriteinmatches:# This branch simulates a rewrite whose precondition has# failed. If the pattern-matching treats this as a# successful match with no replacemen required, then no# rewrite would be performed. On the other hand, if the# pattern-matching treats this as an unsuccessful match,# then it can backtrack and attempt `pat_add` instead.returnexprelifpat_addinmatches:returnmatches[pat_arg]else:raiseRuntimeError("Pattern matched, but neither branch matched")@R.function(private=True)defbefore():withR.dataflow():A=R.ones([64,128],"int32")B=R.zeros([64,128],"int32")C=R.add(A,B)R.output(C)returnC@R.function(private=True)defexpected():withR.dataflow():C=R.ones([64,128],"int32")R.output(C)returnCafter=rewrite_call(pat,rewriter,before)tvm.ir.assert_structural_equal(expected,after)deftest_backtrack_for_no_op_rewriter_does_not_match_on_var():"""The matches should always contain the bound value This is a regression test. In versions from https://github.com/apache/tvm/pull/16732 to https://github.com/apache/tvm/pull/16828, the `rewrite_call` function could erroneously call the rewriter with `expr` and `matches[pat]` set to a variable (`C`) instead of the value to which it is bound (`R.add(A,B)`). """pat_a=is_op("relax.add")(wildcard(),wildcard())pat_b=is_op("relax.add")(wildcard(),wildcard())pat=pat_a|pat_bdefrewriter(expr,matches):assertisinstance(matches[pat],rx.Call)returnexpr@R.function(private=True)defbefore():withR.dataflow():A=R.ones([64,128],"int32")B=R.zeros([64,128],"int32")C=R.add(A,B)R.output(C)returnCexpected=beforeafter=rewrite_call(pat,rewriter,before)tvm.ir.assert_structural_equal(expected,after)