fromtvm.scriptimportrelaxasRfromtvm.scriptimporttirasTfromtvmimportrelaxasrxfromtvmimportrelay,tirfromtvm.relax.analysisimportget_var2valimporttvm.testingfromtvm.relax.dplimport*same_shape_func_type=tvm.testing.parameter("same_static_shape","same_dynamic_shape","different_static_shape","different_dynamic_shape",)deftest_iterative_rewrite_without_trivial_binding():"""Avoid introducing common sub-expressions Pattern replacement may produce the same intermediate, which should appear only once in the final result. """@R.function(private=True)defbefore(x:R.Tensor((1024,))):withR.dataflow():a=R.strided_slice(x,[0],[0],[512],[1])b=R.strided_slice(x,[0],[512],[1024],[1])c=R.add(a,b)R.output(c)returnc@R.function(private=True)defexpected(x:R.Tensor((1024,))):withR.dataflow():x_split=R.split(x,2)a=x_split[0]b=x_split[1]c=R.add(a,b)R.output(c)returncpattern_arg=wildcard()pattern_axes=wildcard()pattern_begin=wildcard()pattern_end=wildcard()pattern_strides=wildcard()pattern=is_op("relax.strided_slice")(pattern_arg,pattern_axes,pattern_begin,pattern_end,pattern_strides)defrewriter(expr,matches):arg=matches[pattern_arg]axes=matches[pattern_axes]begin=matches[pattern_begin]end=matches[pattern_end]strides=matches[pattern_strides]strided_slice=matches[pattern]ifarg.struct_info.shapeisNone:returnexpriflen(axes)!=1:returnexpraxis=axes[0].valuebegin=begin[0].valueend=end[0].valuestride=strides[0].valueifstride!=1:returnexprsize=arg.struct_info.shape[0]if(isinstance(size,tir.IntImm)andisinstance(begin,tir.IntImm)andisinstance(end,tir.IntImm)):size=size.valuebegin=begin.valueend=end.valueelse:returnexprgcd=functools.reduce(math.gcd,[begin,end,size])if(end-begin)//gcd==1:returnrx.op.split(arg,size//gcd)[begin//gcd]returnexprafter=rewrite_call(pattern,rewriter,before)tvm.ir.assert_structural_equal(after,expected)deftest_iterative_rewrite_with_removed_intermediates():"""Pattern replacement may require canonicalization A pattern may replace a tuple returned by a function with a tuple whose contents are known by Relax. In that case, canonicalization is required to unwrap the TupleGetItem instances into the known contents. This test case shows the intermediate results produced in the process of pattern-matching. """@R.function(private=True)defbefore(a:R.Tensor((1024,)),b:R.Tensor((1024,))):withR.dataflow():c=R.concat([a,b])d=R.split(c,2)e=d[0]f=d[1]g=R.add(a,e)h=R.add(f,g)R.output(h)returnh# First pattern rewrite. The concat/rewrite can be unwrapped, so# `d` is rewritten from `R.split(c, 2)` into `(a, b)`.## @R.function(private=True)# def intermediate(a: R.Tensor((1024,)), b: R.Tensor((1024,))):# with R.dataflow():# c = R.concat([a, b])# d = (a,b)# e = d[0]# f = d[1]# g = R.add(a, e)# h = R.add(f, g)# R.output(h)# Canonicalization step. Because `d` is known to be `(a,b)`,# canonicalization can rewrite `d[0]` into `a` and `d[1]` into# `b`.## @R.function(private=True)# def intermediate(a: R.Tensor((1024,)), b: R.Tensor((1024,))):# with R.dataflow():# c = R.concat([a, b])# d = (a,b)# e = a# f = b# g = R.add(a, a)# h = R.add(b, g)# R.output(h)# Dead-code-elimination step. This technically isn't required# until the pattern matching has converged, but performing it now# prevents testing for matches on dead code.## @R.function(private=True)# def intermediate(a: R.Tensor((1024,)), b: R.Tensor((1024,))):# with R.dataflow():# g = R.add(a, a)# h = R.add(b, g)# R.output(h)# Second pattern-matching step. Now, the `R.add(a,a)` can match# the other option in our pattern, and be rewritten as# `R.multiply(a,R.const(2))`.## @R.function(private=True)# def intermediate(a: R.Tensor((1024,)), b: R.Tensor((1024,))):# with R.dataflow():# g = R.multiply(a, R.const(2))# h = R.add(b, g)# R.output(h)# Canonicalization and dead-code-elimination are applied again,# but have no effect this time.@R.function(private=True)defexpected(a:R.Tensor((1024,)),b:R.Tensor((1024,))):withR.dataflow():g=R.multiply(a,R.const(2))h=R.add(b,g)R.output(h)returnhpat_args=wildcard()op_concat=is_op("relax.concat")pat_concat=op_concat(pat_args).has_attr({"axis":0})op_split=is_op("relax.split")pat_split=op_split(pat_concat).has_attr({"axis":0,"indices_or_sections":T.int64(2)})pat_unwrap_concat_split=pat_splitpat_arg=wildcard()op_add=is_op("relax.add")pat_add_self=op_add(pat_arg,pat_arg)pattern=pat_unwrap_concat_split|pat_add_selfdefrewriter(expr,matches):ifpat_unwrap_concat_splitinmatches:args=matches[pat_args]iflen(args)==2andtvm.ir.structural_equal(args[0].struct_info,args[1].struct_info):returnargselifpat_add_selfinmatches:arg=matches[pat_arg]returnarg*rx.const(2)returnexprafter=rewrite_call(pattern,rewriter,before)tvm.ir.assert_structural_equal(expected,after)