分区可选

分区可选#

from testing import viz_expr # 可视化 relay
import tvm
from tvm import relay
from tvm.relay.dataflow_pattern import *
def conv_bias_relu(x, w, b):
    conv2d = relay.op.nn.conv2d(x, w)
    bias_add = relay.op.nn.bias_add(conv2d, b)
    relu = relay.op.nn.relu(bias_add)
    return relu


def test_partition_option():
    x = relay.var("x")
    w = relay.var("w")
    b = relay.var("b")

    conv2d = is_op("nn.conv2d")(wildcard(), wildcard())
    bias = conv2d.optional(lambda x: is_op("nn.bias_add")(x, wildcard()))
    pattern1 = is_op("nn.relu")(bias)

    conv2d = is_op("nn.conv2d")(wildcard(), wildcard())
    bias = is_op("nn.bias_add")(conv2d, wildcard())
    pattern2 = bias.optional(lambda x: is_op("nn.relu")(x))

    relu = conv_bias_relu(x, w, b)

    xf = relay.var("x")
    wf = relay.var("w")
    bf = relay.var("b")
    func = relay.Function([xf, wf, bf], conv_bias_relu(xf, wf, bf)).with_attr(
        "PartitionedFromPattern", "nn.conv2d_nn.bias_add_nn.relu_"
    )

    assert pattern1.match(relu)
    assert tvm.ir.structural_equal(func(x, w, b), pattern1.partition(relu))

    assert pattern2.match(relu)
    assert tvm.ir.structural_equal(func(x, w, b), pattern2.partition(relu))