重叠分区#
from testing import viz_expr # 可视化 relay
import tvm
from tvm import relay
from tvm.relay.dataflow_pattern import *
x = relay.var("x")
var = relay.var("var")
mean = relay.var("mean")
beta = relay.var("beta")
gamma = relay.var("gamma")
BN = relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon=1e-5)
T1 = BN[0]
T2 = BN[0]
add = T1 + T2
viz_expr(add)
print(tvm.IRModule.from_expr(add))
def @main(%x, %gamma, %beta, %mean, %var) {
%0 = nn.batch_norm(%x, %gamma, %beta, %mean, %var);
%1 = %0.0;
%2 = %0.0;
add(%1, %2)
}
构建模式:
x = wildcard()
gamma = wildcard()
beta = wildcard()
moving_mean = wildcard()
moving_var = wildcard()
bn_node = is_op("nn.batch_norm")(x, gamma, beta, moving_mean, moving_var)
tuple_get_item_node = TupleGetItemPattern(bn_node, 0)
分割计算图:
partitioned = tuple_get_item_node.partition(add)
assert partitioned == add
另外一个例子:
pattern = is_op("nn.relu")(is_op("nn.conv2d")(wildcard(), wildcard()))
x = relay.var("input")
w = relay.var("weight")
conv2d = relay.op.nn.conv2d(x, w)
relu = relay.op.nn.relu(conv2d)
out = relu + conv2d
assert pattern.partition(out) == out