

from tvm.script import relax as R
from tvm.script import tir as T
from tvm import relax as rx
from tvm import relay, tir
from tvm.relax.analysis import get_var2val
import tvm.testing
from tvm.relax.dpl import *

same_shape_func_type = tvm.testing.parameter(
def test_iterative_rewrite_without_trivial_binding():
    """Avoid introducing common sub-expressions

    Pattern replacement may produce the same intermediate, which
    should appear only once in the final result.

    def before(x: R.Tensor((1024,))):
        with R.dataflow():
            a = R.strided_slice(x, [0], [0], [512], [1])
            b = R.strided_slice(x, [0], [512], [1024], [1])
            c = R.add(a, b)
        return c

    def expected(x: R.Tensor((1024,))):
        with R.dataflow():
            x_split = R.split(x, 2)
            a = x_split[0]
            b = x_split[1]
            c = R.add(a, b)
        return c

    pattern_arg = wildcard()
    pattern_axes = wildcard()
    pattern_begin = wildcard()
    pattern_end = wildcard()
    pattern_strides = wildcard()
    pattern = is_op("relax.strided_slice")(
        pattern_arg, pattern_axes, pattern_begin, pattern_end, pattern_strides

    def rewriter(expr, matches):
        arg = matches[pattern_arg]
        axes = matches[pattern_axes]
        begin = matches[pattern_begin]
        end = matches[pattern_end]
        strides = matches[pattern_strides]
        strided_slice = matches[pattern]

        if arg.struct_info.shape is None:
            return expr

        if len(axes) != 1:
            return expr

        axis = axes[0].value
        begin = begin[0].value
        end = end[0].value
        stride = strides[0].value

        if stride != 1:
            return expr

        size = arg.struct_info.shape[0]
        if (
            isinstance(size, tir.IntImm)
            and isinstance(begin, tir.IntImm)
            and isinstance(end, tir.IntImm)
            size = size.value
            begin = begin.value
            end = end.value
            return expr

        gcd = functools.reduce(math.gcd, [begin, end, size])
        if (end - begin) // gcd == 1:
            return rx.op.split(arg, size // gcd)[begin // gcd]

        return expr

    after = rewrite_call(pattern, rewriter, before)
    tvm.ir.assert_structural_equal(after, expected)

def test_iterative_rewrite_with_removed_intermediates():
    """Pattern replacement may require canonicalization

    A pattern may replace a tuple returned by a function with a tuple
    whose contents are known by Relax.  In that case, canonicalization
    is required to unwrap the TupleGetItem instances into the known

    This test case shows the intermediate results produced in the
    process of pattern-matching.

    def before(a: R.Tensor((1024,)), b: R.Tensor((1024,))):
        with R.dataflow():
            c = R.concat([a, b])
            d = R.split(c, 2)
            e = d[0]
            f = d[1]
            g = R.add(a, e)
            h = R.add(f, g)
        return h

    # First pattern rewrite.  The concat/rewrite can be unwrapped, so
    # `d` is rewritten from `R.split(c, 2)` into `(a, b)`.
    # @R.function(private=True)
    # def intermediate(a: R.Tensor((1024,)), b: R.Tensor((1024,))):
    #     with R.dataflow():
    #         c = R.concat([a, b])
    #         d = (a,b)
    #         e = d[0]
    #         f = d[1]
    #         g = R.add(a, e)
    #         h = R.add(f, g)
    #         R.output(h)

    # Canonicalization step.  Because `d` is known to be `(a,b)`,
    # canonicalization can rewrite `d[0]` into `a` and `d[1]` into
    # `b`.
    # @R.function(private=True)
    # def intermediate(a: R.Tensor((1024,)), b: R.Tensor((1024,))):
    #     with R.dataflow():
    #         c = R.concat([a, b])
    #         d = (a,b)
    #         e = a
    #         f = b
    #         g = R.add(a, a)
    #         h = R.add(b, g)
    #         R.output(h)

    # Dead-code-elimination step.  This technically isn't required
    # until the pattern matching has converged, but performing it now
    # prevents testing for matches on dead code.
    # @R.function(private=True)
    # def intermediate(a: R.Tensor((1024,)), b: R.Tensor((1024,))):
    #     with R.dataflow():
    #         g = R.add(a, a)
    #         h = R.add(b, g)
    #         R.output(h)

    # Second pattern-matching step.  Now, the `R.add(a,a)` can match
    # the other option in our pattern, and be rewritten as
    # `R.multiply(a,R.const(2))`.
    # @R.function(private=True)
    # def intermediate(a: R.Tensor((1024,)), b: R.Tensor((1024,))):
    #     with R.dataflow():
    #         g = R.multiply(a, R.const(2))
    #         h = R.add(b, g)
    #         R.output(h)

    # Canonicalization and dead-code-elimination are applied again,
    # but have no effect this time.

    def expected(a: R.Tensor((1024,)), b: R.Tensor((1024,))):
        with R.dataflow():
            g = R.multiply(a, R.const(2))
            h = R.add(b, g)
        return h

    pat_args = wildcard()

    op_concat = is_op("relax.concat")
    pat_concat = op_concat(pat_args).has_attr({"axis": 0})

    op_split = is_op("relax.split")
    pat_split = op_split(pat_concat).has_attr({"axis": 0, "indices_or_sections": T.int64(2)})

    pat_unwrap_concat_split = pat_split

    pat_arg = wildcard()
    op_add = is_op("relax.add")
    pat_add_self = op_add(pat_arg, pat_arg)

    pattern = pat_unwrap_concat_split | pat_add_self

    def rewriter(expr, matches):
        if pat_unwrap_concat_split in matches:
            args = matches[pat_args]

            if len(args) == 2 and tvm.ir.structural_equal(args[0].struct_info, args[1].struct_info):
                return args

        elif pat_add_self in matches:
            arg = matches[pat_arg]
            return arg * rx.const(2)

        return expr

    after = rewrite_call(pattern, rewriter, before)
    tvm.ir.assert_structural_equal(expected, after)