constant embedding 分区#
from testing import viz_expr # 可视化 relay
import tvm
from tvm import relay
from tvm.relay.dataflow_pattern import *
def conv_bias_relu(x, w, b):
conv2d = relay.op.nn.conv2d(x, w)
bias_add = relay.op.nn.bias_add(conv2d, b)
relu = relay.op.nn.relu(bias_add)
return relu
构建计算图:
x = relay.var("x")
w = relay.var("w")
wc = relay.const(1)
b = relay.var("b")
relu = conv_bias_relu(x, w, b)
reluc = conv_bias_relu(x, wc, b)
# viz_expr(relu)
print(tvm.IRModule.from_expr(relu))
def @main(%x, %w, %b) {
%0 = nn.conv2d(%x, %w, padding=[0, 0, 0, 0]);
%1 = nn.bias_add(%0, %b);
nn.relu(%1)
}
print(tvm.IRModule.from_expr(reluc))
def @main(%x, %b) {
%0 = nn.conv2d(%x, 1, padding=[0, 0, 0, 0]);
%1 = nn.bias_add(%0, %b);
nn.relu(%1)
}
构建模式(解除 wildcard()
匹配):
pattern = is_op("nn.relu")(
is_op("nn.bias_add")(is_op("nn.conv2d")(wildcard(), wildcard()), wildcard())
)
分割计算图:
partitioned = pattern.partition(relu)
print(tvm.IRModule.from_expr(partitioned))
def @main(%x, %w, %b) {
%2 = fn (%FunctionVar_0_0, %FunctionVar_0_1, %FunctionVar_0_2, PartitionedFromPattern="nn.conv2d_nn.bias_add_nn.relu_") {
%0 = nn.conv2d(%FunctionVar_0_0, %FunctionVar_0_1, padding=[0, 0, 0, 0]);
%1 = nn.bias_add(%0, %FunctionVar_0_2);
nn.relu(%1)
};
%2(%x, %w, %b)
}
partitioned = pattern.partition(reluc)
print(tvm.IRModule.from_expr(partitioned))
def @main(%x, %b) {
%2 = fn (%FunctionVar_0_0, %FunctionVar_0_1, %FunctionVar_0_2, PartitionedFromPattern="nn.conv2d_nn.bias_add_nn.relu_") {
%0 = nn.conv2d(%FunctionVar_0_0, %FunctionVar_0_1, padding=[0, 0, 0, 0]);
%1 = nn.bias_add(%0, %FunctionVar_0_2);
nn.relu(%1)
};
%2(%x, 1, %b)
}
构建模式(解除 input 匹配):
pattern = is_op("nn.relu")(
is_op("nn.bias_add")(is_op("nn.conv2d")(wildcard(), is_var()), wildcard())
)
partitioned = pattern.partition(relu)
print(tvm.IRModule.from_expr(partitioned))
def @main(%x, %w, %b) {
%2 = fn (%FunctionVar_0_0, %FunctionVar_0_1, %FunctionVar_0_2, PartitionedFromPattern="nn.conv2d_nn.bias_add_nn.relu_") {
%0 = nn.conv2d(%FunctionVar_0_0, %FunctionVar_0_1, padding=[0, 0, 0, 0]);
%1 = nn.bias_add(%0, %FunctionVar_0_2);
nn.relu(%1)
};
%2(%x, %w, %b)
}
常量不是输入:
partitioned = pattern.partition(reluc)
print(tvm.IRModule.from_expr(partitioned))
def @main(%x, %b) {
%0 = nn.conv2d(%x, 1, padding=[0, 0, 0, 0]);
%1 = nn.bias_add(%0, %b);
nn.relu(%1)
}
检查常量匹配的嵌入:
pattern = is_op("nn.relu")(
is_op("nn.bias_add")(is_op("nn.conv2d")(wildcard(), is_constant()), wildcard())
)
partitioned = pattern.partition(relu)
print(tvm.IRModule.from_expr(partitioned))
def @main(%x, %w, %b) {
%0 = nn.conv2d(%x, %w, padding=[0, 0, 0, 0]);
%1 = nn.bias_add(%0, %b);
nn.relu(%1)
}
partitioned = pattern.partition(reluc)
print(tvm.IRModule.from_expr(partitioned))
def @main(%x, %b) {
%2 = fn (%FunctionVar_0_0, %FunctionVar_0_1, PartitionedFromPattern="nn.conv2d_nn.bias_add_nn.relu_") {
%0 = nn.conv2d(%FunctionVar_0_0, 1, padding=[0, 0, 0, 0]);
%1 = nn.bias_add(%0, %FunctionVar_0_1);
nn.relu(%1)
};
%2(%x, %b)
}
检查常量 ExprPatterns 的嵌入:
pattern = is_op("nn.relu")(
is_op("nn.bias_add")(is_op("nn.conv2d")(wildcard(), is_expr(wc)), wildcard())
)
partitioned = pattern.partition(relu)
print(tvm.IRModule.from_expr(partitioned))
def @main(%x, %w, %b) {
%0 = nn.conv2d(%x, %w, padding=[0, 0, 0, 0]);
%1 = nn.bias_add(%0, %b);
nn.relu(%1)
}
partitioned = pattern.partition(reluc)
print(tvm.IRModule.from_expr(partitioned))
def @main(%x, %b) {
%2 = fn (%FunctionVar_0_0, %FunctionVar_0_1, PartitionedFromPattern="nn.conv2d_nn.bias_add_nn.relu_") {
%0 = nn.conv2d(%FunctionVar_0_0, 1, padding=[0, 0, 0, 0]);
%1 = nn.bias_add(%0, %FunctionVar_0_1);
nn.relu(%1)
};
%2(%x, %b)
}
检查 Alt 匹配的提升/嵌入:
pattern = is_op("nn.relu")(
is_op("nn.bias_add")(is_op("nn.conv2d")(wildcard(), is_var() | is_constant()), wildcard())
)
partitioned = pattern.partition(relu) # lifted
print(tvm.IRModule.from_expr(partitioned))
def @main(%x, %w, %b) {
%2 = fn (%FunctionVar_0_0, %FunctionVar_0_1, %FunctionVar_0_2, PartitionedFromPattern="nn.conv2d_nn.bias_add_nn.relu_") {
%0 = nn.conv2d(%FunctionVar_0_0, %FunctionVar_0_1, padding=[0, 0, 0, 0]);
%1 = nn.bias_add(%0, %FunctionVar_0_2);
nn.relu(%1)
};
%2(%x, %w, %b)
}
partitioned = pattern.partition(reluc) # embeded
print(tvm.IRModule.from_expr(partitioned))
def @main(%x, %b) {
%2 = fn (%FunctionVar_0_0, %FunctionVar_0_1, PartitionedFromPattern="nn.conv2d_nn.bias_add_nn.relu_") {
%0 = nn.conv2d(%FunctionVar_0_0, 1, padding=[0, 0, 0, 0]);
%1 = nn.bias_add(%0, %FunctionVar_0_1);
nn.relu(%1)
};
%2(%x, %b)
}
检查 Alt 匹配的提升/嵌入是否与其他排序一致:
pattern = is_op("nn.relu")(
is_op("nn.bias_add")(is_op("nn.conv2d")(wildcard(), is_constant() | is_var()), wildcard())
)
partitioned = pattern.partition(relu) # lifted
print(tvm.IRModule.from_expr(partitioned))
def @main(%x, %w, %b) {
%2 = fn (%FunctionVar_0_0, %FunctionVar_0_1, %FunctionVar_0_2, PartitionedFromPattern="nn.conv2d_nn.bias_add_nn.relu_") {
%0 = nn.conv2d(%FunctionVar_0_0, %FunctionVar_0_1, padding=[0, 0, 0, 0]);
%1 = nn.bias_add(%0, %FunctionVar_0_2);
nn.relu(%1)
};
%2(%x, %w, %b)
}
partitioned = pattern.partition(reluc) # embeded
print(tvm.IRModule.from_expr(partitioned))
def @main(%x, %b) {
%2 = fn (%FunctionVar_0_0, %FunctionVar_0_1, PartitionedFromPattern="nn.conv2d_nn.bias_add_nn.relu_") {
%0 = nn.conv2d(%FunctionVar_0_0, 1, padding=[0, 0, 0, 0]);
%1 = nn.bias_add(%0, %FunctionVar_0_1);
nn.relu(%1)
};
%2(%x, %b)
}