合并编译器区域#
参考:tvm/tests/python/relay/test_pass_merge_compiler_regions.py
compiler_begin()
和 compiler_end()
函数的主要作用是在 Relay 表达式中标记区域的开始和结束,该区域将由指定的编译器处理。这个函数通常用于优化和调度,允许不同的编译器处理不同的表达式区域。
import tvm
from tvm import relay
import tvm.relay.transform as transform
from tvm.relay.op.annotation import compiler_begin, compiler_end
from tvm.relay.testing import run_opt_pass
处理菱形区域#
在菱形计算图中存在的数据依赖关系是否被合并过程正确解决。
O = 目标支持
X = 目标不支持
O O
/ \\ / \\
O X --> O + + X
\\ / \\ /
O O
注意不能简单地将三个支持的算子合并在一起,否则两个子图都会依赖于另一个。
def diamond_graph_fanouts():
data = relay.var("data", shape=(10, 10))
cb_1 = compiler_begin(data, "test")
O_1 = relay.abs(cb_1)
ce_1 = compiler_end(O_1, "test")
ce_2 = compiler_end(O_1, "test")
cb_2 = compiler_begin(ce_1, "test")
cb_3 = compiler_begin(ce_2, "default")
O_2 = relay.nn.relu(cb_2)
ce_3 = compiler_end(O_2, "test")
X = relay.tanh(cb_3)
ce_4 = compiler_end(X, "default")
cb_4 = compiler_begin(ce_3, "test")
cb_5 = compiler_begin(ce_4, "test")
O_3 = relay.add(cb_4, cb_5)
ce_5 = compiler_end(O_3, "test")
diamond = relay.Function([data], ce_5)
return diamond
expr = diamond_graph_fanouts()
tvm.IRModule.from_expr(expr).show()
opt_expr = run_opt_pass(expr, relay.transform.MergeCompilerRegions())
tvm.IRModule.from_expr(opt_expr).show()
计算图示例#
参见 RFC 5830 蓝色节点是加法(目标:测试),红色节点是减法(目标:默认)。
原始计算图:
mod = annotated()
mod.show()
优化之后的计算图:
mod = relay.transform.MergeCompilerRegions()(mod)
mod = relay.transform.InferType()(mod)
mod.show()
测试条件语句#
这个测试了限制区域在 if_else 控制流中成功传播。
O = 目标支持
X = 目标不支持
O1 - - - | O1 --|
| | |
X | X
| | |
If cond ? O1: X | --> + + If cond ? O1: X +
| | |
O2 <- - -| O2 <-|
避免 O1 合并到 O2。
target = "test_if_else_merge"
@tvm.ir.register_op_attr("sigmoid", "target." + target)
def sigmoid(expr): # pylint: disable=unused-variable
return True
@tvm.ir.register_op_attr("erf", "target." + target)
def erf(expr): # pylint: disable=unused-variable
return True
@tvm.ir.register_op_attr("add", "target." + target)
def add(expr): # pylint: disable=unused-variable
return True
def get_mod():
data = relay.var("data", shape=(1, 32))
add0 = relay.add(data, data)
sub0 = relay.subtract(add0, data)
eq = relay.equal(relay.sum(add0), relay.sum(sub0))
true_branch = relay.sigmoid(add0)
false_branch = relay.sigmoid(sub0)
ife = relay.If(eq, true_branch, false_branch)
erf = relay.erf(ife)
out = relay.add(add0, erf)
func = relay.Function([data], out)
mod = tvm.IRModule.from_expr(func)
return mod
for annotate_non_call_ops in [True, False]:
result = transform.AnnotateTarget(target, annotate_non_call_ops)(get_mod())
merge = transform.MergeCompilerRegions()(result)
# 确保分区完成而没有段错误。
partition = transform.PartitionGraph()(merge)
partition.show()
def @main(%data: Tensor[(1, 32), float32] /* ty=Tensor[(1, 32), float32] */) -> Tensor[(1, 32), float32] {
%0 = @tvmgen_default_test_if_else_merge_main_0(%data) /* ty=(Tensor[(1, 32), float32], Tensor[(1, 32), float32]) */;
%1 = %0.0 /* ty=Tensor[(1, 32), float32] */;
%2 = subtract(%1, %data) /* ty=Tensor[(1, 32), float32] */;
%3 = sum(%1) /* ty=float32 */;
%4 = sum(%2) /* ty=float32 */;
%5 = equal(%3, %4) /* ty=bool */;
%6 = if (%5) {
%0.1 /* ty=Tensor[(1, 32), float32] */
} else {
@tvmgen_default_test_if_else_merge_main_3(%2) /* ty=Tensor[(1, 32), float32] */
};
@tvmgen_default_test_if_else_merge_main_2(%1, %6) /* ty=Tensor[(1, 32), float32] */
}
def @tvmgen_default_test_if_else_merge_main_0(%test_if_else_merge_0_i0: Tensor[(1, 32), float32] /* ty=Tensor[(1, 32), float32] */, Compiler="test_if_else_merge", Primitive=1, Inline=1, global_symbol="tvmgen_default_test_if_else_merge_main_0") -> (Tensor[(1, 32), float32], Tensor[(1, 32), float32]) {
%7 = add(%test_if_else_merge_0_i0, %test_if_else_merge_0_i0) /* ty=Tensor[(1, 32), float32] */;
%8 = sigmoid(%7) /* ty=Tensor[(1, 32), float32] */;
(%7, %8) /* ty=(Tensor[(1, 32), float32], Tensor[(1, 32), float32]) */
}
def @tvmgen_default_test_if_else_merge_main_2(%test_if_else_merge_2_i0: Tensor[(1, 32), float32] /* ty=Tensor[(1, 32), float32] */, %test_if_else_merge_2_i1: Tensor[(1, 32), float32] /* ty=Tensor[(1, 32), float32] */, Compiler="test_if_else_merge", Primitive=1, Inline=1, global_symbol="tvmgen_default_test_if_else_merge_main_2") -> Tensor[(1, 32), float32] {
%9 = erf(%test_if_else_merge_2_i1) /* ty=Tensor[(1, 32), float32] */;
add(%test_if_else_merge_2_i0, %9) /* ty=Tensor[(1, 32), float32] */
}
def @tvmgen_default_test_if_else_merge_main_3(%test_if_else_merge_3_i0: Tensor[(1, 32), float32] /* ty=Tensor[(1, 32), float32] */, Compiler="test_if_else_merge", Primitive=1, Inline=1, global_symbol="tvmgen_default_test_if_else_merge_main_3") -> Tensor[(1, 32), float32] {
sigmoid(%test_if_else_merge_3_i0) /* ty=Tensor[(1, 32), float32] */
}