分割 matched-outside-but-dominated

分割 matched-outside-but-dominated#

from testing import viz_expr # 可视化 relay
import tvm
from tvm import relay
from tvm.relay.dataflow_pattern import *

模式匹配了 nn.conv2d/add/multiply 流程。尽管 add 的输出被 sigmoid 消耗了,但 sigmoid 本身受到 multiply 的控制。

构建计算图:

in_mod = tvm.relay.parse(
    """
    #[version = "0.0.5"]
    def @main(%data: Tensor[(16, 16, 32, 32), float16], %weight: Tensor[(32, 16, 3, 3), float16], %bias: Tensor[(32), float32]) -> Tensor[(16, 32, 32, 32), float32] {
        %0 = layout_transform(%data, src_layout="NCHW", dst_layout="NHWC");
        %1 = layout_transform(%weight, src_layout="OIHW", dst_layout="OHWI");
        %2 = expand_dims(%bias, axis=1, num_newaxis=2);
        %3 = expand_dims(%2, axis=0);
        %4 = nn.conv2d(%0, %1, padding=[1, 1, 1, 1], channels=32, kernel_size=[3, 3], data_layout="NHWC", kernel_layout="OHWI", out_dtype="float32");
        %5 = layout_transform(%3, src_layout="NCHW", dst_layout="NHWC");
        %6 = add(%4, %5);
        %7 = sigmoid(%6);
        %8 = multiply(%6, %7);
        layout_transform(%8, src_layout="NHWC", dst_layout="NCHW")
    }
    """
)
viz_expr(in_mod["main"])
../../../../_images/dcab02323b8a8d1861994cca1227190d7ac95c4d0cc967ff2a607f03b9e8f141.svg
print(in_mod)
def @main(%data: Tensor[(16, 16, 32, 32), float16] /* ty=Tensor[(16, 16, 32, 32), float16] span=from_string:4:31 */, %weight: Tensor[(32, 16, 3, 3), float16] /* ty=Tensor[(32, 16, 3, 3), float16] span=from_string:5:31 */, %bias: Tensor[(32), float32] /* ty=Tensor[(32), float32] span=from_string:6:26 */) -> Tensor[(16, 32, 32, 32), float32] {
  %0 = layout_transform(%data, src_layout="NCHW", dst_layout="NHWC") /* ty=Tensor[(16, 32, 32, 16), float16] span=from_string:8:24 */;
  %1 = layout_transform(%weight, src_layout="OIHW", dst_layout="OHWI") /* ty=Tensor[(32, 3, 3, 16), float16] span=from_string:8:28 */;
  %2 = expand_dims(%bias, axis=1, num_newaxis=2) /* ty=Tensor[(32, 1, 1), float32] span=from_string:7:26 */;
  %3 = expand_dims(%2, axis=0) /* ty=Tensor[(1, 32, 1, 1), float32] span=from_string:9:31 */;
  %4 = nn.conv2d(%0, %1, padding=[1, 1, 1, 1], channels=32, kernel_size=[3, 3], data_layout="NHWC", kernel_layout="OHWI", out_dtype="float32") /* ty=Tensor[(16, 32, 32, 32), float32] span=from_string:10:18 */;
  %5 = layout_transform(%3, src_layout="NCHW", dst_layout="NHWC") /* ty=Tensor[(1, 1, 1, 32), float32] span=from_string:10:22 */;
  %6 = add(%4, %5) /* ty=Tensor[(16, 32, 32, 32), float32] span=from_string:12:24 */;
  %7 = sigmoid(%6) /* ty=Tensor[(16, 32, 32, 32), float32] span=from_string:12:28 */;
  %8 = multiply(%6, %7) /* ty=Tensor[(16, 32, 32, 32), float32] span=from_string:13:26 */;
  layout_transform(%8, src_layout="NHWC", dst_layout="NCHW") /* ty=Tensor[(16, 32, 32, 32), float32] span=from_string:4:9 */
}

构建模式:

pattern = is_op("multiply")(
    is_op("add")(is_op("nn.conv2d")(wildcard(), wildcard()), wildcard()), wildcard()
)

分割计算图:

partitioned = pattern.partition(in_mod["main"])
print(tvm.IRModule.from_expr(partitioned))
def @main(%data: Tensor[(16, 16, 32, 32), float16] /* ty=Tensor[(16, 16, 32, 32), float16] span=from_string:4:31 */, %weight: Tensor[(32, 16, 3, 3), float16] /* ty=Tensor[(32, 16, 3, 3), float16] span=from_string:5:31 */, %bias: Tensor[(32), float32] /* ty=Tensor[(32), float32] span=from_string:6:26 */) -> Tensor[(16, 32, 32, 32), float32] {
  %2 = expand_dims(%bias, axis=1, num_newaxis=2) /* ty=Tensor[(32, 1, 1), float32] span=from_string:7:26 */;
  %3 = expand_dims(%2, axis=0) /* ty=Tensor[(1, 32, 1, 1), float32] span=from_string:9:31 */;
  %4 = layout_transform(%data, src_layout="NCHW", dst_layout="NHWC") /* ty=Tensor[(16, 32, 32, 16), float16] span=from_string:8:24 */;
  %5 = layout_transform(%weight, src_layout="OIHW", dst_layout="OHWI") /* ty=Tensor[(32, 3, 3, 16), float16] span=from_string:8:28 */;
  %6 = nn.conv2d(%4, %5, padding=[1, 1, 1, 1], channels=32, kernel_size=[3, 3], data_layout="NHWC", kernel_layout="OHWI", out_dtype="float32") /* ty=Tensor[(16, 32, 32, 32), float32] span=from_string:10:18 */;
  %7 = layout_transform(%3, src_layout="NCHW", dst_layout="NHWC") /* ty=Tensor[(1, 1, 1, 32), float32] span=from_string:10:22 */;
  %8 = add(%6, %7) /* ty=Tensor[(16, 32, 32, 32), float32] span=from_string:12:24 */;
  %9 = sigmoid(%8) /* ty=Tensor[(16, 32, 32, 32), float32] span=from_string:12:28 */;
  %10 = fn (%FunctionVar_0_0, %FunctionVar_0_1, %FunctionVar_0_2, %FunctionVar_0_3, PartitionedFromPattern="nn.conv2d_add_multiply_") {
    %0 = nn.conv2d(%FunctionVar_0_0, %FunctionVar_0_1, padding=[1, 1, 1, 1], channels=32, kernel_size=[3, 3], data_layout="NHWC", kernel_layout="OHWI", out_dtype="float32") /* ty=Tensor[(16, 32, 32, 32), float32] span=from_string:10:18 */;
    %1 = add(%0, %FunctionVar_0_2) /* ty=Tensor[(16, 32, 32, 32), float32] span=from_string:12:24 */;
    multiply(%1, %FunctionVar_0_3) /* ty=Tensor[(16, 32, 32, 32), float32] span=from_string:13:26 */
  };
  %11 = %10(%4, %5, %7, %9);
  layout_transform(%11, src_layout="NHWC", dst_layout="NCHW") /* ty=Tensor[(16, 32, 32, 32), float32] span=from_string:4:9 */
}