def test_partition_fuzzy_tuple():
x = relay.var("x")
y = relay.var("y")
z = x + y
tuple_pattern = is_tuple(None)
concat_pattern = is_op("concatenate")(tuple_pattern)
xp = relay.var("xp")
yp = relay.var("yp")
zp = relay.var("zp")
def create_func(args, body):
return relay.Function(args, body).with_attr("PartitionedFromPattern", "Tuple_concatenate_")
def concat(*args):
return relay.op.concatenate(relay.expr.Tuple(args), axis=0)
one = concat_pattern.partition(concat(x))
assert tvm.ir.structural_equal(one, create_func([xp], concat(xp))(x))
two = concat_pattern.partition(concat(x, y))
assert tvm.ir.structural_equal(two, create_func([xp, yp], concat(xp, yp))(x, y))
three = concat_pattern.partition(concat(x, y, z))
assert tvm.ir.structural_equal(three, create_func([xp, yp, zp], concat(xp, yp, zp))(x, y, z))