from tvm.script import relax as R
from tvm.script import tir as T
from tvm import relax as rx
from tvm import relay, tir
from tvm.relax.analysis import get_var2val
import tvm.testing
from tvm.relax.dpl import *
same_shape_func_type = tvm.testing.parameter(
"same_static_shape",
"same_dynamic_shape",
"different_static_shape",
"different_dynamic_shape",
)
def test_iterative_rewrite_without_trivial_binding():
"""Avoid introducing common sub-expressions
Pattern replacement may produce the same intermediate, which
should appear only once in the final result.
"""
@R.function(private=True)
def before(x: R.Tensor((1024,))):
with R.dataflow():
a = R.strided_slice(x, [0], [0], [512], [1])
b = R.strided_slice(x, [0], [512], [1024], [1])
c = R.add(a, b)
R.output(c)
return c
@R.function(private=True)
def expected(x: R.Tensor((1024,))):
with R.dataflow():
x_split = R.split(x, 2)
a = x_split[0]
b = x_split[1]
c = R.add(a, b)
R.output(c)
return c
pattern_arg = wildcard()
pattern_axes = wildcard()
pattern_begin = wildcard()
pattern_end = wildcard()
pattern_strides = wildcard()
pattern = is_op("relax.strided_slice")(
pattern_arg, pattern_axes, pattern_begin, pattern_end, pattern_strides
)
def rewriter(expr, matches):
arg = matches[pattern_arg]
axes = matches[pattern_axes]
begin = matches[pattern_begin]
end = matches[pattern_end]
strides = matches[pattern_strides]
strided_slice = matches[pattern]
if arg.struct_info.shape is None:
return expr
if len(axes) != 1:
return expr
axis = axes[0].value
begin = begin[0].value
end = end[0].value
stride = strides[0].value
if stride != 1:
return expr
size = arg.struct_info.shape[0]
if (
isinstance(size, tir.IntImm)
and isinstance(begin, tir.IntImm)
and isinstance(end, tir.IntImm)
):
size = size.value
begin = begin.value
end = end.value
else:
return expr
gcd = functools.reduce(math.gcd, [begin, end, size])
if (end - begin) // gcd == 1:
return rx.op.split(arg, size // gcd)[begin // gcd]
return expr
after = rewrite_call(pattern, rewriter, before)
tvm.ir.assert_structural_equal(after, expected)
def test_iterative_rewrite_with_removed_intermediates():
"""Pattern replacement may require canonicalization
A pattern may replace a tuple returned by a function with a tuple
whose contents are known by Relax. In that case, canonicalization
is required to unwrap the TupleGetItem instances into the known
contents.
This test case shows the intermediate results produced in the
process of pattern-matching.
"""
@R.function(private=True)
def before(a: R.Tensor((1024,)), b: R.Tensor((1024,))):
with R.dataflow():
c = R.concat([a, b])
d = R.split(c, 2)
e = d[0]
f = d[1]
g = R.add(a, e)
h = R.add(f, g)
R.output(h)
return h
# First pattern rewrite. The concat/rewrite can be unwrapped, so
# `d` is rewritten from `R.split(c, 2)` into `(a, b)`.
#
# @R.function(private=True)
# def intermediate(a: R.Tensor((1024,)), b: R.Tensor((1024,))):
# with R.dataflow():
# c = R.concat([a, b])
# d = (a,b)
# e = d[0]
# f = d[1]
# g = R.add(a, e)
# h = R.add(f, g)
# R.output(h)
# Canonicalization step. Because `d` is known to be `(a,b)`,
# canonicalization can rewrite `d[0]` into `a` and `d[1]` into
# `b`.
#
# @R.function(private=True)
# def intermediate(a: R.Tensor((1024,)), b: R.Tensor((1024,))):
# with R.dataflow():
# c = R.concat([a, b])
# d = (a,b)
# e = a
# f = b
# g = R.add(a, a)
# h = R.add(b, g)
# R.output(h)
# Dead-code-elimination step. This technically isn't required
# until the pattern matching has converged, but performing it now
# prevents testing for matches on dead code.
#
# @R.function(private=True)
# def intermediate(a: R.Tensor((1024,)), b: R.Tensor((1024,))):
# with R.dataflow():
# g = R.add(a, a)
# h = R.add(b, g)
# R.output(h)
# Second pattern-matching step. Now, the `R.add(a,a)` can match
# the other option in our pattern, and be rewritten as
# `R.multiply(a,R.const(2))`.
#
# @R.function(private=True)
# def intermediate(a: R.Tensor((1024,)), b: R.Tensor((1024,))):
# with R.dataflow():
# g = R.multiply(a, R.const(2))
# h = R.add(b, g)
# R.output(h)
# Canonicalization and dead-code-elimination are applied again,
# but have no effect this time.
@R.function(private=True)
def expected(a: R.Tensor((1024,)), b: R.Tensor((1024,))):
with R.dataflow():
g = R.multiply(a, R.const(2))
h = R.add(b, g)
R.output(h)
return h
pat_args = wildcard()
op_concat = is_op("relax.concat")
pat_concat = op_concat(pat_args).has_attr({"axis": 0})
op_split = is_op("relax.split")
pat_split = op_split(pat_concat).has_attr({"axis": 0, "indices_or_sections": T.int64(2)})
pat_unwrap_concat_split = pat_split
pat_arg = wildcard()
op_add = is_op("relax.add")
pat_add_self = op_add(pat_arg, pat_arg)
pattern = pat_unwrap_concat_split | pat_add_self
def rewriter(expr, matches):
if pat_unwrap_concat_split in matches:
args = matches[pat_args]
if len(args) == 2 and tvm.ir.structural_equal(args[0].struct_info, args[1].struct_info):
return args
elif pat_add_self in matches:
arg = matches[pat_arg]
return arg * rx.const(2)
return expr
after = rewrite_call(pattern, rewriter, before)
tvm.ir.assert_structural_equal(expected, after)