模糊主体的函数分区#
from testing import viz_expr # 可视化 relay
import tvm
from tvm import relay
from tvm.relay.dataflow_pattern import *
通过支配者分析允许重写具有模糊主体的函数。
x = relay.var("x")
w = relay.var("w")
b = relay.var("b")
x1 = relay.var("x1")
w1 = relay.var("w1")
func = relay.Function([x1, w1], relay.nn.conv2d(x1, w1))
expr = func(x, w) + b + b
wc_x = wildcard()
wc_w = wildcard()
wc_b = wildcard()
wc_x1 = wildcard()
wc_w1 = wildcard()
func_pattern = FunctionPattern([wc_x1, wc_w1], wildcard())
pattern = func_pattern(wc_x, wc_w) + wc_b
new_expr = pattern.partition(expr)
print(tvm.IRModule.from_expr(expr))
def @main(%x, %w, %b) {
%0 = fn (%x1, %w1) {
nn.conv2d(%x1, %w1, padding=[0, 0, 0, 0])
};
%1 = %0(%x, %w);
%2 = add(%1, %b);
add(%2, %b)
}
print(tvm.IRModule.from_expr(new_expr))
def @main(%x, %w, %b) {
%2 = fn (%FunctionVar_0_0, %FunctionVar_0_1, %FunctionVar_0_2, PartitionedFromPattern="nn.conv2d_FunctionCall_add_") {
%0 = fn (%x1, %w1) {
nn.conv2d(%x1, %w1, padding=[0, 0, 0, 0])
};
%1 = %0(%FunctionVar_0_0, %FunctionVar_0_1);
add(%1, %FunctionVar_0_2)
};
%3 = %2(%x, %w, %b);
add(%3, %b)
}