合并编译器区域#

参考:tvm/tests/python/relay/test_pass_merge_compiler_regions.py

compiler_begin()compiler_end() 函数的主要作用是在 Relay 表达式中标记区域的开始和结束,该区域将由指定的编译器处理。这个函数通常用于优化和调度,允许不同的编译器处理不同的表达式区域。

import tvm
from tvm import relay
import tvm.relay.transform as transform
from tvm.relay.op.annotation import compiler_begin, compiler_end
from tvm.relay.testing import run_opt_pass

处理菱形区域#

在菱形计算图中存在的数据依赖关系是否被合并过程正确解决。

    O = 目标支持
    X = 目标不支持

       O         O
      / \\      /               \\
     O   X --> O    +       +    X
     \\ /             \\ /
       O                O

注意不能简单地将三个支持的算子合并在一起,否则两个子图都会依赖于另一个。

def diamond_graph_fanouts():
    data = relay.var("data", shape=(10, 10))
    cb_1 = compiler_begin(data, "test")
    O_1 = relay.abs(cb_1)
    ce_1 = compiler_end(O_1, "test")
    ce_2 = compiler_end(O_1, "test")
    cb_2 = compiler_begin(ce_1, "test")
    cb_3 = compiler_begin(ce_2, "default")
    O_2 = relay.nn.relu(cb_2)
    ce_3 = compiler_end(O_2, "test")

    X = relay.tanh(cb_3)
    ce_4 = compiler_end(X, "default")

    cb_4 = compiler_begin(ce_3, "test")
    cb_5 = compiler_begin(ce_4, "test")
    O_3 = relay.add(cb_4, cb_5)
    ce_5 = compiler_end(O_3, "test")

    diamond = relay.Function([data], ce_5)
    return diamond
expr = diamond_graph_fanouts()
tvm.IRModule.from_expr(expr).show()
Hide code cell output
def @main(%data: Tensor[(10, 10), float32]) {
  %0 = annotation.compiler_begin(%data, compiler="test");
  %1 = abs(%0);
  %2 = annotation.compiler_end(%1, compiler="test");
  %3 = annotation.compiler_begin(%2, compiler="test");
  %4 = nn.relu(%3);
  %5 = annotation.compiler_end(%4, compiler="test");
  %6 = annotation.compiler_end(%1, compiler="test");
  %7 = annotation.compiler_begin(%6, compiler="default");
  %8 = tanh(%7);
  %9 = annotation.compiler_end(%8, compiler="default");
  %10 = annotation.compiler_begin(%5, compiler="test");
  %11 = annotation.compiler_begin(%9, compiler="test");
  %12 = add(%10, %11);
  annotation.compiler_end(%12, compiler="test")
}
opt_expr = run_opt_pass(expr, relay.transform.MergeCompilerRegions())
tvm.IRModule.from_expr(opt_expr).show()
Hide code cell output
def @main(%data: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */) -> Tensor[(10, 10), float32] {
  %0 = annotation.compiler_begin(%data, compiler="test") /* ty=Tensor[(10, 10), float32] */;
  %1 = abs(%0) /* ty=Tensor[(10, 10), float32] */;
  %2 = nn.relu(%1) /* ty=Tensor[(10, 10), float32] */;
  %3 = annotation.compiler_end(%2, compiler="test") /* ty=Tensor[(10, 10), float32] */;
  %4 = annotation.compiler_end(%1, compiler="test") /* ty=Tensor[(10, 10), float32] */;
  %5 = annotation.compiler_begin(%4, compiler="default") /* ty=Tensor[(10, 10), float32] */;
  %6 = tanh(%5) /* ty=Tensor[(10, 10), float32] */;
  %7 = annotation.compiler_end(%6, compiler="default") /* ty=Tensor[(10, 10), float32] */;
  %8 = annotation.compiler_begin(%3, compiler="test") /* ty=Tensor[(10, 10), float32] */;
  %9 = annotation.compiler_begin(%7, compiler="test") /* ty=Tensor[(10, 10), float32] */;
  %10 = add(%8, %9) /* ty=Tensor[(10, 10), float32] */;
  annotation.compiler_end(%10, compiler="test") /* ty=Tensor[(10, 10), float32] */
}

计算图示例#

参见 RFC 5830 蓝色节点是加法(目标:测试),红色节点是减法(目标:默认)。

原始计算图:

Hide code cell content
def annotated():
    in_1 = relay.var("in_1", shape=(10, 10), dtype="float32")
    in_2 = relay.var("in_2", shape=(10, 10), dtype="float32")
    in_3 = relay.var("in_3", shape=(10, 10), dtype="float32")
    in_4 = relay.var("in_4", shape=(10, 10), dtype="float32")
    in_5 = relay.var("in_5", shape=(10, 10), dtype="float32")
    in_6 = relay.var("in_6", shape=(10, 10), dtype="float32")
    in_7 = relay.var("in_7", shape=(10, 10), dtype="float32")
    in_8 = relay.var("in_8", shape=(10, 10), dtype="float32")
    in_9 = relay.var("in_9", shape=(10, 10), dtype="float32")
    in_10 = relay.var("in_10", shape=(10, 10), dtype="float32")

    begin0 = compiler_begin(in_1, "test")
    begin1 = compiler_begin(in_2, "test")
    begin2 = compiler_begin(in_3, "test")
    begin3 = compiler_begin(in_4, "test")
    node0 = relay.add(begin0, begin1)
    node1 = relay.add(begin2, begin3)
    end0 = compiler_end(node0, "test")
    end1 = compiler_end(node1, "test")
    begin4 = compiler_begin(end0, "test")
    begin5 = compiler_begin(end1, "test")
    node2 = relay.add(begin4, begin5)
    end2 = compiler_end(node2, "test")

    dbegin0 = compiler_begin(in_5, "default")
    dbegin1 = compiler_begin(in_6, "default")
    node3 = relay.subtract(dbegin0, dbegin1)
    dbegin2 = compiler_begin(in_7, "default")
    dend1 = compiler_end(node3, "default")
    dbegin3 = compiler_begin(dend1, "default")
    node4 = relay.subtract(dbegin2, dbegin3)
    dend2 = compiler_end(node4, "default")

    begin6 = compiler_begin(end2, "test")
    begin7 = compiler_begin(dend2, "test")
    node5 = relay.add(begin6, begin7)
    end3 = compiler_end(node5, "test")
    end4 = compiler_end(node5, "test")
    dbegin4 = compiler_begin(in_8, "default")
    dbegin5 = compiler_begin(end3, "default")
    node6 = relay.subtract(dbegin4, dbegin5)
    begin8 = compiler_begin(in_9, "test")
    begin9 = compiler_begin(end4, "test")
    node7 = relay.add(begin8, begin9)
    end5 = compiler_end(node7, "test")

    dend3 = compiler_end(node6, "default")
    begin10 = compiler_begin(dend3, "test")
    begin11 = compiler_begin(end5, "test")
    node8 = relay.add(begin10, begin11)
    end6 = compiler_end(node8, "test")
    begin12 = compiler_begin(in_10, "test")
    begin13 = compiler_begin(end6, "test")
    node9 = relay.add(begin12, begin13)
    end7 = compiler_end(node9, "test")

    f = relay.Function([in_1, in_2, in_3, in_4, in_5, in_6, in_7, in_8, in_9, in_10], end7)
    mod = tvm.IRModule.from_expr(f)
    return mod
mod = annotated()
mod.show()
Hide code cell output
def @main(%in_1: Tensor[(10, 10), float32], %in_2: Tensor[(10, 10), float32], %in_3: Tensor[(10, 10), float32], %in_4: Tensor[(10, 10), float32], %in_5: Tensor[(10, 10), float32], %in_6: Tensor[(10, 10), float32], %in_7: Tensor[(10, 10), float32], %in_8: Tensor[(10, 10), float32], %in_9: Tensor[(10, 10), float32], %in_10: Tensor[(10, 10), float32]) {
  %0 = annotation.compiler_begin(%in_1, compiler="test");
  %1 = annotation.compiler_begin(%in_2, compiler="test");
  %2 = add(%0, %1);
  %3 = annotation.compiler_end(%2, compiler="test");
  %4 = annotation.compiler_begin(%in_3, compiler="test");
  %5 = annotation.compiler_begin(%in_4, compiler="test");
  %6 = add(%4, %5);
  %7 = annotation.compiler_end(%6, compiler="test");
  %8 = annotation.compiler_begin(%3, compiler="test");
  %9 = annotation.compiler_begin(%7, compiler="test");
  %10 = add(%8, %9);
  %11 = annotation.compiler_end(%10, compiler="test");
  %12 = annotation.compiler_begin(%in_5, compiler="default");
  %13 = annotation.compiler_begin(%in_6, compiler="default");
  %14 = subtract(%12, %13);
  %15 = annotation.compiler_end(%14, compiler="default");
  %16 = annotation.compiler_begin(%in_7, compiler="default");
  %17 = annotation.compiler_begin(%15, compiler="default");
  %18 = subtract(%16, %17);
  %19 = annotation.compiler_end(%18, compiler="default");
  %20 = annotation.compiler_begin(%11, compiler="test");
  %21 = annotation.compiler_begin(%19, compiler="test");
  %22 = add(%20, %21);
  %23 = annotation.compiler_end(%22, compiler="test");
  %24 = annotation.compiler_begin(%in_8, compiler="default");
  %25 = annotation.compiler_begin(%23, compiler="default");
  %26 = subtract(%24, %25);
  %27 = annotation.compiler_end(%26, compiler="default");
  %28 = annotation.compiler_end(%22, compiler="test");
  %29 = annotation.compiler_begin(%in_9, compiler="test");
  %30 = annotation.compiler_begin(%28, compiler="test");
  %31 = add(%29, %30);
  %32 = annotation.compiler_end(%31, compiler="test");
  %33 = annotation.compiler_begin(%27, compiler="test");
  %34 = annotation.compiler_begin(%32, compiler="test");
  %35 = add(%33, %34);
  %36 = annotation.compiler_end(%35, compiler="test");
  %37 = annotation.compiler_begin(%in_10, compiler="test");
  %38 = annotation.compiler_begin(%36, compiler="test");
  %39 = add(%37, %38);
  annotation.compiler_end(%39, compiler="test")
}

优化之后的计算图:

mod = relay.transform.MergeCompilerRegions()(mod)
mod = relay.transform.InferType()(mod)
mod.show()
Hide code cell output
def @main(%in_1: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %in_2: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %in_3: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %in_4: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %in_5: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %in_6: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %in_7: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %in_8: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %in_9: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %in_10: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */) -> Tensor[(10, 10), float32] {
  %0 = annotation.compiler_begin(%in_1, compiler="test") /* ty=Tensor[(10, 10), float32] */;
  %1 = annotation.compiler_begin(%in_2, compiler="test") /* ty=Tensor[(10, 10), float32] */;
  %2 = annotation.compiler_begin(%in_3, compiler="test") /* ty=Tensor[(10, 10), float32] */;
  %3 = annotation.compiler_begin(%in_4, compiler="test") /* ty=Tensor[(10, 10), float32] */;
  %4 = add(%0, %1) /* ty=Tensor[(10, 10), float32] */;
  %5 = add(%2, %3) /* ty=Tensor[(10, 10), float32] */;
  %6 = annotation.compiler_begin(%in_5, compiler="default") /* ty=Tensor[(10, 10), float32] */;
  %7 = annotation.compiler_begin(%in_6, compiler="default") /* ty=Tensor[(10, 10), float32] */;
  %8 = annotation.compiler_begin(%in_7, compiler="default") /* ty=Tensor[(10, 10), float32] */;
  %9 = subtract(%6, %7) /* ty=Tensor[(10, 10), float32] */;
  %10 = subtract(%8, %9) /* ty=Tensor[(10, 10), float32] */;
  %11 = annotation.compiler_end(%10, compiler="default") /* ty=Tensor[(10, 10), float32] */;
  %12 = add(%4, %5) /* ty=Tensor[(10, 10), float32] */;
  %13 = annotation.compiler_begin(%11, compiler="test") /* ty=Tensor[(10, 10), float32] */;
  %14 = add(%12, %13) /* ty=Tensor[(10, 10), float32] */;
  %15 = annotation.compiler_end(%14, compiler="test") /* ty=Tensor[(10, 10), float32] */;
  %16 = annotation.compiler_begin(%in_8, compiler="default") /* ty=Tensor[(10, 10), float32] */;
  %17 = annotation.compiler_begin(%15, compiler="default") /* ty=Tensor[(10, 10), float32] */;
  %18 = subtract(%16, %17) /* ty=Tensor[(10, 10), float32] */;
  %19 = annotation.compiler_end(%18, compiler="default") /* ty=Tensor[(10, 10), float32] */;
  %20 = annotation.compiler_begin(%in_9, compiler="test") /* ty=Tensor[(10, 10), float32] */;
  %21 = add(%20, %14) /* ty=Tensor[(10, 10), float32] */;
  %22 = annotation.compiler_end(%21, compiler="test") /* ty=Tensor[(10, 10), float32] */;
  %23 = annotation.compiler_begin(%19, compiler="test") /* ty=Tensor[(10, 10), float32] */;
  %24 = annotation.compiler_begin(%22, compiler="test") /* ty=Tensor[(10, 10), float32] */;
  %25 = annotation.compiler_begin(%in_10, compiler="test") /* ty=Tensor[(10, 10), float32] */;
  %26 = add(%23, %24) /* ty=Tensor[(10, 10), float32] */;
  %27 = add(%25, %26) /* ty=Tensor[(10, 10), float32] */;
  annotation.compiler_end(%27, compiler="test") /* ty=Tensor[(10, 10), float32] */
}

测试条件语句#

这个测试了限制区域在 if_else 控制流中成功传播。

    O = 目标支持
    X = 目标不支持

           O1 - - - |      O1 --|
            |       |               |
            X       |               X
            |       |                              |
    If cond ? O1: X | -->       +       +  If cond ? O1: X  +
            |       |                                           |
           O2 <- - -|                                          O2 <-|

避免 O1 合并到 O2。

target = "test_if_else_merge"

@tvm.ir.register_op_attr("sigmoid", "target." + target)
def sigmoid(expr):  # pylint: disable=unused-variable
    return True

@tvm.ir.register_op_attr("erf", "target." + target)
def erf(expr):  # pylint: disable=unused-variable
    return True

@tvm.ir.register_op_attr("add", "target." + target)
def add(expr):  # pylint: disable=unused-variable
    return True
def get_mod():
    data = relay.var("data", shape=(1, 32))
    add0 = relay.add(data, data)
    sub0 = relay.subtract(add0, data)
    eq = relay.equal(relay.sum(add0), relay.sum(sub0))

    true_branch = relay.sigmoid(add0)
    false_branch = relay.sigmoid(sub0)
    ife = relay.If(eq, true_branch, false_branch)
    erf = relay.erf(ife)
    out = relay.add(add0, erf)
    func = relay.Function([data], out)
    mod = tvm.IRModule.from_expr(func)

    return mod
for annotate_non_call_ops in [True, False]:
    result = transform.AnnotateTarget(target, annotate_non_call_ops)(get_mod())
    merge = transform.MergeCompilerRegions()(result)
    # 确保分区完成而没有段错误。
    partition = transform.PartitionGraph()(merge)
    
partition.show()
def @main(%data: Tensor[(1, 32), float32] /* ty=Tensor[(1, 32), float32] */) -> Tensor[(1, 32), float32] {
  %0 = @tvmgen_default_test_if_else_merge_main_0(%data) /* ty=(Tensor[(1, 32), float32], Tensor[(1, 32), float32]) */;
  %1 = %0.0 /* ty=Tensor[(1, 32), float32] */;
  %2 = subtract(%1, %data) /* ty=Tensor[(1, 32), float32] */;
  %3 = sum(%1) /* ty=float32 */;
  %4 = sum(%2) /* ty=float32 */;
  %5 = equal(%3, %4) /* ty=bool */;
  %6 = if (%5) {
    %0.1 /* ty=Tensor[(1, 32), float32] */
  } else {
    @tvmgen_default_test_if_else_merge_main_3(%2) /* ty=Tensor[(1, 32), float32] */
  };
  @tvmgen_default_test_if_else_merge_main_2(%1, %6) /* ty=Tensor[(1, 32), float32] */
}

def @tvmgen_default_test_if_else_merge_main_0(%test_if_else_merge_0_i0: Tensor[(1, 32), float32] /* ty=Tensor[(1, 32), float32] */, Compiler="test_if_else_merge", Primitive=1, Inline=1, global_symbol="tvmgen_default_test_if_else_merge_main_0") -> (Tensor[(1, 32), float32], Tensor[(1, 32), float32]) {
  %7 = add(%test_if_else_merge_0_i0, %test_if_else_merge_0_i0) /* ty=Tensor[(1, 32), float32] */;
  %8 = sigmoid(%7) /* ty=Tensor[(1, 32), float32] */;
  (%7, %8) /* ty=(Tensor[(1, 32), float32], Tensor[(1, 32), float32]) */
}

def @tvmgen_default_test_if_else_merge_main_2(%test_if_else_merge_2_i0: Tensor[(1, 32), float32] /* ty=Tensor[(1, 32), float32] */, %test_if_else_merge_2_i1: Tensor[(1, 32), float32] /* ty=Tensor[(1, 32), float32] */, Compiler="test_if_else_merge", Primitive=1, Inline=1, global_symbol="tvmgen_default_test_if_else_merge_main_2") -> Tensor[(1, 32), float32] {
  %9 = erf(%test_if_else_merge_2_i1) /* ty=Tensor[(1, 32), float32] */;
  add(%test_if_else_merge_2_i0, %9) /* ty=Tensor[(1, 32), float32] */
}

def @tvmgen_default_test_if_else_merge_main_3(%test_if_else_merge_3_i0: Tensor[(1, 32), float32] /* ty=Tensor[(1, 32), float32] */, Compiler="test_if_else_merge", Primitive=1, Inline=1, global_symbol="tvmgen_default_test_if_else_merge_main_3") -> Tensor[(1, 32), float32] {
  sigmoid(%test_if_else_merge_3_i0) /* ty=Tensor[(1, 32), float32] */
}