def test_partition_fuzzy_function_args():
func_pattern = FunctionPattern(None, wildcard() + wildcard())(None) + wildcard()
x = relay.var("x")
y = relay.var("y")
z = relay.var("z")
b = relay.var("b")
xp = relay.var("xp")
yp = relay.var("yp")
zp = relay.var("zp")
def create_func(call):
N = len(call.op.params)
new_params = [relay.var(str(i)) for i in range(N + 1)]
label = "add_FunctionCall_add_"
if N == 3:
label = "add_" + label
return relay.Function(
new_params, relay.Call(call.op, (new_params[0:-1])) + new_params[-1]
).with_attr("PartitionedFromPattern", label)(*([x, y, z][0:N] + [b]))
f1 = relay.Function([xp], xp + xp)(x)
one = func_pattern.partition(f1 + b)
assert tvm.ir.structural_equal(one, create_func(f1))
f2 = relay.Function([xp, yp], xp + yp)(x, y)
two = func_pattern.partition(f2 + b)
assert tvm.ir.structural_equal(two, create_func(f2))
f3 = relay.Function([xp, yp, zp], xp + yp + zp)(x, y, z)
three = func_pattern.partition(f3 + b)
assert tvm.ir.structural_equal(three, create_func(f3))