def test_partition_function():
x = relay.var("x")
w = relay.var("w")
b = relay.var("b")
x1 = relay.var("x1")
w1 = relay.var("w1")
wc_x = wildcard()
wc_w = wildcard()
wc_b = wildcard()
wc_x1 = wildcard()
wc_w1 = wildcard()
func_pattern = FunctionPattern([wc_x1, wc_w1], is_op("nn.conv2d")(wc_x1, wc_w1))
pattern = func_pattern(wc_x, wc_w) + wc_b
func = relay.Function([x1, w1], relay.nn.conv2d(x1, w1))
expr = func(x, w) + b + b
x2 = relay.var("x2")
w2 = relay.var("w2")
b2 = relay.var("b2")
func2 = relay.Function([x2, w2, b2], func(x2, w2) + b2).with_attr(
"PartitionedFromPattern", "nn.conv2d_FunctionCall_add_"
)
expr2 = func2(x, w, b) + b
assert tvm.ir.structural_equal(pattern.partition(expr), expr2)