MergeComposite#

合并复合(merge composite)过程旨在将多个符合特定模式的 Relay 算子合并,并将它们组合成单一的 Relay 函数。

例如,假设有如下计算图:

    conv2d
      |            (合并复合过程)
   bias_add            ====>           conv2d_bias_relu
      |            (我们的目标)
     relu

在合并复合过程之前的 Relay IR:

    fn (%data: Tensor[(1, 512, 28, 28), float32], %kernel: Tensor[(256, 512, 1, 1), float32],
            %bias: Tensor[(256), float32]) -> Tensor[(1, 256, 28, 28), float32] {
        %0 = nn.conv2d(%data, %kernel, kernel_size=[1, 1])
            /* ty=Tensor[(1, 256, 28, 28), float32] */;
        %1 = nn.bias_add(%0, %bias) /* ty=Tensor[(1, 256, 28, 28), float32] */;
        nn.relu(%1) /* ty=Tensor[(1, 256, 28, 28), float32] */
    }

在合并复合过程之后的 Relay IR:

    fn (%data: Tensor[(1, 512, 28, 28), float32], %kernel: Tensor[(256, 512, 1, 1), float32],
            %bias: Tensor[(256), float32]) -> Tensor[(1, 256, 28, 28), float32] {
      %2 = fn (%x: Tensor[(1, 512, 28, 28), float32], %y: Tensor[(256, 512, 1, 1), float32],
            %z: Tensor[(256), float32], Primitive=1, Composite="conv2d_bias_relu") ->
            Tensor[(1, 256, 28, 28), float32] {
        %0 = nn.conv2d(%x, %y, kernel_size=[1, 1]) /* ty=Tensor[(1, 256, 28, 28), float32] */;
        %1 = nn.bias_add(%0, %z) /* ty=Tensor[(1, 256, 28, 28), float32] */;
        nn.relu(%1) /* ty=Tensor[(1, 256, 28, 28), float32] */
      };
      %2(%data, %kernel, %bias) /* ty=Tensor[(1, 256, 28, 28), float32] */
    }

正如你在第二个 Relay 示例中所看到的,指定的模式被封装在函数中。然后调用该函数,产生与第一个 Relay 示例相同的结果。

这个合并复合过程的便利用途是将多个算子卸载到单一的外部代码生成函数中。

Hide code cell content
from tvm.relay.dataflow_pattern import TupleGetItemPattern, is_op, wildcard

def make_add_sub_mul_pattern():
    r"""Create a pattern to match the following graph.

    add  sub
     \   /
      \ /
      mul
    """
    x = wildcard()
    y = wildcard()
    return (x + y) * (x - y)


def make_add_relu_pattern():
    r"""Create a pattern to match the following graph.

     add
      |
    relu
    """
    add_node = wildcard() + wildcard()
    r = is_op("nn.relu")(add_node)
    return r


def make_conv_bias_relu_pattern():
    r"""Create a pattern to match the following graph.

     conv2d
       |
    bias_add
       |
     relu
    """
    x = wildcard()
    y = wildcard()
    z = wildcard()
    conv_node = is_op("nn.conv2d")(x, y)
    bias_node = is_op("nn.bias_add")(conv_node, z)
    r = is_op("nn.relu")(bias_node)
    return r


def make_pattern_with_optional():
    r"""Create a pattern to match the following graph. Note that relu is optinal.

     conv2d
       |
    bias_add
       |
     (relu)
    """
    x = wildcard()
    y = wildcard()
    z = wildcard()
    conv_node = is_op("nn.conv2d")(x, y)
    bias_node = is_op("nn.bias_add")(conv_node, z)
    r = bias_node.optional(lambda x: is_op("nn.relu")(x))
    return r


def make_add_add_add_pattern():
    r"""Create a pattern to match the following graph.
       Useful for testing re-using a call node.

        x    y
      /  \  /
      |  add
       \  |  \
         add |
          | /
         add
    """
    x = wildcard()
    y = wildcard()
    add_node = is_op("add")(x, y)
    add_node_1 = is_op("add")(x, add_node)
    r = is_op("add")(add_node_1, add_node)
    return r


def make_bn_relu_pattern():
    r"""Create a pattern to match the following graph.

     batch_norm
         |
    TupleGetItem(0)
         |
       relu
    """
    x = wildcard()
    gamma = wildcard()
    beta = wildcard()
    moving_mean = wildcard()
    moving_var = wildcard()
    bn_node = is_op("nn.batch_norm")(x, gamma, beta, moving_mean, moving_var)
    tuple_get_item_node = TupleGetItemPattern(bn_node, 0)
    r = is_op("nn.relu")(tuple_get_item_node)
    return r

简单的合并#

从简单的计算图中正确生成复合函数。期望模式 make_add_relu_pattern 能够被合并成单一的算子 add_relu

        a  b
        \ /               a  b
        add    ====>      \ /
         |             add_relu
       relu
pattern_table = [("add_relu", make_add_relu_pattern())]

查看效果:

import tvm
from tvm import relay, tir
from tvm.relay.testing import run_opt_pass

def before():
    a = relay.var("a", shape=(10, 10))
    b = relay.var("b", shape=(10, 10))
    add_node = relay.add(a, b)
    r = relay.nn.relu(add_node)
    return relay.Function([a, b], r)
graph = before()
tvm.IRModule.from_expr(graph).show()
def @main(%a: Tensor[(10, 10), float32], %b: Tensor[(10, 10), float32]) {
  %0 = add(%a, %b);
  nn.relu(%0)
}
result = run_opt_pass(
    graph, relay.transform.MergeComposite(pattern_table), import_prelude=False
)
assert not relay.analysis.free_vars(result), f"在{result}图中发现了自由变量"
tvm.IRModule.from_expr(result).show()
def @main(%a: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %b: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */) -> Tensor[(10, 10), float32] {
  %1 = fn (%FunctionVar_0_0: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %FunctionVar_0_1: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, PartitionedFromPattern="add_nn.relu_", Composite="add_relu") -> Tensor[(10, 10), float32] {
    %0 = add(%FunctionVar_0_0, %FunctionVar_0_1) /* ty=Tensor[(10, 10), float32] */;
    nn.relu(%0) /* ty=Tensor[(10, 10), float32] */
  } /* ty=fn (Tensor[(10, 10), float32], Tensor[(10, 10), float32]) -> Tensor[(10, 10), float32] */;
  %1(%a, %b) /* ty=Tensor[(10, 10), float32] */
}

分支合并#

测试从分支图中正确生成复合函数。

期望模式 make_add_sub_mul_pattern 能够被合并成单一的算子 add_sub_mul

       a  b  a  b
        \/    \/
        add  sub                       a  b
         \   /                          \/
          \ /                      add_sub_mul
          mul                     c     |
          /  \                     \    |
       c /  c |       ====>        add_sub_mul
       \/   \/                          |
       add  sub                         |
        \   /                         relu
         \ /
         mul
          |
          |
        relu
pattern_table = [("add_sub_mul", make_add_sub_mul_pattern())]
def before():
    a = relay.var("a", shape=(10, 10))
    b = relay.var("b", shape=(10, 10))
    c = relay.var("c", shape=(10, 10))
    add_node = relay.add(a, b)
    sub_node = relay.subtract(a, b)
    mul_node = relay.multiply(add_node, sub_node)
    add_node_2 = relay.add(c, mul_node)
    sub_node_2 = relay.subtract(c, mul_node)
    mul_node_2 = relay.multiply(add_node_2, sub_node_2)
    r = relay.nn.relu(mul_node_2)
    return relay.Function([a, b, c], r)

graph = before()
tvm.IRModule.from_expr(graph).show()
result = run_opt_pass(
    graph, relay.transform.MergeComposite(pattern_table), import_prelude=False
)
assert not relay.analysis.free_vars(result), f"在{result}图中发现了自由变量"
tvm.IRModule.from_expr(result).show()
Hide code cell output
def @main(%a: Tensor[(10, 10), float32], %b: Tensor[(10, 10), float32], %c: Tensor[(10, 10), float32]) {
  %0 = add(%a, %b);
  %1 = subtract(%a, %b);
  %2 = multiply(%0, %1);
  %3 = add(%c, %2);
  %4 = subtract(%c, %2);
  %5 = multiply(%3, %4);
  nn.relu(%5)
}
def @main(%a: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %b: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %c: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */) -> Tensor[(10, 10), float32] {
  %4 = fn (%FunctionVar_1_0: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %FunctionVar_1_1: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, PartitionedFromPattern="add_subtract_multiply_", Composite="add_sub_mul") -> Tensor[(10, 10), float32] {
    %2 = add(%FunctionVar_1_0, %FunctionVar_1_1) /* ty=Tensor[(10, 10), float32] */;
    %3 = subtract(%FunctionVar_1_0, %FunctionVar_1_1) /* ty=Tensor[(10, 10), float32] */;
    multiply(%2, %3) /* ty=Tensor[(10, 10), float32] */
  } /* ty=fn (Tensor[(10, 10), float32], Tensor[(10, 10), float32]) -> Tensor[(10, 10), float32] */;
  %5 = %4(%a, %b) /* ty=Tensor[(10, 10), float32] */;
  %6 = fn (%FunctionVar_0_0: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %FunctionVar_0_1: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, PartitionedFromPattern="add_subtract_multiply_", Composite="add_sub_mul") -> Tensor[(10, 10), float32] {
    %0 = add(%FunctionVar_0_0, %FunctionVar_0_1) /* ty=Tensor[(10, 10), float32] */;
    %1 = subtract(%FunctionVar_0_0, %FunctionVar_0_1) /* ty=Tensor[(10, 10), float32] */;
    multiply(%0, %1) /* ty=Tensor[(10, 10), float32] */
  } /* ty=fn (Tensor[(10, 10), float32], Tensor[(10, 10), float32]) -> Tensor[(10, 10), float32] */;
  %7 = %6(%c, %5) /* ty=Tensor[(10, 10), float32] */;
  nn.relu(%7) /* ty=Tensor[(10, 10), float32] */
}

重用调用合并#

测试从重用调用节点的简单图中正确生成复合函数。

期望模式 make_add_add_add 能够被合并成单一的算子 add_add_add

        x     y
         \   / \
          sub  |           x     y
        /  |  /             \   / |
        | add      ====>     sub  |
         \ |  \               |  /
          add |           add_add_add
           | /
          add
def before():
    a = relay.var("a", shape=(10, 10))
    b = relay.var("b", shape=(10, 10))
    sub_node = relay.subtract(a, b)

    # pattern
    add_node = relay.add(sub_node, b)
    add_node_1 = relay.add(sub_node, add_node)
    r = relay.add(add_node_1, add_node)

    return relay.Function([a, b], r)


pattern_table = [("add_add_add", make_add_add_add_pattern())]
graph = before()
tvm.IRModule.from_expr(graph).show()
result = run_opt_pass(
    graph, relay.transform.MergeComposite(pattern_table), import_prelude=False
)
assert not relay.analysis.free_vars(result), f"在{result}图中发现了自由变量"
tvm.IRModule.from_expr(result).show()
def @main(%a: Tensor[(10, 10), float32], %b: Tensor[(10, 10), float32]) {
  %0 = subtract(%a, %b);
  %1 = add(%0, %b);
  %2 = add(%0, %1);
  add(%2, %1)
}
def @main(%a: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %b: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */) -> Tensor[(10, 10), float32] {
  %2 = subtract(%a, %b) /* ty=Tensor[(10, 10), float32] */;
  %3 = fn (%FunctionVar_0_0: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %FunctionVar_0_1: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, PartitionedFromPattern="add_add_add_", Composite="add_add_add") -> Tensor[(10, 10), float32] {
    %0 = add(%FunctionVar_0_0, %FunctionVar_0_1) /* ty=Tensor[(10, 10), float32] */;
    %1 = add(%FunctionVar_0_0, %0) /* ty=Tensor[(10, 10), float32] */;
    add(%1, %0) /* ty=Tensor[(10, 10), float32] */
  } /* ty=fn (Tensor[(10, 10), float32], Tensor[(10, 10), float32]) -> Tensor[(10, 10), float32] */;
  %3(%2, %b) /* ty=Tensor[(10, 10), float32] */
}

合并多个模式#

测试图中不同模式是否正确合并。

期望模式 make_conv_bias_relu_pattern 能够被合并成单一的算子 conv_bias_relu。同时,也期望 make_add_relu_pattern 能够被合并成单一的算子 add_relu

        data   kernel
          \      /
           \    /
           conv2d                   data   kernel   bias
             |                         \      |      /
             |   bias                 conv2d_bias_relu
             |   /                            |
          bias_add        ====>               |    a
             |                                |   /
           relu  a                        add_relu
             \  /                             |
             add                              |  b
              |                               | /
            relu  b                          mul
              |  /
             mul
def before():
    data = relay.var("data", shape=(1, 512, 28, 28))
    kernel = relay.var("kernel", shape=(256, 512, 1, 1))
    bias = relay.var("bias", shape=(256,))
    a = relay.var("a", shape=(1, 256, 28, 28))
    b = relay.var("b", shape=(1, 256, 28, 28))

    conv_node = relay.nn.conv2d(
        data, kernel, kernel_size=(1, 1), padding=(0, 0), strides=(1, 1)
    )

    bias_node = relay.nn.bias_add(conv_node, bias)
    relu_node = relay.nn.relu(bias_node)
    add_node = relay.add(relu_node, a)
    relu_node_2 = relay.nn.relu(add_node)
    r = relay.multiply(relu_node_2, b)
    return relay.Function([data, kernel, bias, a, b], r)


pattern_table = [
    ("conv2d_bias_relu", make_conv_bias_relu_pattern()),
    ("add_relu", make_add_relu_pattern()),
]
graph = before()
tvm.IRModule.from_expr(graph).show()
result = run_opt_pass(
    graph, relay.transform.MergeComposite(pattern_table), import_prelude=False
)
assert not relay.analysis.free_vars(result), f"在{result}图中发现了自由变量"
tvm.IRModule.from_expr(result).show()
Hide code cell output
def @main(%data: Tensor[(1, 512, 28, 28), float32], %kernel: Tensor[(256, 512, 1, 1), float32], %bias: Tensor[(256), float32], %a: Tensor[(1, 256, 28, 28), float32], %b: Tensor[(1, 256, 28, 28), float32]) {
  %0 = nn.conv2d(%data, %kernel, padding=[0, 0, 0, 0], kernel_size=[1, 1]);
  %1 = nn.bias_add(%0, %bias);
  %2 = nn.relu(%1);
  %3 = add(%2, %a);
  %4 = nn.relu(%3);
  multiply(%4, %b)
}
def @main(%data: Tensor[(1, 512, 28, 28), float32] /* ty=Tensor[(1, 512, 28, 28), float32] */, %kernel: Tensor[(256, 512, 1, 1), float32] /* ty=Tensor[(256, 512, 1, 1), float32] */, %bias: Tensor[(256), float32] /* ty=Tensor[(256), float32] */, %a: Tensor[(1, 256, 28, 28), float32] /* ty=Tensor[(1, 256, 28, 28), float32] */, %b: Tensor[(1, 256, 28, 28), float32] /* ty=Tensor[(1, 256, 28, 28), float32] */) -> Tensor[(1, 256, 28, 28), float32] {
  %3 = fn (%FunctionVar_0_01: Tensor[(1, 512, 28, 28), float32] /* ty=Tensor[(1, 512, 28, 28), float32] */, %FunctionVar_0_11: Tensor[(256, 512, 1, 1), float32] /* ty=Tensor[(256, 512, 1, 1), float32] */, %FunctionVar_0_2: Tensor[(256), float32] /* ty=Tensor[(256), float32] */, PartitionedFromPattern="nn.conv2d_nn.bias_add_nn.relu_", Composite="conv2d_bias_relu") -> Tensor[(1, 256, 28, 28), float32] {
    %1 = nn.conv2d(%FunctionVar_0_01, %FunctionVar_0_11, padding=[0, 0, 0, 0], kernel_size=[1, 1]) /* ty=Tensor[(1, 256, 28, 28), float32] */;
    %2 = nn.bias_add(%1, %FunctionVar_0_2) /* ty=Tensor[(1, 256, 28, 28), float32] */;
    nn.relu(%2) /* ty=Tensor[(1, 256, 28, 28), float32] */
  } /* ty=fn (Tensor[(1, 512, 28, 28), float32], Tensor[(256, 512, 1, 1), float32], Tensor[(256), float32]) -> Tensor[(1, 256, 28, 28), float32] */;
  %4 = %3(%data, %kernel, %bias) /* ty=Tensor[(1, 256, 28, 28), float32] */;
  %5 = fn (%FunctionVar_0_0: Tensor[(1, 256, 28, 28), float32] /* ty=Tensor[(1, 256, 28, 28), float32] */, %FunctionVar_0_1: Tensor[(1, 256, 28, 28), float32] /* ty=Tensor[(1, 256, 28, 28), float32] */, PartitionedFromPattern="add_nn.relu_", Composite="add_relu") -> Tensor[(1, 256, 28, 28), float32] {
    %0 = add(%FunctionVar_0_0, %FunctionVar_0_1) /* ty=Tensor[(1, 256, 28, 28), float32] */;
    nn.relu(%0) /* ty=Tensor[(1, 256, 28, 28), float32] */
  } /* ty=fn (Tensor[(1, 256, 28, 28), float32], Tensor[(1, 256, 28, 28), float32]) -> Tensor[(1, 256, 28, 28), float32] */;
  %6 = %5(%4, %a) /* ty=Tensor[(1, 256, 28, 28), float32] */;
  multiply(%6, %b) /* ty=Tensor[(1, 256, 28, 28), float32] */
}

合并可选模式#

测试包含可选算子的模式。可以定义包含某些可选算子的模式。合并复合过程将为所有匹配的模式创建复合函数,但会带有不同的 "PartitionedFromPattern" 属性。期望后端代码生成器能够分析该属性并确定相应的算子。

模式:        匹配情况 A:        匹配情况 B:

 conv2d        conv2d             conv2d
   |             |                  |
bias_add      bias_add           bias_add
   |             |
 (relu)         relu

在上面的示例中,匹配情况 A 的复合函数将具有 PartitionedFromPattern="nn.conv2d_nn.bias_add_nn.relu_",而匹配情况 B 的复合函数将是 "nn.conv2d_nn.bias_add_"

def before():
    x = relay.var("x", shape=(1, 3, 7, 7))
    w1 = relay.var("w", shape=(3, 3, 1, 1))
    b1 = relay.var("b", shape=(3,))
    w2 = relay.var("w", shape=(3, 3, 1, 1))
    b2 = relay.var("b", shape=(3,))
    conv = relay.nn.conv2d(x, w1, kernel_size=(1, 1))
    bias = relay.nn.bias_add(conv, b1)
    relu = relay.nn.relu(bias)
    conv = relay.nn.conv2d(relu, w2, kernel_size=(1, 1))
    bias = relay.nn.bias_add(conv, b2)
    return relay.Function([x, w1, w2, b1, b2], bias)


pattern_table = [("layer", make_pattern_with_optional())]
graph = before()
tvm.IRModule.from_expr(graph).show()
result = run_opt_pass(
    graph, relay.transform.MergeComposite(pattern_table), import_prelude=False
)
assert not relay.analysis.free_vars(result), f"在{result}图中发现了自由变量"
tvm.IRModule.from_expr(result).show()
Hide code cell output
def @main(%x: Tensor[(1, 3, 7, 7), float32], %w: Tensor[(3, 3, 1, 1), float32], %w1: Tensor[(3, 3, 1, 1), float32], %b: Tensor[(3), float32], %b1: Tensor[(3), float32]) {
  %0 = nn.conv2d(%x, %w, padding=[0, 0, 0, 0], kernel_size=[1, 1]);
  %1 = nn.bias_add(%0, %b);
  %2 = nn.relu(%1);
  %3 = nn.conv2d(%2, %w1, padding=[0, 0, 0, 0], kernel_size=[1, 1]);
  nn.bias_add(%3, %b1)
}
def @main(%x: Tensor[(1, 3, 7, 7), float32] /* ty=Tensor[(1, 3, 7, 7), float32] */, %w: Tensor[(3, 3, 1, 1), float32] /* ty=Tensor[(3, 3, 1, 1), float32] */, %w1: Tensor[(3, 3, 1, 1), float32] /* ty=Tensor[(3, 3, 1, 1), float32] */, %b: Tensor[(3), float32] /* ty=Tensor[(3), float32] */, %b1: Tensor[(3), float32] /* ty=Tensor[(3), float32] */) -> Tensor[(1, 3, 7, 7), float32] {
  %3 = fn (%FunctionVar_1_0: Tensor[(1, 3, 7, 7), float32] /* ty=Tensor[(1, 3, 7, 7), float32] */, %FunctionVar_1_1: Tensor[(3, 3, 1, 1), float32] /* ty=Tensor[(3, 3, 1, 1), float32] */, %FunctionVar_1_2: Tensor[(3), float32] /* ty=Tensor[(3), float32] */, PartitionedFromPattern="nn.conv2d_nn.bias_add_nn.relu_", Composite="layer") -> Tensor[(1, 3, 7, 7), float32] {
    %1 = nn.conv2d(%FunctionVar_1_0, %FunctionVar_1_1, padding=[0, 0, 0, 0], kernel_size=[1, 1]) /* ty=Tensor[(1, 3, 7, 7), float32] */;
    %2 = nn.bias_add(%1, %FunctionVar_1_2) /* ty=Tensor[(1, 3, 7, 7), float32] */;
    nn.relu(%2) /* ty=Tensor[(1, 3, 7, 7), float32] */
  } /* ty=fn (Tensor[(1, 3, 7, 7), float32], Tensor[(3, 3, 1, 1), float32], Tensor[(3), float32]) -> Tensor[(1, 3, 7, 7), float32] */;
  %4 = %3(%x, %w, %b) /* ty=Tensor[(1, 3, 7, 7), float32] */;
  %5 = fn (%FunctionVar_0_0: Tensor[(1, 3, 7, 7), float32] /* ty=Tensor[(1, 3, 7, 7), float32] */, %FunctionVar_0_1: Tensor[(3, 3, 1, 1), float32] /* ty=Tensor[(3, 3, 1, 1), float32] */, %FunctionVar_0_2: Tensor[(3), float32] /* ty=Tensor[(3), float32] */, PartitionedFromPattern="nn.conv2d_nn.bias_add_", Composite="layer") -> Tensor[(1, 3, 7, 7), float32] {
    %0 = nn.conv2d(%FunctionVar_0_0, %FunctionVar_0_1, padding=[0, 0, 0, 0], kernel_size=[1, 1]) /* ty=Tensor[(1, 3, 7, 7), float32] */;
    nn.bias_add(%0, %FunctionVar_0_2) /* ty=Tensor[(1, 3, 7, 7), float32] */
  } /* ty=fn (Tensor[(1, 3, 7, 7), float32], Tensor[(3, 3, 1, 1), float32], Tensor[(3), float32]) -> Tensor[(1, 3, 7, 7), float32] */;
  %5(%4, %w1, %b1) /* ty=Tensor[(1, 3, 7, 7), float32] */
}

依序合并#

测试模式是否按照它们在模式表中的顺序进行合并。

在某些情况下,一个模式可能是另一个模式的子图,此时不清楚哪个匹配应该优先。优先级应取决于模式在模式表中声明的顺序。最先声明的模式将以最高优先级合并,而最后声明的模式将以最低优先级合并。

from tvm.relay.dataflow_pattern import is_op, wildcard
def pattern_A():
    x = wildcard()
    y = wildcard()
    out = is_op("add")(x, y)
    out = is_op("abs")(out)
    out = is_op("nn.relu")(out)
    return out

def pattern_B():
    x = wildcard()
    y = wildcard()
    out = is_op("add")(x, y)
    out = is_op("abs")(out)
    return out

def pattern_C():
    x = wildcard()
    out = is_op("abs")(x)
    out = is_op("nn.relu")(out)
    return out

def before():
    input_1 = relay.var("input_1", shape=(10, 10))
    input_2 = relay.var("input_2", shape=(10, 10))
    out = relay.add(input_1, input_2)
    out = relay.abs(out)
    out = relay.nn.relu(out)
    return relay.Function([input_1, input_2], out)

检查 A 是否具有最高优先级:

pattern_table = [
    ("A", pattern_A()),
    ("B", pattern_B()),
    ("C", pattern_C()),
]
graph = before()
tvm.IRModule.from_expr(graph).show()
result = run_opt_pass(
    graph, relay.transform.MergeComposite(pattern_table), import_prelude=False
)
assert not relay.analysis.free_vars(result), f"在{result}图中发现了自由变量"
tvm.IRModule.from_expr(result).show()
def @main(%input_1: Tensor[(10, 10), float32], %input_2: Tensor[(10, 10), float32]) {
  %0 = add(%input_1, %input_2);
  %1 = abs(%0);
  nn.relu(%1)
}
def @main(%input_1: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %input_2: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */) -> Tensor[(10, 10), float32] {
  %2 = fn (%FunctionVar_0_0: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %FunctionVar_0_1: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, PartitionedFromPattern="add_abs_nn.relu_", Composite="A") -> Tensor[(10, 10), float32] {
    %0 = add(%FunctionVar_0_0, %FunctionVar_0_1) /* ty=Tensor[(10, 10), float32] */;
    %1 = abs(%0) /* ty=Tensor[(10, 10), float32] */;
    nn.relu(%1) /* ty=Tensor[(10, 10), float32] */
  } /* ty=fn (Tensor[(10, 10), float32], Tensor[(10, 10), float32]) -> Tensor[(10, 10), float32] */;
  %2(%input_1, %input_2) /* ty=Tensor[(10, 10), float32] */
}

检查 B 是否具有最高优先级:

pattern_table = [
    ("B", pattern_B()),
    ("C", pattern_C()),
    ("A", pattern_A()),
]
graph = before()
tvm.IRModule.from_expr(graph).show()
result = run_opt_pass(
    graph, relay.transform.MergeComposite(pattern_table), import_prelude=False
)
assert not relay.analysis.free_vars(result), f"在{result}图中发现了自由变量"
tvm.IRModule.from_expr(result).show()
def @main(%input_1: Tensor[(10, 10), float32], %input_2: Tensor[(10, 10), float32]) {
  %0 = add(%input_1, %input_2);
  %1 = abs(%0);
  nn.relu(%1)
}
def @main(%input_1: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %input_2: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */) -> Tensor[(10, 10), float32] {
  %1 = fn (%FunctionVar_0_0: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %FunctionVar_0_1: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, PartitionedFromPattern="add_abs_", Composite="B") -> Tensor[(10, 10), float32] {
    %0 = add(%FunctionVar_0_0, %FunctionVar_0_1) /* ty=Tensor[(10, 10), float32] */;
    abs(%0) /* ty=Tensor[(10, 10), float32] */
  } /* ty=fn (Tensor[(10, 10), float32], Tensor[(10, 10), float32]) -> Tensor[(10, 10), float32] */;
  %2 = %1(%input_1, %input_2) /* ty=Tensor[(10, 10), float32] */;
  nn.relu(%2) /* ty=Tensor[(10, 10), float32] */
}

检查 C 是否具有最高优先级:

pattern_table = [
    ("C", pattern_C()),
    ("A", pattern_A()),
    ("B", pattern_B()),
]
graph = before()
tvm.IRModule.from_expr(graph).show()
result = run_opt_pass(
    graph, relay.transform.MergeComposite(pattern_table), import_prelude=False
)
assert not relay.analysis.free_vars(result), f"在{result}图中发现了自由变量"
tvm.IRModule.from_expr(result).show()
def @main(%input_1: Tensor[(10, 10), float32], %input_2: Tensor[(10, 10), float32]) {
  %0 = add(%input_1, %input_2);
  %1 = abs(%0);
  nn.relu(%1)
}
def @main(%input_1: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %input_2: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */) -> Tensor[(10, 10), float32] {
  %1 = add(%input_1, %input_2) /* ty=Tensor[(10, 10), float32] */;
  %2 = fn (%FunctionVar_0_0: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, PartitionedFromPattern="abs_nn.relu_", Composite="C") -> Tensor[(10, 10), float32] {
    %0 = abs(%FunctionVar_0_0) /* ty=Tensor[(10, 10), float32] */;
    nn.relu(%0) /* ty=Tensor[(10, 10), float32] */
  } /* ty=fn (Tensor[(10, 10), float32]) -> Tensor[(10, 10), float32] */;
  %2(%1) /* ty=Tensor[(10, 10), float32] */
}

并行合并#

测试依赖于相同输入的并行模式是否正确合并。

测试图难以用 ASCII 图形表示。它本质上是两个并行的 add-sub-mul 单元,它们都消耗 input_1input_2,并将它们的结果相乘以生成输出。期望两个并行分支都能被合并,并且它们仍然应该消耗相同的输入变量 input_1input_2

def before():
    input_1 = relay.var("input_1", shape=(10, 10))
    input_2 = relay.var("input_2", shape=(10, 10))
    branch_1_add = relay.add(input_1, input_2)
    branch_1_sub = relay.subtract(input_1, input_2)
    branch_1 = relay.multiply(branch_1_add, branch_1_sub)
    branch_2_add = relay.add(input_1, input_2)
    branch_2_sub = relay.subtract(input_1, input_2)
    branch_2 = relay.multiply(branch_2_add, branch_2_sub)
    out = relay.multiply(branch_1, branch_2)
    return relay.Function([input_1, input_2], out)

pattern_table = [("add_sub_mul", make_add_sub_mul_pattern())]
graph = before()
tvm.IRModule.from_expr(graph).show()
result = run_opt_pass(
    graph, relay.transform.MergeComposite(pattern_table), import_prelude=False
)
assert not relay.analysis.free_vars(result), f"在{result}图中发现了自由变量"
tvm.IRModule.from_expr(result).show()
Hide code cell output
def @main(%input_1: Tensor[(10, 10), float32], %input_2: Tensor[(10, 10), float32]) {
  %0 = add(%input_1, %input_2);
  %1 = subtract(%input_1, %input_2);
  %2 = add(%input_1, %input_2);
  %3 = subtract(%input_1, %input_2);
  %4 = multiply(%0, %1);
  %5 = multiply(%2, %3);
  multiply(%4, %5)
}
def @main(%input_1: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %input_2: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */) -> Tensor[(10, 10), float32] {
  %2 = fn (%FunctionVar_1_0: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %FunctionVar_1_1: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, PartitionedFromPattern="add_subtract_multiply_", Composite="add_sub_mul") -> Tensor[(10, 10), float32] {
    %0 = add(%FunctionVar_1_0, %FunctionVar_1_1) /* ty=Tensor[(10, 10), float32] */;
    %1 = subtract(%FunctionVar_1_0, %FunctionVar_1_1) /* ty=Tensor[(10, 10), float32] */;
    multiply(%0, %1) /* ty=Tensor[(10, 10), float32] */
  } /* ty=fn (Tensor[(10, 10), float32], Tensor[(10, 10), float32]) -> Tensor[(10, 10), float32] */;
  %5 = fn (%FunctionVar_0_0: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %FunctionVar_0_1: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, PartitionedFromPattern="add_subtract_multiply_", Composite="add_sub_mul") -> Tensor[(10, 10), float32] {
    %3 = add(%FunctionVar_0_0, %FunctionVar_0_1) /* ty=Tensor[(10, 10), float32] */;
    %4 = subtract(%FunctionVar_0_0, %FunctionVar_0_1) /* ty=Tensor[(10, 10), float32] */;
    multiply(%3, %4) /* ty=Tensor[(10, 10), float32] */
  } /* ty=fn (Tensor[(10, 10), float32], Tensor[(10, 10), float32]) -> Tensor[(10, 10), float32] */;
  %6 = %2(%input_1, %input_2) /* ty=Tensor[(10, 10), float32] */;
  %7 = %5(%input_1, %input_2) /* ty=Tensor[(10, 10), float32] */;
  multiply(%6, %7) /* ty=Tensor[(10, 10), float32] */
}

多输入子图#

测试多个输入子图连接到另一个子图的情况。

     (1)    (2)    (3)    (4)
    add    add    add    add
     |      |      |      |
    relu   relu   relu   relu
     \      /      \      /
      \   /         \   /
       add           sub
        \            /
          \        /
            \    /
              mul

    ----> 当 1=3 且 2=4 时(情况 'A')

    add_relu  add_relu
       \         /
        \      /
       add_sub_mul

    ----> 当 1!=3 且 2!=4 时(情况 'B')

    add_relu  add_relu  add_relu  add_relu
       \       /           \       /
         \   /               \   /
          add                 sub
           \                  /
            --------     -----
                   \    /
                    mul

行为上的差异源于 add_sub_mul 期望 addsub 的输入是相同的(相同的两个 relay 表达式)。因此,当你有 4 个独立的输入时,模式不应被合并。

def before():
    before_funcs = {}
    inputs = [relay.var("input_" + str(i), shape=(10, 10)) for i in range(8)]
    add_relu_1 = relay.add(inputs[0], inputs[1])
    add_relu_1 = relay.nn.relu(add_relu_1)
    add_relu_2 = relay.add(inputs[2], inputs[3])
    add_relu_2 = relay.nn.relu(add_relu_2)
    add_relu_3 = relay.add(inputs[4], inputs[5])
    add_relu_3 = relay.nn.relu(add_relu_3)
    add_relu_4 = relay.add(inputs[6], inputs[7])
    add_relu_4 = relay.nn.relu(add_relu_4)
    add = relay.add(add_relu_1, add_relu_2)
    sub = relay.subtract(add_relu_3, add_relu_4)
    out = relay.multiply(add, sub)
    before_funcs["B"] = relay.Function(inputs, out)
    sub = relay.subtract(add_relu_1, add_relu_2)
    out = relay.multiply(add, sub)
    before_funcs["A"] = relay.Function(inputs[:4], out)
    return before_funcs

pattern_table = [
    ("add_sub_mul", make_add_sub_mul_pattern()),
    ("add_relu", make_add_relu_pattern()),
]
graph = before()["A"]
tvm.IRModule.from_expr(graph).show()
result = run_opt_pass(
    graph, relay.transform.MergeComposite(pattern_table), import_prelude=False
)
assert not relay.analysis.free_vars(result), f"在{result}图中发现了自由变量"
tvm.IRModule.from_expr(result).show()
Hide code cell output
def @main(%input_0: Tensor[(10, 10), float32], %input_1: Tensor[(10, 10), float32], %input_2: Tensor[(10, 10), float32], %input_3: Tensor[(10, 10), float32]) {
  %0 = add(%input_0, %input_1);
  %1 = add(%input_2, %input_3);
  %2 = nn.relu(%0);
  %3 = nn.relu(%1);
  %4 = add(%2, %3);
  %5 = subtract(%2, %3);
  multiply(%4, %5)
}
def @main(%input_0: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %input_1: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %input_2: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %input_3: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */) -> Tensor[(10, 10), float32] {
  %3 = fn (%FunctionVar_1_0: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %FunctionVar_1_1: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, PartitionedFromPattern="add_nn.relu_", Composite="add_relu") -> Tensor[(10, 10), float32] {
    %2 = add(%FunctionVar_1_0, %FunctionVar_1_1) /* ty=Tensor[(10, 10), float32] */;
    nn.relu(%2) /* ty=Tensor[(10, 10), float32] */
  } /* ty=fn (Tensor[(10, 10), float32], Tensor[(10, 10), float32]) -> Tensor[(10, 10), float32] */;
  %5 = fn (%FunctionVar_0_01: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %FunctionVar_0_11: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, PartitionedFromPattern="add_nn.relu_", Composite="add_relu") -> Tensor[(10, 10), float32] {
    %4 = add(%FunctionVar_0_01, %FunctionVar_0_11) /* ty=Tensor[(10, 10), float32] */;
    nn.relu(%4) /* ty=Tensor[(10, 10), float32] */
  } /* ty=fn (Tensor[(10, 10), float32], Tensor[(10, 10), float32]) -> Tensor[(10, 10), float32] */;
  %6 = %3(%input_0, %input_1) /* ty=Tensor[(10, 10), float32] */;
  %7 = %5(%input_2, %input_3) /* ty=Tensor[(10, 10), float32] */;
  %8 = fn (%FunctionVar_0_0: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %FunctionVar_0_1: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, PartitionedFromPattern="add_subtract_multiply_", Composite="add_sub_mul") -> Tensor[(10, 10), float32] {
    %0 = add(%FunctionVar_0_0, %FunctionVar_0_1) /* ty=Tensor[(10, 10), float32] */;
    %1 = subtract(%FunctionVar_0_0, %FunctionVar_0_1) /* ty=Tensor[(10, 10), float32] */;
    multiply(%0, %1) /* ty=Tensor[(10, 10), float32] */
  } /* ty=fn (Tensor[(10, 10), float32], Tensor[(10, 10), float32]) -> Tensor[(10, 10), float32] */;
  %8(%6, %7) /* ty=Tensor[(10, 10), float32] */
}
graph = before()["B"]
tvm.IRModule.from_expr(graph).show()
result = run_opt_pass(
    graph, relay.transform.MergeComposite(pattern_table), import_prelude=False
)
assert not relay.analysis.free_vars(result), f"在{result}图中发现了自由变量"
tvm.IRModule.from_expr(result).show()
Hide code cell output
def @main(%input_0: Tensor[(10, 10), float32], %input_1: Tensor[(10, 10), float32], %input_2: Tensor[(10, 10), float32], %input_3: Tensor[(10, 10), float32], %input_4: Tensor[(10, 10), float32], %input_5: Tensor[(10, 10), float32], %input_6: Tensor[(10, 10), float32], %input_7: Tensor[(10, 10), float32]) {
  %0 = add(%input_0, %input_1);
  %1 = add(%input_2, %input_3);
  %2 = nn.relu(%0);
  %3 = nn.relu(%1);
  %4 = add(%input_4, %input_5);
  %5 = add(%input_6, %input_7);
  %6 = nn.relu(%4);
  %7 = nn.relu(%5);
  %8 = add(%2, %3);
  %9 = subtract(%6, %7);
  multiply(%8, %9)
}
def @main(%input_0: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %input_1: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %input_2: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %input_3: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %input_4: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %input_5: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %input_6: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %input_7: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */) -> Tensor[(10, 10), float32] {
  %1 = fn (%FunctionVar_3_0: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %FunctionVar_3_1: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, PartitionedFromPattern="add_nn.relu_", Composite="add_relu") -> Tensor[(10, 10), float32] {
    %0 = add(%FunctionVar_3_0, %FunctionVar_3_1) /* ty=Tensor[(10, 10), float32] */;
    nn.relu(%0) /* ty=Tensor[(10, 10), float32] */
  } /* ty=fn (Tensor[(10, 10), float32], Tensor[(10, 10), float32]) -> Tensor[(10, 10), float32] */;
  %3 = fn (%FunctionVar_2_0: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %FunctionVar_2_1: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, PartitionedFromPattern="add_nn.relu_", Composite="add_relu") -> Tensor[(10, 10), float32] {
    %2 = add(%FunctionVar_2_0, %FunctionVar_2_1) /* ty=Tensor[(10, 10), float32] */;
    nn.relu(%2) /* ty=Tensor[(10, 10), float32] */
  } /* ty=fn (Tensor[(10, 10), float32], Tensor[(10, 10), float32]) -> Tensor[(10, 10), float32] */;
  %4 = %1(%input_0, %input_1) /* ty=Tensor[(10, 10), float32] */;
  %5 = %3(%input_2, %input_3) /* ty=Tensor[(10, 10), float32] */;
  %7 = fn (%FunctionVar_1_0: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %FunctionVar_1_1: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, PartitionedFromPattern="add_nn.relu_", Composite="add_relu") -> Tensor[(10, 10), float32] {
    %6 = add(%FunctionVar_1_0, %FunctionVar_1_1) /* ty=Tensor[(10, 10), float32] */;
    nn.relu(%6) /* ty=Tensor[(10, 10), float32] */
  } /* ty=fn (Tensor[(10, 10), float32], Tensor[(10, 10), float32]) -> Tensor[(10, 10), float32] */;
  %9 = fn (%FunctionVar_0_0: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %FunctionVar_0_1: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, PartitionedFromPattern="add_nn.relu_", Composite="add_relu") -> Tensor[(10, 10), float32] {
    %8 = add(%FunctionVar_0_0, %FunctionVar_0_1) /* ty=Tensor[(10, 10), float32] */;
    nn.relu(%8) /* ty=Tensor[(10, 10), float32] */
  } /* ty=fn (Tensor[(10, 10), float32], Tensor[(10, 10), float32]) -> Tensor[(10, 10), float32] */;
  %10 = %7(%input_4, %input_5) /* ty=Tensor[(10, 10), float32] */;
  %11 = %9(%input_6, %input_7) /* ty=Tensor[(10, 10), float32] */;
  %12 = add(%4, %5) /* ty=Tensor[(10, 10), float32] */;
  %13 = subtract(%10, %11) /* ty=Tensor[(10, 10), float32] */;
  multiply(%12, %13) /* ty=Tensor[(10, 10), float32] */
}

合并 TupleGetItem#

测试可以从包含 TupleGetItem 节点的模式中合并复合函数。

pattern_table = [("bn_relu", make_bn_relu_pattern())]

def before():
    x = relay.var("x", shape=(1, 8))
    gamma = relay.var("gamma", shape=(8,))
    beta = relay.var("beta", shape=(8,))
    moving_mean = relay.var("moving_mean", shape=(8,))
    moving_var = relay.var("moving_var", shape=(8,))
    bn_node = relay.nn.batch_norm(x, gamma, beta, moving_mean, moving_var)
    tuple_get_item_node = bn_node[0]
    r = relay.nn.relu(tuple_get_item_node)
    return relay.Function([x, gamma, beta, moving_mean, moving_var], r)
graph = before()
tvm.IRModule.from_expr(graph).show()
result = run_opt_pass(
    graph, relay.transform.MergeComposite(pattern_table), import_prelude=False
)
assert not relay.analysis.free_vars(result), f"在{result}图中发现了自由变量"
tvm.IRModule.from_expr(result).show()
def @main(%x: Tensor[(1, 8), float32], %gamma: Tensor[(8), float32], %beta: Tensor[(8), float32], %moving_mean: Tensor[(8), float32], %moving_var: Tensor[(8), float32]) {
  %0 = nn.batch_norm(%x, %gamma, %beta, %moving_mean, %moving_var);
  %1 = %0.0;
  nn.relu(%1)
}
def @main(%x: Tensor[(1, 8), float32] /* ty=Tensor[(1, 8), float32] */, %gamma: Tensor[(8), float32] /* ty=Tensor[(8), float32] */, %beta: Tensor[(8), float32] /* ty=Tensor[(8), float32] */, %moving_mean: Tensor[(8), float32] /* ty=Tensor[(8), float32] */, %moving_var: Tensor[(8), float32] /* ty=Tensor[(8), float32] */) -> Tensor[(1, 8), float32] {
  %2 = fn (%FunctionVar_0_0: Tensor[(1, 8), float32] /* ty=Tensor[(1, 8), float32] */, %FunctionVar_0_1: Tensor[(8), float32] /* ty=Tensor[(8), float32] */, %FunctionVar_0_2: Tensor[(8), float32] /* ty=Tensor[(8), float32] */, %FunctionVar_0_3: Tensor[(8), float32] /* ty=Tensor[(8), float32] */, %FunctionVar_0_4: Tensor[(8), float32] /* ty=Tensor[(8), float32] */, PartitionedFromPattern="nn.batch_norm_TupleGetItem0_nn.relu_", Composite="bn_relu") -> Tensor[(1, 8), float32] {
    %0 = nn.batch_norm(%FunctionVar_0_0, %FunctionVar_0_1, %FunctionVar_0_2, %FunctionVar_0_3, %FunctionVar_0_4) /* ty=(Tensor[(1, 8), float32], Tensor[(8), float32], Tensor[(8), float32]) */;
    %1 = %0.0 /* ty=Tensor[(1, 8), float32] */;
    nn.relu(%1) /* ty=Tensor[(1, 8), float32] */
  } /* ty=fn (Tensor[(1, 8), float32], Tensor[(8), float32], Tensor[(8), float32], Tensor[(8), float32], Tensor[(8), float32]) -> Tensor[(1, 8), float32] */;
  %2(%x, %gamma, %beta, %moving_mean, %moving_var) /* ty=Tensor[(1, 8), float32] */
}

with 检查#

def before():
    x = relay.var("x", shape=(1, 10, 10, 10))
    w = relay.var("w", shape=(10, 10, 3, 3))
    b = relay.var("b", shape=(8,))
    conv = relay.nn.conv2d(x, w, kernel_size=(3, 3), kernel_layout="OIHW", data_layout="NHWC")
    bias = relay.nn.bias_add(conv, b)
    relu = relay.nn.relu(bias)
    return relay.Function([x, w, b], relu)

def _check_true(extract):
    conv = extract.args[0].args[0]
    return conv.attrs.data_layout == "NHWC"

def _check_false(extract):
    conv = extract.args[0].args[0]
    return conv.attrs.data_layout == "NCHW"
pattern_table = [("conv_bias_relu", make_conv_bias_relu_pattern(), _check_false)]
graph = before()
tvm.IRModule.from_expr(graph).show()
result = run_opt_pass(
    graph, relay.transform.MergeComposite(pattern_table), import_prelude=False
)
assert not relay.analysis.free_vars(result), f"在{result}图中发现了自由变量"
tvm.IRModule.from_expr(result).show()
def @main(%x: Tensor[(1, 10, 10, 10), float32], %w: Tensor[(10, 10, 3, 3), float32], %b: Tensor[(8), float32]) {
  %0 = nn.conv2d(%x, %w, padding=[0, 0, 0, 0], kernel_size=[3, 3], data_layout="NHWC");
  %1 = nn.bias_add(%0, %b);
  nn.relu(%1)
}
def @main(%x: Tensor[(1, 10, 10, 10), float32] /* ty=Tensor[(1, 10, 10, 10), float32] */, %w: Tensor[(10, 10, 3, 3), float32] /* ty=Tensor[(10, 10, 3, 3), float32] */, %b: Tensor[(8), float32] /* ty=Tensor[(8), float32] */) -> Tensor[(1, 8, 8, 10), float32] {
  %0 = nn.conv2d(%x, %w, padding=[0, 0, 0, 0], kernel_size=[3, 3], data_layout="NHWC") /* ty=Tensor[(1, 8, 8, 10), float32] */;
  %1 = nn.bias_add(%0, %b) /* ty=Tensor[(1, 8, 8, 10), float32] */;
  nn.relu(%1) /* ty=Tensor[(1, 8, 8, 10), float32] */
}
pattern_table = [("conv_bias_relu", make_conv_bias_relu_pattern(), _check_true)]
graph = before()
tvm.IRModule.from_expr(graph).show()
result = run_opt_pass(
    graph, relay.transform.MergeComposite(pattern_table), import_prelude=False
)
assert not relay.analysis.free_vars(result), f"在{result}图中发现了自由变量"
tvm.IRModule.from_expr(result).show()
def @main(%x: Tensor[(1, 10, 10, 10), float32], %w: Tensor[(10, 10, 3, 3), float32], %b: Tensor[(8), float32]) {
  %0 = nn.conv2d(%x, %w, padding=[0, 0, 0, 0], kernel_size=[3, 3], data_layout="NHWC");
  %1 = nn.bias_add(%0, %b);
  nn.relu(%1)
}
def @main(%x: Tensor[(1, 10, 10, 10), float32] /* ty=Tensor[(1, 10, 10, 10), float32] */, %w: Tensor[(10, 10, 3, 3), float32] /* ty=Tensor[(10, 10, 3, 3), float32] */, %b: Tensor[(8), float32] /* ty=Tensor[(8), float32] */) -> Tensor[(1, 8, 8, 10), float32] {
  %2 = fn (%FunctionVar_0_0: Tensor[(1, 10, 10, 10), float32] /* ty=Tensor[(1, 10, 10, 10), float32] */, %FunctionVar_0_1: Tensor[(10, 10, 3, 3), float32] /* ty=Tensor[(10, 10, 3, 3), float32] */, %FunctionVar_0_2: Tensor[(8), float32] /* ty=Tensor[(8), float32] */, PartitionedFromPattern="nn.conv2d_nn.bias_add_nn.relu_", Composite="conv_bias_relu") -> Tensor[(1, 8, 8, 10), float32] {
    %0 = nn.conv2d(%FunctionVar_0_0, %FunctionVar_0_1, padding=[0, 0, 0, 0], kernel_size=[3, 3], data_layout="NHWC") /* ty=Tensor[(1, 8, 8, 10), float32] */;
    %1 = nn.bias_add(%0, %FunctionVar_0_2) /* ty=Tensor[(1, 8, 8, 10), float32] */;
    nn.relu(%1) /* ty=Tensor[(1, 8, 8, 10), float32] */
  } /* ty=fn (Tensor[(1, 10, 10, 10), float32], Tensor[(10, 10, 3, 3), float32], Tensor[(8), float32]) -> Tensor[(1, 8, 8, 10), float32] */;
  %2(%x, %w, %b) /* ty=Tensor[(1, 8, 8, 10), float32] */
}

diamond 不合并的情况#

左侧的模式不应匹配右侧的结构。

    relu             relu
     | \              | \
     | clip           | add
     |  /             |  |
     mul              | clip
                      |  /
                      mul
def get_pattern():
    conv = make_conv_bias_relu_pattern()
    clip = is_op("clip")(conv, wildcard(), wildcard())
    return is_op("multiply")(conv, clip)

def get_net():
    data = relay.var("data", shape=(1, 512, 28, 28))
    kernel = relay.var("kernel", shape=(256, 512, 1, 1))
    conv = relay.nn.conv2d(data, kernel, kernel_size=(1, 1), padding=(0, 0), strides=(1, 1))
    bias = relay.nn.bias_add(conv, relay.var("bias", shape=(256,)))
    relu = relay.nn.relu(bias)
    add = relay.op.add(relu, relay.const(1.0))
    clip2 = relay.op.clip(add, 0, 255)
    mul = relay.op.multiply(relu, clip2)
    return relay.Function(relay.analysis.free_vars(mul), mul)

pattern_table = [("pat", get_pattern())]
graph = get_net()
tvm.IRModule.from_expr(graph).show()
result = run_opt_pass(
    graph, relay.transform.MergeComposite(pattern_table), import_prelude=False
)
assert not relay.analysis.free_vars(result), f"在{result}图中发现了自由变量"
tvm.IRModule.from_expr(result).show()
def @main(%data: Tensor[(1, 512, 28, 28), float32], %kernel: Tensor[(256, 512, 1, 1), float32], %bias: Tensor[(256), float32]) {
  %0 = nn.conv2d(%data, %kernel, padding=[0, 0, 0, 0], kernel_size=[1, 1]);
  %1 = nn.bias_add(%0, %bias);
  %2 = nn.relu(%1);
  %3 = add(%2, 1f);
  %4 = clip(%3, a_min=0f, a_max=255f);
  multiply(%2, %4)
}
def @main(%data: Tensor[(1, 512, 28, 28), float32] /* ty=Tensor[(1, 512, 28, 28), float32] */, %kernel: Tensor[(256, 512, 1, 1), float32] /* ty=Tensor[(256, 512, 1, 1), float32] */, %bias: Tensor[(256), float32] /* ty=Tensor[(256), float32] */) -> Tensor[(1, 256, 28, 28), float32] {
  %0 = nn.conv2d(%data, %kernel, padding=[0, 0, 0, 0], kernel_size=[1, 1]) /* ty=Tensor[(1, 256, 28, 28), float32] */;
  %1 = nn.bias_add(%0, %bias) /* ty=Tensor[(1, 256, 28, 28), float32] */;
  %2 = nn.relu(%1) /* ty=Tensor[(1, 256, 28, 28), float32] */;
  %3 = add(%2, 1f /* ty=float32 */) /* ty=Tensor[(1, 256, 28, 28), float32] */;
  %4 = clip(%3, a_min=0f, a_max=255f) /* ty=Tensor[(1, 256, 28, 28), float32] */;
  multiply(%2, %4) /* ty=Tensor[(1, 256, 28, 28), float32] */
}

查询张量类型#

def before():
    x = relay.var("x", shape=(1, 10, 10, 10))
    w = relay.var("w", shape=(10, 10, 3, 3))
    b = relay.var("b", shape=(8,))
    add = relay.op.add(x, x)
    relu = relay.nn.relu(add)
    conv = relay.nn.conv2d(
        relu, w, kernel_size=(3, 3), kernel_layout="OIHW", data_layout="NHWC"
    )
    bias = relay.nn.bias_add(conv, b)
    relu2 = relay.nn.relu(bias)
    return run_opt_pass(relay.Function([x, w, b], relu2), relay.transform.InferType())

def _check_type_true(extract):
    conv = extract.args[0].args[0]
    typ = conv.checked_type
    return bool(typ.shape[0] == 1)

def _check_type_false(extract):
    conv = extract.args[0].args[0]
    typ = conv.checked_type
    return bool(typ.shape[0] != 1)
pattern_table = [
    ("add_relu", make_add_relu_pattern()),
    ("conv_bias_relu", make_conv_bias_relu_pattern(), _check_type_false),
]
graph = before()
tvm.IRModule.from_expr(graph).show()
result = run_opt_pass(
    graph, relay.transform.MergeComposite(pattern_table), import_prelude=False
)
assert not relay.analysis.free_vars(result), f"在{result}图中发现了自由变量"
tvm.IRModule.from_expr(result).show()
def @main(%x: Tensor[(1, 10, 10, 10), float32] /* ty=Tensor[(1, 10, 10, 10), float32] */, %w: Tensor[(10, 10, 3, 3), float32] /* ty=Tensor[(10, 10, 3, 3), float32] */, %b: Tensor[(8), float32] /* ty=Tensor[(8), float32] */) -> Tensor[(1, 8, 8, 10), float32] {
  %0 = add(%x, %x) /* ty=Tensor[(1, 10, 10, 10), float32] */;
  %1 = nn.relu(%0) /* ty=Tensor[(1, 10, 10, 10), float32] */;
  %2 = nn.conv2d(%1, %w, padding=[0, 0, 0, 0], kernel_size=[3, 3], data_layout="NHWC") /* ty=Tensor[(1, 8, 8, 10), float32] */;
  %3 = nn.bias_add(%2, %b) /* ty=Tensor[(1, 8, 8, 10), float32] */;
  nn.relu(%3) /* ty=Tensor[(1, 8, 8, 10), float32] */
}
def @main(%x: Tensor[(1, 10, 10, 10), float32] /* ty=Tensor[(1, 10, 10, 10), float32] */, %w: Tensor[(10, 10, 3, 3), float32] /* ty=Tensor[(10, 10, 3, 3), float32] */, %b: Tensor[(8), float32] /* ty=Tensor[(8), float32] */) -> Tensor[(1, 8, 8, 10), float32] {
  %1 = fn (%FunctionVar_0_0: Tensor[(1, 10, 10, 10), float32] /* ty=Tensor[(1, 10, 10, 10), float32] */, PartitionedFromPattern="add_nn.relu_", Composite="add_relu") -> Tensor[(1, 10, 10, 10), float32] {
    %0 = add(%FunctionVar_0_0, %FunctionVar_0_0) /* ty=Tensor[(1, 10, 10, 10), float32] */;
    nn.relu(%0) /* ty=Tensor[(1, 10, 10, 10), float32] */
  } /* ty=fn (Tensor[(1, 10, 10, 10), float32]) -> Tensor[(1, 10, 10, 10), float32] */;
  %2 = %1(%x) /* ty=Tensor[(1, 10, 10, 10), float32] */;
  %3 = nn.conv2d(%2, %w, padding=[0, 0, 0, 0], kernel_size=[3, 3], data_layout="NHWC") /* ty=Tensor[(1, 8, 8, 10), float32] */;
  %4 = nn.bias_add(%3, %b) /* ty=Tensor[(1, 8, 8, 10), float32] */;
  nn.relu(%4) /* ty=Tensor[(1, 8, 8, 10), float32] */
}
pattern_table = [
    ("add_relu", make_add_relu_pattern()),
    ("conv_bias_relu", make_conv_bias_relu_pattern(), _check_type_true),
]
graph = before()
tvm.IRModule.from_expr(graph).show()
result = run_opt_pass(
    graph, relay.transform.MergeComposite(pattern_table), import_prelude=False
)
assert not relay.analysis.free_vars(result), f"在{result}图中发现了自由变量"
tvm.IRModule.from_expr(result).show()
def @main(%x: Tensor[(1, 10, 10, 10), float32] /* ty=Tensor[(1, 10, 10, 10), float32] */, %w: Tensor[(10, 10, 3, 3), float32] /* ty=Tensor[(10, 10, 3, 3), float32] */, %b: Tensor[(8), float32] /* ty=Tensor[(8), float32] */) -> Tensor[(1, 8, 8, 10), float32] {
  %0 = add(%x, %x) /* ty=Tensor[(1, 10, 10, 10), float32] */;
  %1 = nn.relu(%0) /* ty=Tensor[(1, 10, 10, 10), float32] */;
  %2 = nn.conv2d(%1, %w, padding=[0, 0, 0, 0], kernel_size=[3, 3], data_layout="NHWC") /* ty=Tensor[(1, 8, 8, 10), float32] */;
  %3 = nn.bias_add(%2, %b) /* ty=Tensor[(1, 8, 8, 10), float32] */;
  nn.relu(%3) /* ty=Tensor[(1, 8, 8, 10), float32] */
}
def @main(%x: Tensor[(1, 10, 10, 10), float32] /* ty=Tensor[(1, 10, 10, 10), float32] */, %w: Tensor[(10, 10, 3, 3), float32] /* ty=Tensor[(10, 10, 3, 3), float32] */, %b: Tensor[(8), float32] /* ty=Tensor[(8), float32] */) -> Tensor[(1, 8, 8, 10), float32] {
  %3 = fn (%FunctionVar_0_01: Tensor[(1, 10, 10, 10), float32] /* ty=Tensor[(1, 10, 10, 10), float32] */, PartitionedFromPattern="add_nn.relu_", Composite="add_relu") -> Tensor[(1, 10, 10, 10), float32] {
    %2 = add(%FunctionVar_0_01, %FunctionVar_0_01) /* ty=Tensor[(1, 10, 10, 10), float32] */;
    nn.relu(%2) /* ty=Tensor[(1, 10, 10, 10), float32] */
  } /* ty=fn (Tensor[(1, 10, 10, 10), float32]) -> Tensor[(1, 10, 10, 10), float32] */;
  %4 = %3(%x) /* ty=Tensor[(1, 10, 10, 10), float32] */;
  %5 = fn (%FunctionVar_0_0: Tensor[(1, 10, 10, 10), float32] /* ty=Tensor[(1, 10, 10, 10), float32] */, %FunctionVar_0_1: Tensor[(10, 10, 3, 3), float32] /* ty=Tensor[(10, 10, 3, 3), float32] */, %FunctionVar_0_2: Tensor[(8), float32] /* ty=Tensor[(8), float32] */, PartitionedFromPattern="nn.conv2d_nn.bias_add_nn.relu_", Composite="conv_bias_relu") -> Tensor[(1, 8, 8, 10), float32] {
    %0 = nn.conv2d(%FunctionVar_0_0, %FunctionVar_0_1, padding=[0, 0, 0, 0], kernel_size=[3, 3], data_layout="NHWC") /* ty=Tensor[(1, 8, 8, 10), float32] */;
    %1 = nn.bias_add(%0, %FunctionVar_0_2) /* ty=Tensor[(1, 8, 8, 10), float32] */;
    nn.relu(%1) /* ty=Tensor[(1, 8, 8, 10), float32] */
  } /* ty=fn (Tensor[(1, 10, 10, 10), float32], Tensor[(10, 10, 3, 3), float32], Tensor[(8), float32]) -> Tensor[(1, 8, 8, 10), float32] */;
  %5(%4, %w, %b) /* ty=Tensor[(1, 8, 8, 10), float32] */
}

不会因 einsum 算子而导致错误#

from tvm.relay.dataflow_pattern import TuplePattern

def make_einsum_reshape_pattern():
    x = wildcard()
    x = is_op("reshape")(x) | x
    y = wildcard()
    y = is_op("reshape")(y) | y
    z = is_op("einsum")(TuplePattern([x, y]))
    r = is_op("reshape")(z) | z
    return r

def before():
    a = relay.var("a", shape=(10, 10))
    b = relay.var("b", shape=(10, 10))
    c = relay.reshape(a, [20, 5])
    d = relay.reshape(b, [20, 5])
    r = relay.einsum([c, d], "...ab,...cb->...ac")
    return relay.Function([a, b], r)

pattern_table = [
    (
        "einsum_reshape",
        make_einsum_reshape_pattern(),
    )
]
graph = before()
tvm.IRModule.from_expr(graph).show()
result = run_opt_pass(
    graph, relay.transform.MergeComposite(pattern_table), import_prelude=False
)
assert not relay.analysis.free_vars(result), f"在{result}图中发现了自由变量"
tvm.IRModule.from_expr(result).show()
def @main(%a: Tensor[(10, 10), float32], %b: Tensor[(10, 10), float32]) {
  %0 = reshape(%a, newshape=[20, 5]);
  %1 = reshape(%b, newshape=[20, 5]);
  %2 = (%0, %1);
  einsum(%2, equation="...ab,...cb->...ac")
}
def @main(%a: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %b: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */) -> Tensor[(20, 20), float32] {
  %3 = fn (%FunctionVar_0_0: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %FunctionVar_0_1: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, PartitionedFromPattern="reshape_reshape_Tuple_einsum_", Composite="einsum_reshape") -> Tensor[(20, 20), float32] {
    %0 = reshape(%FunctionVar_0_0, newshape=[20, 5]) /* ty=Tensor[(20, 5), float32] */;
    %1 = reshape(%FunctionVar_0_1, newshape=[20, 5]) /* ty=Tensor[(20, 5), float32] */;
    %2 = (%0, %1) /* ty=(Tensor[(20, 5), float32], Tensor[(20, 5), float32]) */;
    einsum(%2, equation="...ab,...cb->...ac") /* ty=Tensor[(20, 20), float32] */
  } /* ty=fn (Tensor[(10, 10), float32], Tensor[(10, 10), float32]) -> Tensor[(20, 20), float32] */;
  %3(%a, %b) /* ty=Tensor[(20, 20), float32] */
}