合并编译器区域#
参考:tvm/tests/python/relay/test_pass_merge_compiler_regions.py
compiler_begin()
和 compiler_end()
函数的主要作用是在 Relay 表达式中标记区域的开始和结束,该区域将由指定的编译器处理。这个函数通常用于优化和调度,允许不同的编译器处理不同的表达式区域。
import tvm
from tvm import relay
import tvm.relay.transform as transform
from tvm.relay.op.annotation import compiler_begin, compiler_end
from tvm.relay.testing import run_opt_pass
处理菱形区域#
在菱形计算图中存在的数据依赖关系是否被合并过程正确解决。
O = 目标支持
X = 目标不支持
O O
/ \\ / \\
O X --> O + + X
\\ / \\ /
O O
注意不能简单地将三个支持的算子合并在一起,否则两个子图都会依赖于另一个。
def diamond_graph_fanouts():
data = relay.var("data", shape=(10, 10))
cb_1 = compiler_begin(data, "test")
O_1 = relay.abs(cb_1)
ce_1 = compiler_end(O_1, "test")
ce_2 = compiler_end(O_1, "test")
cb_2 = compiler_begin(ce_1, "test")
cb_3 = compiler_begin(ce_2, "default")
O_2 = relay.nn.relu(cb_2)
ce_3 = compiler_end(O_2, "test")
X = relay.tanh(cb_3)
ce_4 = compiler_end(X, "default")
cb_4 = compiler_begin(ce_3, "test")
cb_5 = compiler_begin(ce_4, "test")
O_3 = relay.add(cb_4, cb_5)
ce_5 = compiler_end(O_3, "test")
diamond = relay.Function([data], ce_5)
return diamond
expr = diamond_graph_fanouts()
tvm.IRModule.from_expr(expr).show()
Show code cell output
def @main(%data: Tensor[(10, 10), float32]) {
%0 = annotation.compiler_begin(%data, compiler="test");
%1 = abs(%0);
%2 = annotation.compiler_end(%1, compiler="test");
%3 = annotation.compiler_begin(%2, compiler="test");
%4 = nn.relu(%3);
%5 = annotation.compiler_end(%4, compiler="test");
%6 = annotation.compiler_end(%1, compiler="test");
%7 = annotation.compiler_begin(%6, compiler="default");
%8 = tanh(%7);
%9 = annotation.compiler_end(%8, compiler="default");
%10 = annotation.compiler_begin(%5, compiler="test");
%11 = annotation.compiler_begin(%9, compiler="test");
%12 = add(%10, %11);
annotation.compiler_end(%12, compiler="test")
}
opt_expr = run_opt_pass(expr, relay.transform.MergeCompilerRegions())
tvm.IRModule.from_expr(opt_expr).show()
Show code cell output
def @main(%data: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */) -> Tensor[(10, 10), float32] {
%0 = annotation.compiler_begin(%data, compiler="test") /* ty=Tensor[(10, 10), float32] */;
%1 = abs(%0) /* ty=Tensor[(10, 10), float32] */;
%2 = nn.relu(%1) /* ty=Tensor[(10, 10), float32] */;
%3 = annotation.compiler_end(%2, compiler="test") /* ty=Tensor[(10, 10), float32] */;
%4 = annotation.compiler_end(%1, compiler="test") /* ty=Tensor[(10, 10), float32] */;
%5 = annotation.compiler_begin(%4, compiler="default") /* ty=Tensor[(10, 10), float32] */;
%6 = tanh(%5) /* ty=Tensor[(10, 10), float32] */;
%7 = annotation.compiler_end(%6, compiler="default") /* ty=Tensor[(10, 10), float32] */;
%8 = annotation.compiler_begin(%3, compiler="test") /* ty=Tensor[(10, 10), float32] */;
%9 = annotation.compiler_begin(%7, compiler="test") /* ty=Tensor[(10, 10), float32] */;
%10 = add(%8, %9) /* ty=Tensor[(10, 10), float32] */;
annotation.compiler_end(%10, compiler="test") /* ty=Tensor[(10, 10), float32] */
}
计算图示例#
参见 RFC 5830 蓝色节点是加法(目标:测试),红色节点是减法(目标:默认)。
原始计算图:
Show code cell content
def annotated():
in_1 = relay.var("in_1", shape=(10, 10), dtype="float32")
in_2 = relay.var("in_2", shape=(10, 10), dtype="float32")
in_3 = relay.var("in_3", shape=(10, 10), dtype="float32")
in_4 = relay.var("in_4", shape=(10, 10), dtype="float32")
in_5 = relay.var("in_5", shape=(10, 10), dtype="float32")
in_6 = relay.var("in_6", shape=(10, 10), dtype="float32")
in_7 = relay.var("in_7", shape=(10, 10), dtype="float32")
in_8 = relay.var("in_8", shape=(10, 10), dtype="float32")
in_9 = relay.var("in_9", shape=(10, 10), dtype="float32")
in_10 = relay.var("in_10", shape=(10, 10), dtype="float32")
begin0 = compiler_begin(in_1, "test")
begin1 = compiler_begin(in_2, "test")
begin2 = compiler_begin(in_3, "test")
begin3 = compiler_begin(in_4, "test")
node0 = relay.add(begin0, begin1)
node1 = relay.add(begin2, begin3)
end0 = compiler_end(node0, "test")
end1 = compiler_end(node1, "test")
begin4 = compiler_begin(end0, "test")
begin5 = compiler_begin(end1, "test")
node2 = relay.add(begin4, begin5)
end2 = compiler_end(node2, "test")
dbegin0 = compiler_begin(in_5, "default")
dbegin1 = compiler_begin(in_6, "default")
node3 = relay.subtract(dbegin0, dbegin1)
dbegin2 = compiler_begin(in_7, "default")
dend1 = compiler_end(node3, "default")
dbegin3 = compiler_begin(dend1, "default")
node4 = relay.subtract(dbegin2, dbegin3)
dend2 = compiler_end(node4, "default")
begin6 = compiler_begin(end2, "test")
begin7 = compiler_begin(dend2, "test")
node5 = relay.add(begin6, begin7)
end3 = compiler_end(node5, "test")
end4 = compiler_end(node5, "test")
dbegin4 = compiler_begin(in_8, "default")
dbegin5 = compiler_begin(end3, "default")
node6 = relay.subtract(dbegin4, dbegin5)
begin8 = compiler_begin(in_9, "test")
begin9 = compiler_begin(end4, "test")
node7 = relay.add(begin8, begin9)
end5 = compiler_end(node7, "test")
dend3 = compiler_end(node6, "default")
begin10 = compiler_begin(dend3, "test")
begin11 = compiler_begin(end5, "test")
node8 = relay.add(begin10, begin11)
end6 = compiler_end(node8, "test")
begin12 = compiler_begin(in_10, "test")
begin13 = compiler_begin(end6, "test")
node9 = relay.add(begin12, begin13)
end7 = compiler_end(node9, "test")
f = relay.Function([in_1, in_2, in_3, in_4, in_5, in_6, in_7, in_8, in_9, in_10], end7)
mod = tvm.IRModule.from_expr(f)
return mod
mod = annotated()
mod.show()
Show code cell output
def @main(%in_1: Tensor[(10, 10), float32], %in_2: Tensor[(10, 10), float32], %in_3: Tensor[(10, 10), float32], %in_4: Tensor[(10, 10), float32], %in_5: Tensor[(10, 10), float32], %in_6: Tensor[(10, 10), float32], %in_7: Tensor[(10, 10), float32], %in_8: Tensor[(10, 10), float32], %in_9: Tensor[(10, 10), float32], %in_10: Tensor[(10, 10), float32]) {
%0 = annotation.compiler_begin(%in_1, compiler="test");
%1 = annotation.compiler_begin(%in_2, compiler="test");
%2 = add(%0, %1);
%3 = annotation.compiler_end(%2, compiler="test");
%4 = annotation.compiler_begin(%in_3, compiler="test");
%5 = annotation.compiler_begin(%in_4, compiler="test");
%6 = add(%4, %5);
%7 = annotation.compiler_end(%6, compiler="test");
%8 = annotation.compiler_begin(%3, compiler="test");
%9 = annotation.compiler_begin(%7, compiler="test");
%10 = add(%8, %9);
%11 = annotation.compiler_end(%10, compiler="test");
%12 = annotation.compiler_begin(%in_5, compiler="default");
%13 = annotation.compiler_begin(%in_6, compiler="default");
%14 = subtract(%12, %13);
%15 = annotation.compiler_end(%14, compiler="default");
%16 = annotation.compiler_begin(%in_7, compiler="default");
%17 = annotation.compiler_begin(%15, compiler="default");
%18 = subtract(%16, %17);
%19 = annotation.compiler_end(%18, compiler="default");
%20 = annotation.compiler_begin(%11, compiler="test");
%21 = annotation.compiler_begin(%19, compiler="test");
%22 = add(%20, %21);
%23 = annotation.compiler_end(%22, compiler="test");
%24 = annotation.compiler_begin(%in_8, compiler="default");
%25 = annotation.compiler_begin(%23, compiler="default");
%26 = subtract(%24, %25);
%27 = annotation.compiler_end(%26, compiler="default");
%28 = annotation.compiler_end(%22, compiler="test");
%29 = annotation.compiler_begin(%in_9, compiler="test");
%30 = annotation.compiler_begin(%28, compiler="test");
%31 = add(%29, %30);
%32 = annotation.compiler_end(%31, compiler="test");
%33 = annotation.compiler_begin(%27, compiler="test");
%34 = annotation.compiler_begin(%32, compiler="test");
%35 = add(%33, %34);
%36 = annotation.compiler_end(%35, compiler="test");
%37 = annotation.compiler_begin(%in_10, compiler="test");
%38 = annotation.compiler_begin(%36, compiler="test");
%39 = add(%37, %38);
annotation.compiler_end(%39, compiler="test")
}
优化之后的计算图:
mod = relay.transform.MergeCompilerRegions()(mod)
mod = relay.transform.InferType()(mod)
mod.show()
Show code cell output
def @main(%in_1: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %in_2: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %in_3: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %in_4: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %in_5: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %in_6: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %in_7: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %in_8: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %in_9: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */, %in_10: Tensor[(10, 10), float32] /* ty=Tensor[(10, 10), float32] */) -> Tensor[(10, 10), float32] {
%0 = annotation.compiler_begin(%in_1, compiler="test") /* ty=Tensor[(10, 10), float32] */;
%1 = annotation.compiler_begin(%in_2, compiler="test") /* ty=Tensor[(10, 10), float32] */;
%2 = annotation.compiler_begin(%in_3, compiler="test") /* ty=Tensor[(10, 10), float32] */;
%3 = annotation.compiler_begin(%in_4, compiler="test") /* ty=Tensor[(10, 10), float32] */;
%4 = add(%0, %1) /* ty=Tensor[(10, 10), float32] */;
%5 = add(%2, %3) /* ty=Tensor[(10, 10), float32] */;
%6 = annotation.compiler_begin(%in_5, compiler="default") /* ty=Tensor[(10, 10), float32] */;
%7 = annotation.compiler_begin(%in_6, compiler="default") /* ty=Tensor[(10, 10), float32] */;
%8 = annotation.compiler_begin(%in_7, compiler="default") /* ty=Tensor[(10, 10), float32] */;
%9 = subtract(%6, %7) /* ty=Tensor[(10, 10), float32] */;
%10 = subtract(%8, %9) /* ty=Tensor[(10, 10), float32] */;
%11 = annotation.compiler_end(%10, compiler="default") /* ty=Tensor[(10, 10), float32] */;
%12 = add(%4, %5) /* ty=Tensor[(10, 10), float32] */;
%13 = annotation.compiler_begin(%11, compiler="test") /* ty=Tensor[(10, 10), float32] */;
%14 = add(%12, %13) /* ty=Tensor[(10, 10), float32] */;
%15 = annotation.compiler_end(%14, compiler="test") /* ty=Tensor[(10, 10), float32] */;
%16 = annotation.compiler_begin(%in_8, compiler="default") /* ty=Tensor[(10, 10), float32] */;
%17 = annotation.compiler_begin(%15, compiler="default") /* ty=Tensor[(10, 10), float32] */;
%18 = subtract(%16, %17) /* ty=Tensor[(10, 10), float32] */;
%19 = annotation.compiler_end(%18, compiler="default") /* ty=Tensor[(10, 10), float32] */;
%20 = annotation.compiler_begin(%in_9, compiler="test") /* ty=Tensor[(10, 10), float32] */;
%21 = add(%20, %14) /* ty=Tensor[(10, 10), float32] */;
%22 = annotation.compiler_end(%21, compiler="test") /* ty=Tensor[(10, 10), float32] */;
%23 = annotation.compiler_begin(%19, compiler="test") /* ty=Tensor[(10, 10), float32] */;
%24 = annotation.compiler_begin(%22, compiler="test") /* ty=Tensor[(10, 10), float32] */;
%25 = annotation.compiler_begin(%in_10, compiler="test") /* ty=Tensor[(10, 10), float32] */;
%26 = add(%23, %24) /* ty=Tensor[(10, 10), float32] */;
%27 = add(%25, %26) /* ty=Tensor[(10, 10), float32] */;
annotation.compiler_end(%27, compiler="test") /* ty=Tensor[(10, 10), float32] */
}
测试条件语句#
这个测试了限制区域在 if_else 控制流中成功传播。
O = 目标支持
X = 目标不支持
O1 - - - | O1 --|
| | |
X | X
| | |
If cond ? O1: X | --> + + If cond ? O1: X +
| | |
O2 <- - -| O2 <-|
避免 O1 合并到 O2。
target = "test_if_else_merge"
@tvm.ir.register_op_attr("sigmoid", "target." + target)
def sigmoid(expr): # pylint: disable=unused-variable
return True
@tvm.ir.register_op_attr("erf", "target." + target)
def erf(expr): # pylint: disable=unused-variable
return True
@tvm.ir.register_op_attr("add", "target." + target)
def add(expr): # pylint: disable=unused-variable
return True
def get_mod():
data = relay.var("data", shape=(1, 32))
add0 = relay.add(data, data)
sub0 = relay.subtract(add0, data)
eq = relay.equal(relay.sum(add0), relay.sum(sub0))
true_branch = relay.sigmoid(add0)
false_branch = relay.sigmoid(sub0)
ife = relay.If(eq, true_branch, false_branch)
erf = relay.erf(ife)
out = relay.add(add0, erf)
func = relay.Function([data], out)
mod = tvm.IRModule.from_expr(func)
return mod
for annotate_non_call_ops in [True, False]:
result = transform.AnnotateTarget(target, annotate_non_call_ops)(get_mod())
merge = transform.MergeCompilerRegions()(result)
# 确保分区完成而没有段错误。
partition = transform.PartitionGraph()(merge)
partition.show()
def @main(%data: Tensor[(1, 32), float32] /* ty=Tensor[(1, 32), float32] */) -> Tensor[(1, 32), float32] {
%0 = @tvmgen_default_test_if_else_merge_main_0(%data) /* ty=(Tensor[(1, 32), float32], Tensor[(1, 32), float32]) */;
%1 = %0.0 /* ty=Tensor[(1, 32), float32] */;
%2 = subtract(%1, %data) /* ty=Tensor[(1, 32), float32] */;
%3 = sum(%1) /* ty=float32 */;
%4 = sum(%2) /* ty=float32 */;
%5 = equal(%3, %4) /* ty=bool */;
%6 = if (%5) {
%0.1 /* ty=Tensor[(1, 32), float32] */
} else {
@tvmgen_default_test_if_else_merge_main_3(%2) /* ty=Tensor[(1, 32), float32] */
};
@tvmgen_default_test_if_else_merge_main_2(%1, %6) /* ty=Tensor[(1, 32), float32] */
}
def @tvmgen_default_test_if_else_merge_main_0(%test_if_else_merge_0_i0: Tensor[(1, 32), float32] /* ty=Tensor[(1, 32), float32] */, Compiler="test_if_else_merge", Primitive=1, Inline=1, global_symbol="tvmgen_default_test_if_else_merge_main_0") -> (Tensor[(1, 32), float32], Tensor[(1, 32), float32]) {
%7 = add(%test_if_else_merge_0_i0, %test_if_else_merge_0_i0) /* ty=Tensor[(1, 32), float32] */;
%8 = sigmoid(%7) /* ty=Tensor[(1, 32), float32] */;
(%7, %8) /* ty=(Tensor[(1, 32), float32], Tensor[(1, 32), float32]) */
}
def @tvmgen_default_test_if_else_merge_main_2(%test_if_else_merge_2_i0: Tensor[(1, 32), float32] /* ty=Tensor[(1, 32), float32] */, %test_if_else_merge_2_i1: Tensor[(1, 32), float32] /* ty=Tensor[(1, 32), float32] */, Compiler="test_if_else_merge", Primitive=1, Inline=1, global_symbol="tvmgen_default_test_if_else_merge_main_2") -> Tensor[(1, 32), float32] {
%9 = erf(%test_if_else_merge_2_i1) /* ty=Tensor[(1, 32), float32] */;
add(%test_if_else_merge_2_i0, %9) /* ty=Tensor[(1, 32), float32] */
}
def @tvmgen_default_test_if_else_merge_main_3(%test_if_else_merge_3_i0: Tensor[(1, 32), float32] /* ty=Tensor[(1, 32), float32] */, Compiler="test_if_else_merge", Primitive=1, Inline=1, global_symbol="tvmgen_default_test_if_else_merge_main_3") -> Tensor[(1, 32), float32] {
sigmoid(%test_if_else_merge_3_i0) /* ty=Tensor[(1, 32), float32] */
}