def conv_bias_relu(x, w, b):
conv2d = relay.op.nn.conv2d(x, w)
bias_add = relay.op.nn.bias_add(conv2d, b)
relu = relay.op.nn.relu(bias_add)
return relu
def test_partition_option():
x = relay.var("x")
w = relay.var("w")
b = relay.var("b")
conv2d = is_op("nn.conv2d")(wildcard(), wildcard())
bias = conv2d.optional(lambda x: is_op("nn.bias_add")(x, wildcard()))
pattern1 = is_op("nn.relu")(bias)
conv2d = is_op("nn.conv2d")(wildcard(), wildcard())
bias = is_op("nn.bias_add")(conv2d, wildcard())
pattern2 = bias.optional(lambda x: is_op("nn.relu")(x))
relu = conv_bias_relu(x, w, b)
xf = relay.var("x")
wf = relay.var("w")
bf = relay.var("b")
func = relay.Function([xf, wf, bf], conv_bias_relu(xf, wf, bf)).with_attr(
"PartitionedFromPattern", "nn.conv2d_nn.bias_add_nn.relu_"
)
assert pattern1.match(relu)
assert tvm.ir.structural_equal(func(x, w, b), pattern1.partition(relu))
assert pattern2.match(relu)
assert tvm.ir.structural_equal(func(x, w, b), pattern2.partition(relu))