合并编译器区域#

菱形计算图#

定义如下数据依赖:

     O         O
    / \\      /               \\
    O   X --> O    +       +    X
    \\ /             \\ /
      O                O

其中 O 表示 target 支持的算子,X 表示 target 不支持的算子。

注意,不能仅仅将三个支持的算子合并在一起,否则两个子图将依赖于另一个子图。

import tvm
from tvm import relay
from tvm.relay.op.annotation import compiler_begin, compiler_end
from tvm.relay.testing import run_opt_pass
def diamond_graph_fanouts():
    data = relay.var("data", shape=(10, 10))
    cb_1 = compiler_begin(data, "test")
    O_1 = relay.abs(cb_1)
    ce_1 = compiler_end(O_1, "test")
    ce_2 = compiler_end(O_1, "test")
    cb_2 = compiler_begin(ce_1, "test")
    cb_3 = compiler_begin(ce_2, "default")
    O_2 = relay.nn.relu(cb_2)
    ce_3 = compiler_end(O_2, "test")

    X = relay.tanh(cb_3)
    ce_4 = compiler_end(X, "default")

    cb_4 = compiler_begin(ce_3, "test")
    cb_5 = compiler_begin(ce_4, "test")
    O_3 = relay.add(cb_4, cb_5)
    ce_5 = compiler_end(O_3, "test")

    diamond = relay.Function([data], ce_5)
    return diamond
mod = tvm.IRModule.from_expr(diamond_graph_fanouts())
print(f"合并之前:{mod}")
mod = relay.transform.MergeCompilerRegions()(mod)
print(f"合并之后:{mod}")
合并之前:def @main(%data: Tensor[(10, 10), float32]) {
  %0 = annotation.compiler_begin(%data, compiler="test");
  %1 = abs(%0);
  %2 = annotation.compiler_end(%1, compiler="test");
  %3 = annotation.compiler_begin(%2, compiler="test");
  %4 = nn.relu(%3);
  %5 = annotation.compiler_end(%4, compiler="test");
  %6 = annotation.compiler_end(%1, compiler="test");
  %7 = annotation.compiler_begin(%6, compiler="default");
  %8 = tanh(%7);
  %9 = annotation.compiler_end(%8, compiler="default");
  %10 = annotation.compiler_begin(%5, compiler="test");
  %11 = annotation.compiler_begin(%9, compiler="test");
  %12 = add(%10, %11);
  annotation.compiler_end(%12, compiler="test")
}

合并之后:def @main(%data: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */) -> Tensor[(10, 10), float32] {
  %0 = annotation.compiler_begin(%data, compiler="test") /* ty=Tensor[(10, 10), float32] */;
  %1 = abs(%0) /* ty=Tensor[(10, 10), float32] */;
  %2 = nn.relu(%1) /* ty=Tensor[(10, 10), float32] */;
  %3 = annotation.compiler_end(%2, compiler="test") /* ty=Tensor[(10, 10), float32] */;
  %4 = annotation.compiler_end(%1, compiler="test") /* ty=Tensor[(10, 10), float32] */;
  %5 = annotation.compiler_begin(%4, compiler="default") /* ty=Tensor[(10, 10), float32] */;
  %6 = tanh(%5) /* ty=Tensor[(10, 10), float32] */;
  %7 = annotation.compiler_end(%6, compiler="default") /* ty=Tensor[(10, 10), float32] */;
  %8 = annotation.compiler_begin(%3, compiler="test") /* ty=Tensor[(10, 10), float32] */;
  %9 = annotation.compiler_begin(%7, compiler="test") /* ty=Tensor[(10, 10), float32] */;
  %10 = add(%8, %9) /* ty=Tensor[(10, 10), float32] */;
  annotation.compiler_end(%10, compiler="test") /* ty=Tensor[(10, 10), float32] */
}

if-else 测试#

           O1 - - - |      O1 --|
            |       |               |
            X       |               X
            |       |                              |
    If cond ? O1: X | -->       +       +  If cond ? O1: X  +
            |       |                                           |
           O2 <- - -|                                          O2 <-|
target = "test_if_else_merge"

@tvm.ir.register_op_attr("sigmoid", "target." + target)
def sigmoid(expr):  # pylint: disable=unused-variable
    return True

@tvm.ir.register_op_attr("erf", "target." + target)
def erf(expr):  # pylint: disable=unused-variable
    return True

@tvm.ir.register_op_attr("add", "target." + target)
def add(expr):  # pylint: disable=unused-variable
    return True
def get_mod():
    data = relay.var("data", shape=(1, 32))
    add0 = relay.add(data, data)
    sub0 = relay.subtract(add0, data)
    eq = relay.equal(relay.sum(add0), relay.sum(sub0))

    true_branch = relay.sigmoid(add0)
    false_branch = relay.sigmoid(sub0)
    ife = relay.If(eq, true_branch, false_branch)
    erf = relay.erf(ife)
    out = relay.add(add0, erf)
    func = relay.Function([data], out)
    mod = tvm.IRModule.from_expr(func)
    return mod
for annotate_non_call_ops in [True, False]:
    result = relay.transform.AnnotateTarget(target, annotate_non_call_ops)(get_mod())
    merge = relay.transform.MergeCompilerRegions()(result)
    # Ensure partition finished without segment fault.
    partition = relay.transform.PartitionGraph()(merge)

合并计算图示例#

参考:RFC 5830

def annotated():
    in_1 = relay.var("in_1", shape=(10, 10), dtype="float32")
    in_2 = relay.var("in_2", shape=(10, 10), dtype="float32")
    in_3 = relay.var("in_3", shape=(10, 10), dtype="float32")
    in_4 = relay.var("in_4", shape=(10, 10), dtype="float32")
    in_5 = relay.var("in_5", shape=(10, 10), dtype="float32")
    in_6 = relay.var("in_6", shape=(10, 10), dtype="float32")
    in_7 = relay.var("in_7", shape=(10, 10), dtype="float32")
    in_8 = relay.var("in_8", shape=(10, 10), dtype="float32")
    in_9 = relay.var("in_9", shape=(10, 10), dtype="float32")
    in_10 = relay.var("in_10", shape=(10, 10), dtype="float32")

    begin0 = compiler_begin(in_1, "test")
    begin1 = compiler_begin(in_2, "test")
    begin2 = compiler_begin(in_3, "test")
    begin3 = compiler_begin(in_4, "test")
    node0 = relay.add(begin0, begin1)
    node1 = relay.add(begin2, begin3)
    end0 = compiler_end(node0, "test")
    end1 = compiler_end(node1, "test")
    begin4 = compiler_begin(end0, "test")
    begin5 = compiler_begin(end1, "test")
    node2 = relay.add(begin4, begin5)
    end2 = compiler_end(node2, "test")

    dbegin0 = compiler_begin(in_5, "default")
    dbegin1 = compiler_begin(in_6, "default")
    node3 = relay.subtract(dbegin0, dbegin1)
    dbegin2 = compiler_begin(in_7, "default")
    dend1 = compiler_end(node3, "default")
    dbegin3 = compiler_begin(dend1, "default")
    node4 = relay.subtract(dbegin2, dbegin3)
    dend2 = compiler_end(node4, "default")

    begin6 = compiler_begin(end2, "test")
    begin7 = compiler_begin(dend2, "test")
    node5 = relay.add(begin6, begin7)
    end3 = compiler_end(node5, "test")
    end4 = compiler_end(node5, "test")
    dbegin4 = compiler_begin(in_8, "default")
    dbegin5 = compiler_begin(end3, "default")
    node6 = relay.subtract(dbegin4, dbegin5)
    begin8 = compiler_begin(in_9, "test")
    begin9 = compiler_begin(end4, "test")
    node7 = relay.add(begin8, begin9)
    end5 = compiler_end(node7, "test")

    dend3 = compiler_end(node6, "default")
    begin10 = compiler_begin(dend3, "test")
    begin11 = compiler_begin(end5, "test")
    node8 = relay.add(begin10, begin11)
    end6 = compiler_end(node8, "test")
    begin12 = compiler_begin(in_10, "test")
    begin13 = compiler_begin(end6, "test")
    node9 = relay.add(begin12, begin13)
    end7 = compiler_end(node9, "test")

    f = relay.Function([in_1, in_2, in_3, in_4, in_5, in_6, in_7, in_8, in_9, in_10], end7)
    mod = tvm.IRModule.from_expr(f)
    return mod
mod = annotated()
mod = relay.transform.MergeCompilerRegions()(mod)
mod = relay.transform.InferType()(mod)
print(mod)
def @main(%in_1: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %in_2: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %in_3: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %in_4: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %in_5: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %in_6: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %in_7: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %in_8: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %in_9: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %in_10: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */) -> Tensor[(10, 10), float32] {
  %0 = annotation.compiler_begin(%in_1, compiler="test") /* ty=Tensor[(10, 10), float32] */;
  %1 = annotation.compiler_begin(%in_2, compiler="test") /* ty=Tensor[(10, 10), float32] */;
  %2 = annotation.compiler_begin(%in_3, compiler="test") /* ty=Tensor[(10, 10), float32] */;
  %3 = annotation.compiler_begin(%in_4, compiler="test") /* ty=Tensor[(10, 10), float32] */;
  %4 = add(%0, %1) /* ty=Tensor[(10, 10), float32] */;
  %5 = add(%2, %3) /* ty=Tensor[(10, 10), float32] */;
  %6 = annotation.compiler_begin(%in_5, compiler="default") /* ty=Tensor[(10, 10), float32] */;
  %7 = annotation.compiler_begin(%in_6, compiler="default") /* ty=Tensor[(10, 10), float32] */;
  %8 = annotation.compiler_begin(%in_7, compiler="default") /* ty=Tensor[(10, 10), float32] */;
  %9 = subtract(%6, %7) /* ty=Tensor[(10, 10), float32] */;
  %10 = subtract(%8, %9) /* ty=Tensor[(10, 10), float32] */;
  %11 = annotation.compiler_end(%10, compiler="default") /* ty=Tensor[(10, 10), float32] */;
  %12 = add(%4, %5) /* ty=Tensor[(10, 10), float32] */;
  %13 = annotation.compiler_begin(%11, compiler="test") /* ty=Tensor[(10, 10), float32] */;
  %14 = add(%12, %13) /* ty=Tensor[(10, 10), float32] */;
  %15 = annotation.compiler_end(%14, compiler="test") /* ty=Tensor[(10, 10), float32] */;
  %16 = annotation.compiler_begin(%in_8, compiler="default") /* ty=Tensor[(10, 10), float32] */;
  %17 = annotation.compiler_begin(%15, compiler="default") /* ty=Tensor[(10, 10), float32] */;
  %18 = subtract(%16, %17) /* ty=Tensor[(10, 10), float32] */;
  %19 = annotation.compiler_end(%18, compiler="default") /* ty=Tensor[(10, 10), float32] */;
  %20 = annotation.compiler_begin(%in_9, compiler="test") /* ty=Tensor[(10, 10), float32] */;
  %21 = add(%20, %14) /* ty=Tensor[(10, 10), float32] */;
  %22 = annotation.compiler_end(%21, compiler="test") /* ty=Tensor[(10, 10), float32] */;
  %23 = annotation.compiler_begin(%19, compiler="test") /* ty=Tensor[(10, 10), float32] */;
  %24 = annotation.compiler_begin(%22, compiler="test") /* ty=Tensor[(10, 10), float32] */;
  %25 = annotation.compiler_begin(%in_10, compiler="test") /* ty=Tensor[(10, 10), float32] */;
  %26 = add(%23, %24) /* ty=Tensor[(10, 10), float32] */;
  %27 = add(%25, %26) /* ty=Tensor[(10, 10), float32] */;
  annotation.compiler_end(%27, compiler="test") /* ty=Tensor[(10, 10), float32] */
}