deftest_same_shape_pattern(same_shape_func_type):ifsame_shape_func_type=="same_static_shape":@R.function(private=True)deffunc(a:R.Tensor((1024,128),"float32"),b:R.Tensor((1024,128),"float32"),)->R.Tensor:withR.dataflow():c=R.multiply(a,R.const(2.0))d=R.add(b,c)out=dR.output(out)returnoutelifsame_shape_func_type=="same_dynamic_shape":@R.function(private=True)deffunc(a:R.Tensor(("n",128),"float32"),b:R.Tensor(("n",128),"float32"),)->R.Tensor:withR.dataflow():c=R.multiply(a,R.const(2.0))d=R.add(b,c)out=dR.output(out)returnoutelifsame_shape_func_type=="different_static_shape":@R.function(private=True)deffunc(a:R.Tensor((1024,128),"float32"),b:R.Tensor((1,128),"float32"),)->R.Tensor:withR.dataflow():c=R.multiply(a,R.const(2.0))d=R.add(b,c)out=dR.output(out)returnoutelifsame_shape_func_type=="different_dynamic_shape":@R.function(private=True)deffunc(a:R.Tensor(("n",128),"float32"),b:R.Tensor(("m",128),"float32"),)->R.Tensor:withR.dataflow():c=R.multiply(a,R.const(2.0))d=R.add(b,c)out=dR.output(out)returnoutelse:raiseValueError(f"Unknown value of same_shape_func_type={same_shape_func_type}")withPatternContext()asctx:pat_lhs=wildcard()pat_rhs=wildcard()pat_sum=is_op("relax.add")(pat_lhs,pat_rhs)pat_lhs.same_shape_as(pat_rhs)block=func.body.blocks[0]match=ctx.match_dfb(block)if"same"insame_shape_func_type:assertmatchelse:assertmatchisNone