def test_same_shape_pattern(same_shape_func_type):
if same_shape_func_type == "same_static_shape":
@R.function(private=True)
def func(
a: R.Tensor((1024, 128), "float32"),
b: R.Tensor((1024, 128), "float32"),
) -> R.Tensor:
with R.dataflow():
c = R.multiply(a, R.const(2.0))
d = R.add(b, c)
out = d
R.output(out)
return out
elif same_shape_func_type == "same_dynamic_shape":
@R.function(private=True)
def func(
a: R.Tensor(("n", 128), "float32"),
b: R.Tensor(("n", 128), "float32"),
) -> R.Tensor:
with R.dataflow():
c = R.multiply(a, R.const(2.0))
d = R.add(b, c)
out = d
R.output(out)
return out
elif same_shape_func_type == "different_static_shape":
@R.function(private=True)
def func(
a: R.Tensor((1024, 128), "float32"),
b: R.Tensor((1, 128), "float32"),
) -> R.Tensor:
with R.dataflow():
c = R.multiply(a, R.const(2.0))
d = R.add(b, c)
out = d
R.output(out)
return out
elif same_shape_func_type == "different_dynamic_shape":
@R.function(private=True)
def func(
a: R.Tensor(("n", 128), "float32"),
b: R.Tensor(("m", 128), "float32"),
) -> R.Tensor:
with R.dataflow():
c = R.multiply(a, R.const(2.0))
d = R.add(b, c)
out = d
R.output(out)
return out
else:
raise ValueError(f"Unknown value of same_shape_func_type={same_shape_func_type}")
with PatternContext() as ctx:
pat_lhs = wildcard()
pat_rhs = wildcard()
pat_sum = is_op("relax.add")(pat_lhs, pat_rhs)
pat_lhs.same_shape_as(pat_rhs)
block = func.body.blocks[0]
match = ctx.match_dfb(block)
if "same" in same_shape_func_type:
assert match
else:
assert match is None