# 合并编译器区域

参考：`tvm/tests/python/relay/test_pass_merge_compiler_regions.py`

{func}`~tvm.relay.op.annotation.compiler_begin` 和 {func}`~tvm.relay.op.annotation.compiler_end` 函数的主要作用是在 Relay 表达式中标记区域的开始和结束，该区域将由指定的编译器处理。这个函数通常用于优化和调度，允许不同的编译器处理不同的表达式区域。

In [1]:
import tvm
from tvm import relay
import tvm.relay.transform as transform
from tvm.relay.op.annotation import compiler_begin, compiler_end
from tvm.relay.testing import run_opt_pass

## 处理菱形区域

在菱形计算图中存在的数据依赖关系是否被合并过程正确解决。

````
    O = 目标支持
    X = 目标不支持

       O         O
      / \\      /               \\
     O   X --> O    +       +    X
     \\ /             \\ /
       O                O
````

注意不能简单地将三个支持的算子合并在一起，否则两个子图都会依赖于另一个。

In [2]:
def diamond_graph_fanouts():
    data = relay.var("data", shape=(10, 10))
    cb_1 = compiler_begin(data, "test")
    O_1 = relay.abs(cb_1)
    ce_1 = compiler_end(O_1, "test")
    ce_2 = compiler_end(O_1, "test")
    cb_2 = compiler_begin(ce_1, "test")
    cb_3 = compiler_begin(ce_2, "default")
    O_2 = relay.nn.relu(cb_2)
    ce_3 = compiler_end(O_2, "test")

    X = relay.tanh(cb_3)
    ce_4 = compiler_end(X, "default")

    cb_4 = compiler_begin(ce_3, "test")
    cb_5 = compiler_begin(ce_4, "test")
    O_3 = relay.add(cb_4, cb_5)
    ce_5 = compiler_end(O_3, "test")

    diamond = relay.Function([data], ce_5)
    return diamond

In [3]:
expr = diamond_graph_fanouts()
tvm.IRModule.from_expr(expr).show()

In [4]:
opt_expr = run_opt_pass(expr, relay.transform.MergeCompilerRegions())
tvm.IRModule.from_expr(opt_expr).show()

## 计算图示例

参见 [RFC 5830](https://discuss.tvm.apache.org/t/relay-improved-graph-partitioning-algorithm/5830) 蓝色节点是加法（目标：测试），红色节点是减法（目标：默认）。

原始计算图：

In [5]:
def annotated():
    in_1 = relay.var("in_1", shape=(10, 10), dtype="float32")
    in_2 = relay.var("in_2", shape=(10, 10), dtype="float32")
    in_3 = relay.var("in_3", shape=(10, 10), dtype="float32")
    in_4 = relay.var("in_4", shape=(10, 10), dtype="float32")
    in_5 = relay.var("in_5", shape=(10, 10), dtype="float32")
    in_6 = relay.var("in_6", shape=(10, 10), dtype="float32")
    in_7 = relay.var("in_7", shape=(10, 10), dtype="float32")
    in_8 = relay.var("in_8", shape=(10, 10), dtype="float32")
    in_9 = relay.var("in_9", shape=(10, 10), dtype="float32")
    in_10 = relay.var("in_10", shape=(10, 10), dtype="float32")

    begin0 = compiler_begin(in_1, "test")
    begin1 = compiler_begin(in_2, "test")
    begin2 = compiler_begin(in_3, "test")
    begin3 = compiler_begin(in_4, "test")
    node0 = relay.add(begin0, begin1)
    node1 = relay.add(begin2, begin3)
    end0 = compiler_end(node0, "test")
    end1 = compiler_end(node1, "test")
    begin4 = compiler_begin(end0, "test")
    begin5 = compiler_begin(end1, "test")
    node2 = relay.add(begin4, begin5)
    end2 = compiler_end(node2, "test")

    dbegin0 = compiler_begin(in_5, "default")
    dbegin1 = compiler_begin(in_6, "default")
    node3 = relay.subtract(dbegin0, dbegin1)
    dbegin2 = compiler_begin(in_7, "default")
    dend1 = compiler_end(node3, "default")
    dbegin3 = compiler_begin(dend1, "default")
    node4 = relay.subtract(dbegin2, dbegin3)
    dend2 = compiler_end(node4, "default")

    begin6 = compiler_begin(end2, "test")
    begin7 = compiler_begin(dend2, "test")
    node5 = relay.add(begin6, begin7)
    end3 = compiler_end(node5, "test")
    end4 = compiler_end(node5, "test")
    dbegin4 = compiler_begin(in_8, "default")
    dbegin5 = compiler_begin(end3, "default")
    node6 = relay.subtract(dbegin4, dbegin5)
    begin8 = compiler_begin(in_9, "test")
    begin9 = compiler_begin(end4, "test")
    node7 = relay.add(begin8, begin9)
    end5 = compiler_end(node7, "test")

    dend3 = compiler_end(node6, "default")
    begin10 = compiler_begin(dend3, "test")
    begin11 = compiler_begin(end5, "test")
    node8 = relay.add(begin10, begin11)
    end6 = compiler_end(node8, "test")
    begin12 = compiler_begin(in_10, "test")
    begin13 = compiler_begin(end6, "test")
    node9 = relay.add(begin12, begin13)
    end7 = compiler_end(node9, "test")

    f = relay.Function([in_1, in_2, in_3, in_4, in_5, in_6, in_7, in_8, in_9, in_10], end7)
    mod = tvm.IRModule.from_expr(f)
    return mod

In [6]:
mod = annotated()
mod.show()

优化之后的计算图：

In [7]:
mod = relay.transform.MergeCompilerRegions()(mod)
mod = relay.transform.InferType()(mod)
mod.show()

## 测试条件语句

这个测试了限制区域在 if_else 控制流中成功传播。

```
    O = 目标支持
    X = 目标不支持

           O1 - - - |      O1 --|
            |       |               |
            X       |               X
            |       |                              |
    If cond ? O1: X | -->       +       +  If cond ? O1: X  +
            |       |                                           |
           O2 <- - -|                                          O2 <-|

```
    
避免 O1 合并到 O2。

In [8]:
target = "test_if_else_merge"

@tvm.ir.register_op_attr("sigmoid", "target." + target)
def sigmoid(expr):  # pylint: disable=unused-variable
    return True

@tvm.ir.register_op_attr("erf", "target." + target)
def erf(expr):  # pylint: disable=unused-variable
    return True

@tvm.ir.register_op_attr("add", "target." + target)
def add(expr):  # pylint: disable=unused-variable
    return True

In [9]:
def get_mod():
    data = relay.var("data", shape=(1, 32))
    add0 = relay.add(data, data)
    sub0 = relay.subtract(add0, data)
    eq = relay.equal(relay.sum(add0), relay.sum(sub0))

    true_branch = relay.sigmoid(add0)
    false_branch = relay.sigmoid(sub0)
    ife = relay.If(eq, true_branch, false_branch)
    erf = relay.erf(ife)
    out = relay.add(add0, erf)
    func = relay.Function([data], out)
    mod = tvm.IRModule.from_expr(func)

    return mod

In [10]:
for annotate_non_call_ops in [True, False]:
    result = transform.AnnotateTarget(target, annotate_non_call_ops)(get_mod())
    merge = transform.MergeCompilerRegions()(result)
    # 确保分区完成而没有段错误。
    partition = transform.PartitionGraph()(merge)
    

In [12]:
partition.show()