分区检查

分区检查#

from testing import viz_expr # 可视化 relay
import tvm
from tvm import relay
from tvm.relay.dataflow_pattern import *
def test_partition_check():
    pattern = is_op("nn.relu")(is_op("nn.conv2d")(is_var("input"), wildcard()))

    def check(pre):
        return pre.args[0].attrs.data_layout == "NCHW"

    x = relay.var("input")
    w = relay.var("weight")
    conv2d = relay.op.nn.conv2d(x, w)
    relu = relay.op.nn.relu(conv2d)

    xf = relay.var("input")
    wf = relay.var("weight")
    conv2df = relay.op.nn.conv2d(xf, wf)
    reluf = relay.op.nn.relu(conv2df)
    func = relay.Function([xf, wf], reluf).with_attr("PartitionedFromPattern", "nn.conv2d_nn.relu_")

    reference = func(x, w)
    partitioned = pattern.partition(relu, check=check)
    assert tvm.ir.structural_equal(partitioned, reference)

    conv2d = relay.op.nn.conv2d(x, w, data_layout="NHWC")
    relu = relay.op.nn.relu(conv2d)
    assert relu == pattern.partition(relu, check=check)
def test_partition_check_types():
    pattern = is_op("nn.relu")(is_op("nn.conv2d")(wildcard(), wildcard()))

    def check(pre):
        conv = pre.args[0]
        return (conv.attrs.data_layout == "NCHW") and bool(conv.checked_type.shape[0] == 1)

    x = relay.var("input", shape=(1, 10, 10, 10))
    w = relay.var("weight", shape=(10, 10, 3, 3))
    conv2d = relay.op.nn.conv2d(x, w)
    relu = relay.op.nn.relu(conv2d)
    relu = run_opt_pass(relu, relay.transform.InferType())

    partitioned = pattern.partition(relu, check=check)
    assert partitioned.op.attrs["PartitionedFromPattern"] == "nn.conv2d_nn.relu_"

    conv2d = relay.op.nn.conv2d(x, w, data_layout="NHWC")
    relu = relay.op.nn.relu(conv2d)
    relu = run_opt_pass(relu, relay.transform.InferType())
    assert relu == pattern.partition(relu, check=check)

    x = relay.var("input", shape=(2, 10, 10, 10))
    w = relay.var("weight", shape=(10, 10, 3, 3))
    conv2d = relay.op.nn.conv2d(x, w)
    relu = relay.op.nn.relu(conv2d)
    relu = run_opt_pass(relu, relay.transform.InferType())
    assert relu == pattern.partition(relu, check=check)