def test_partition_optional_function():
x = relay.var("x")
w = relay.var("w")
b = relay.var("b")
x1 = relay.var("x1")
w1 = relay.var("w1")
wc_x = wildcard()
wc_w = wildcard()
wc_x1 = wildcard()
wc_w1 = wildcard()
func_pattern0 = FunctionPattern(
[wc_x1, wc_w1], is_op("sigmoid")(is_op("nn.conv2d")(wc_x1, wc_w1))
)
func_pattern1 = FunctionPattern(
[wc_x1, wc_w1], is_op("nn.relu")(is_op("nn.conv2d")(wc_x1, wc_w1))
)
pattern = func_pattern0(wc_x, wc_w) | func_pattern1(wc_x, wc_w)
func = relay.Function([x1, w1], relay.nn.relu(relay.nn.conv2d(x1, w1)))
expr = func(x, w) + b
x2 = relay.var("x2")
w2 = relay.var("w2")
func2 = relay.Function([x2, w2], func(x2, w2)).with_attr(
"PartitionedFromPattern", "nn.conv2d_nn.relu_FunctionCall_"
)
expr2 = func2(x, w) + b
assert tvm.ir.structural_equal(pattern.partition(expr), expr2)