合并编译器区域#
菱形计算图#
定义如下数据依赖:
O O
/ \\ / \\
O X --> O + + X
\\ / \\ /
O O
其中 O
表示 target
支持的算子,X
表示 target
不支持的算子。
注意,不能仅仅将三个支持的算子合并在一起,否则两个子图将依赖于另一个子图。
import tvm
from tvm import relay
from tvm.relay.op.annotation import compiler_begin, compiler_end
from tvm.relay.testing import run_opt_pass
def diamond_graph_fanouts():
data = relay.var("data", shape=(10, 10))
cb_1 = compiler_begin(data, "test")
O_1 = relay.abs(cb_1)
ce_1 = compiler_end(O_1, "test")
ce_2 = compiler_end(O_1, "test")
cb_2 = compiler_begin(ce_1, "test")
cb_3 = compiler_begin(ce_2, "default")
O_2 = relay.nn.relu(cb_2)
ce_3 = compiler_end(O_2, "test")
X = relay.tanh(cb_3)
ce_4 = compiler_end(X, "default")
cb_4 = compiler_begin(ce_3, "test")
cb_5 = compiler_begin(ce_4, "test")
O_3 = relay.add(cb_4, cb_5)
ce_5 = compiler_end(O_3, "test")
diamond = relay.Function([data], ce_5)
return diamond
mod = tvm.IRModule.from_expr(diamond_graph_fanouts())
print(f"合并之前:{mod}")
mod = relay.transform.MergeCompilerRegions()(mod)
print(f"合并之后:{mod}")
合并之前:def @main(%data: Tensor[(10, 10), float32]) {
%0 = annotation.compiler_begin(%data, compiler="test");
%1 = abs(%0);
%2 = annotation.compiler_end(%1, compiler="test");
%3 = annotation.compiler_begin(%2, compiler="test");
%4 = nn.relu(%3);
%5 = annotation.compiler_end(%4, compiler="test");
%6 = annotation.compiler_end(%1, compiler="test");
%7 = annotation.compiler_begin(%6, compiler="default");
%8 = tanh(%7);
%9 = annotation.compiler_end(%8, compiler="default");
%10 = annotation.compiler_begin(%5, compiler="test");
%11 = annotation.compiler_begin(%9, compiler="test");
%12 = add(%10, %11);
annotation.compiler_end(%12, compiler="test")
}
合并之后:def @main(%data: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */) -> Tensor[(10, 10), float32] {
%0 = annotation.compiler_begin(%data, compiler="test") /* ty=Tensor[(10, 10), float32] */;
%1 = abs(%0) /* ty=Tensor[(10, 10), float32] */;
%2 = nn.relu(%1) /* ty=Tensor[(10, 10), float32] */;
%3 = annotation.compiler_end(%2, compiler="test") /* ty=Tensor[(10, 10), float32] */;
%4 = annotation.compiler_end(%1, compiler="test") /* ty=Tensor[(10, 10), float32] */;
%5 = annotation.compiler_begin(%4, compiler="default") /* ty=Tensor[(10, 10), float32] */;
%6 = tanh(%5) /* ty=Tensor[(10, 10), float32] */;
%7 = annotation.compiler_end(%6, compiler="default") /* ty=Tensor[(10, 10), float32] */;
%8 = annotation.compiler_begin(%3, compiler="test") /* ty=Tensor[(10, 10), float32] */;
%9 = annotation.compiler_begin(%7, compiler="test") /* ty=Tensor[(10, 10), float32] */;
%10 = add(%8, %9) /* ty=Tensor[(10, 10), float32] */;
annotation.compiler_end(%10, compiler="test") /* ty=Tensor[(10, 10), float32] */
}
if-else
测试#
O1 - - - | O1 --|
| | |
X | X
| | |
If cond ? O1: X | --> + + If cond ? O1: X +
| | |
O2 <- - -| O2 <-|
target = "test_if_else_merge"
@tvm.ir.register_op_attr("sigmoid", "target." + target)
def sigmoid(expr): # pylint: disable=unused-variable
return True
@tvm.ir.register_op_attr("erf", "target." + target)
def erf(expr): # pylint: disable=unused-variable
return True
@tvm.ir.register_op_attr("add", "target." + target)
def add(expr): # pylint: disable=unused-variable
return True
def get_mod():
data = relay.var("data", shape=(1, 32))
add0 = relay.add(data, data)
sub0 = relay.subtract(add0, data)
eq = relay.equal(relay.sum(add0), relay.sum(sub0))
true_branch = relay.sigmoid(add0)
false_branch = relay.sigmoid(sub0)
ife = relay.If(eq, true_branch, false_branch)
erf = relay.erf(ife)
out = relay.add(add0, erf)
func = relay.Function([data], out)
mod = tvm.IRModule.from_expr(func)
return mod
for annotate_non_call_ops in [True, False]:
result = relay.transform.AnnotateTarget(target, annotate_non_call_ops)(get_mod())
merge = relay.transform.MergeCompilerRegions()(result)
# Ensure partition finished without segment fault.
partition = relay.transform.PartitionGraph()(merge)
合并计算图示例#
参考:RFC 5830
def annotated():
in_1 = relay.var("in_1", shape=(10, 10), dtype="float32")
in_2 = relay.var("in_2", shape=(10, 10), dtype="float32")
in_3 = relay.var("in_3", shape=(10, 10), dtype="float32")
in_4 = relay.var("in_4", shape=(10, 10), dtype="float32")
in_5 = relay.var("in_5", shape=(10, 10), dtype="float32")
in_6 = relay.var("in_6", shape=(10, 10), dtype="float32")
in_7 = relay.var("in_7", shape=(10, 10), dtype="float32")
in_8 = relay.var("in_8", shape=(10, 10), dtype="float32")
in_9 = relay.var("in_9", shape=(10, 10), dtype="float32")
in_10 = relay.var("in_10", shape=(10, 10), dtype="float32")
begin0 = compiler_begin(in_1, "test")
begin1 = compiler_begin(in_2, "test")
begin2 = compiler_begin(in_3, "test")
begin3 = compiler_begin(in_4, "test")
node0 = relay.add(begin0, begin1)
node1 = relay.add(begin2, begin3)
end0 = compiler_end(node0, "test")
end1 = compiler_end(node1, "test")
begin4 = compiler_begin(end0, "test")
begin5 = compiler_begin(end1, "test")
node2 = relay.add(begin4, begin5)
end2 = compiler_end(node2, "test")
dbegin0 = compiler_begin(in_5, "default")
dbegin1 = compiler_begin(in_6, "default")
node3 = relay.subtract(dbegin0, dbegin1)
dbegin2 = compiler_begin(in_7, "default")
dend1 = compiler_end(node3, "default")
dbegin3 = compiler_begin(dend1, "default")
node4 = relay.subtract(dbegin2, dbegin3)
dend2 = compiler_end(node4, "default")
begin6 = compiler_begin(end2, "test")
begin7 = compiler_begin(dend2, "test")
node5 = relay.add(begin6, begin7)
end3 = compiler_end(node5, "test")
end4 = compiler_end(node5, "test")
dbegin4 = compiler_begin(in_8, "default")
dbegin5 = compiler_begin(end3, "default")
node6 = relay.subtract(dbegin4, dbegin5)
begin8 = compiler_begin(in_9, "test")
begin9 = compiler_begin(end4, "test")
node7 = relay.add(begin8, begin9)
end5 = compiler_end(node7, "test")
dend3 = compiler_end(node6, "default")
begin10 = compiler_begin(dend3, "test")
begin11 = compiler_begin(end5, "test")
node8 = relay.add(begin10, begin11)
end6 = compiler_end(node8, "test")
begin12 = compiler_begin(in_10, "test")
begin13 = compiler_begin(end6, "test")
node9 = relay.add(begin12, begin13)
end7 = compiler_end(node9, "test")
f = relay.Function([in_1, in_2, in_3, in_4, in_5, in_6, in_7, in_8, in_9, in_10], end7)
mod = tvm.IRModule.from_expr(f)
return mod
mod = annotated()
mod = relay.transform.MergeCompilerRegions()(mod)
mod = relay.transform.InferType()(mod)
print(mod)
def @main(%in_1: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %in_2: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %in_3: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %in_4: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %in_5: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %in_6: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %in_7: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %in_8: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %in_9: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %in_10: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */) -> Tensor[(10, 10), float32] {
%0 = annotation.compiler_begin(%in_1, compiler="test") /* ty=Tensor[(10, 10), float32] */;
%1 = annotation.compiler_begin(%in_2, compiler="test") /* ty=Tensor[(10, 10), float32] */;
%2 = annotation.compiler_begin(%in_3, compiler="test") /* ty=Tensor[(10, 10), float32] */;
%3 = annotation.compiler_begin(%in_4, compiler="test") /* ty=Tensor[(10, 10), float32] */;
%4 = add(%0, %1) /* ty=Tensor[(10, 10), float32] */;
%5 = add(%2, %3) /* ty=Tensor[(10, 10), float32] */;
%6 = annotation.compiler_begin(%in_5, compiler="default") /* ty=Tensor[(10, 10), float32] */;
%7 = annotation.compiler_begin(%in_6, compiler="default") /* ty=Tensor[(10, 10), float32] */;
%8 = annotation.compiler_begin(%in_7, compiler="default") /* ty=Tensor[(10, 10), float32] */;
%9 = subtract(%6, %7) /* ty=Tensor[(10, 10), float32] */;
%10 = subtract(%8, %9) /* ty=Tensor[(10, 10), float32] */;
%11 = annotation.compiler_end(%10, compiler="default") /* ty=Tensor[(10, 10), float32] */;
%12 = add(%4, %5) /* ty=Tensor[(10, 10), float32] */;
%13 = annotation.compiler_begin(%11, compiler="test") /* ty=Tensor[(10, 10), float32] */;
%14 = add(%12, %13) /* ty=Tensor[(10, 10), float32] */;
%15 = annotation.compiler_end(%14, compiler="test") /* ty=Tensor[(10, 10), float32] */;
%16 = annotation.compiler_begin(%in_8, compiler="default") /* ty=Tensor[(10, 10), float32] */;
%17 = annotation.compiler_begin(%15, compiler="default") /* ty=Tensor[(10, 10), float32] */;
%18 = subtract(%16, %17) /* ty=Tensor[(10, 10), float32] */;
%19 = annotation.compiler_end(%18, compiler="default") /* ty=Tensor[(10, 10), float32] */;
%20 = annotation.compiler_begin(%in_9, compiler="test") /* ty=Tensor[(10, 10), float32] */;
%21 = add(%20, %14) /* ty=Tensor[(10, 10), float32] */;
%22 = annotation.compiler_end(%21, compiler="test") /* ty=Tensor[(10, 10), float32] */;
%23 = annotation.compiler_begin(%19, compiler="test") /* ty=Tensor[(10, 10), float32] */;
%24 = annotation.compiler_begin(%22, compiler="test") /* ty=Tensor[(10, 10), float32] */;
%25 = annotation.compiler_begin(%in_10, compiler="test") /* ty=Tensor[(10, 10), float32] */;
%26 = add(%23, %24) /* ty=Tensor[(10, 10), float32] */;
%27 = add(%25, %26) /* ty=Tensor[(10, 10), float32] */;
annotation.compiler_end(%27, compiler="test") /* ty=Tensor[(10, 10), float32] */
}