constant embedding 分区

constant embedding 分区#

from testing import viz_expr # 可视化 relay
import tvm
from tvm import relay
from tvm.relay.dataflow_pattern import *

def conv_bias_relu(x, w, b):
    conv2d = relay.op.nn.conv2d(x, w)
    bias_add = relay.op.nn.bias_add(conv2d, b)
    relu = relay.op.nn.relu(bias_add)
    return relu

构建计算图:

x = relay.var("x")
w = relay.var("w")
wc = relay.const(1)
b = relay.var("b")

relu = conv_bias_relu(x, w, b)
reluc = conv_bias_relu(x, wc, b)
# viz_expr(relu)
print(tvm.IRModule.from_expr(relu))
def @main(%x, %w, %b) {
  %0 = nn.conv2d(%x, %w, padding=[0, 0, 0, 0]);
  %1 = nn.bias_add(%0, %b);
  nn.relu(%1)
}
print(tvm.IRModule.from_expr(reluc))
def @main(%x, %b) {
  %0 = nn.conv2d(%x, 1, padding=[0, 0, 0, 0]);
  %1 = nn.bias_add(%0, %b);
  nn.relu(%1)
}

构建模式(解除 wildcard() 匹配):

pattern = is_op("nn.relu")(
    is_op("nn.bias_add")(is_op("nn.conv2d")(wildcard(), wildcard()), wildcard())
)

分割计算图:

partitioned = pattern.partition(relu)
print(tvm.IRModule.from_expr(partitioned))
def @main(%x, %w, %b) {
  %2 = fn (%FunctionVar_0_0, %FunctionVar_0_1, %FunctionVar_0_2, PartitionedFromPattern="nn.conv2d_nn.bias_add_nn.relu_") {
    %0 = nn.conv2d(%FunctionVar_0_0, %FunctionVar_0_1, padding=[0, 0, 0, 0]);
    %1 = nn.bias_add(%0, %FunctionVar_0_2);
    nn.relu(%1)
  };
  %2(%x, %w, %b)
}
partitioned = pattern.partition(reluc)
print(tvm.IRModule.from_expr(partitioned))
def @main(%x, %b) {
  %2 = fn (%FunctionVar_0_0, %FunctionVar_0_1, %FunctionVar_0_2, PartitionedFromPattern="nn.conv2d_nn.bias_add_nn.relu_") {
    %0 = nn.conv2d(%FunctionVar_0_0, %FunctionVar_0_1, padding=[0, 0, 0, 0]);
    %1 = nn.bias_add(%0, %FunctionVar_0_2);
    nn.relu(%1)
  };
  %2(%x, 1, %b)
}

构建模式(解除 input 匹配):

pattern = is_op("nn.relu")(
    is_op("nn.bias_add")(is_op("nn.conv2d")(wildcard(), is_var()), wildcard())
)
partitioned = pattern.partition(relu)
print(tvm.IRModule.from_expr(partitioned))
def @main(%x, %w, %b) {
  %2 = fn (%FunctionVar_0_0, %FunctionVar_0_1, %FunctionVar_0_2, PartitionedFromPattern="nn.conv2d_nn.bias_add_nn.relu_") {
    %0 = nn.conv2d(%FunctionVar_0_0, %FunctionVar_0_1, padding=[0, 0, 0, 0]);
    %1 = nn.bias_add(%0, %FunctionVar_0_2);
    nn.relu(%1)
  };
  %2(%x, %w, %b)
}

常量不是输入:

partitioned = pattern.partition(reluc)
print(tvm.IRModule.from_expr(partitioned)) 
def @main(%x, %b) {
  %0 = nn.conv2d(%x, 1, padding=[0, 0, 0, 0]);
  %1 = nn.bias_add(%0, %b);
  nn.relu(%1)
}

检查常量匹配的嵌入:

pattern = is_op("nn.relu")(
    is_op("nn.bias_add")(is_op("nn.conv2d")(wildcard(), is_constant()), wildcard())
)
partitioned = pattern.partition(relu)
print(tvm.IRModule.from_expr(partitioned))
def @main(%x, %w, %b) {
  %0 = nn.conv2d(%x, %w, padding=[0, 0, 0, 0]);
  %1 = nn.bias_add(%0, %b);
  nn.relu(%1)
}
partitioned = pattern.partition(reluc)
print(tvm.IRModule.from_expr(partitioned))
def @main(%x, %b) {
  %2 = fn (%FunctionVar_0_0, %FunctionVar_0_1, PartitionedFromPattern="nn.conv2d_nn.bias_add_nn.relu_") {
    %0 = nn.conv2d(%FunctionVar_0_0, 1, padding=[0, 0, 0, 0]);
    %1 = nn.bias_add(%0, %FunctionVar_0_1);
    nn.relu(%1)
  };
  %2(%x, %b)
}

检查常量 ExprPatterns 的嵌入:

pattern = is_op("nn.relu")(
    is_op("nn.bias_add")(is_op("nn.conv2d")(wildcard(), is_expr(wc)), wildcard())
)
partitioned = pattern.partition(relu)
print(tvm.IRModule.from_expr(partitioned))
def @main(%x, %w, %b) {
  %0 = nn.conv2d(%x, %w, padding=[0, 0, 0, 0]);
  %1 = nn.bias_add(%0, %b);
  nn.relu(%1)
}
partitioned = pattern.partition(reluc)
print(tvm.IRModule.from_expr(partitioned))
def @main(%x, %b) {
  %2 = fn (%FunctionVar_0_0, %FunctionVar_0_1, PartitionedFromPattern="nn.conv2d_nn.bias_add_nn.relu_") {
    %0 = nn.conv2d(%FunctionVar_0_0, 1, padding=[0, 0, 0, 0]);
    %1 = nn.bias_add(%0, %FunctionVar_0_1);
    nn.relu(%1)
  };
  %2(%x, %b)
}

检查 Alt 匹配的提升/嵌入:

pattern = is_op("nn.relu")(
    is_op("nn.bias_add")(is_op("nn.conv2d")(wildcard(), is_var() | is_constant()), wildcard())
)
partitioned = pattern.partition(relu) # lifted
print(tvm.IRModule.from_expr(partitioned))
def @main(%x, %w, %b) {
  %2 = fn (%FunctionVar_0_0, %FunctionVar_0_1, %FunctionVar_0_2, PartitionedFromPattern="nn.conv2d_nn.bias_add_nn.relu_") {
    %0 = nn.conv2d(%FunctionVar_0_0, %FunctionVar_0_1, padding=[0, 0, 0, 0]);
    %1 = nn.bias_add(%0, %FunctionVar_0_2);
    nn.relu(%1)
  };
  %2(%x, %w, %b)
}
partitioned = pattern.partition(reluc) # embeded
print(tvm.IRModule.from_expr(partitioned))
def @main(%x, %b) {
  %2 = fn (%FunctionVar_0_0, %FunctionVar_0_1, PartitionedFromPattern="nn.conv2d_nn.bias_add_nn.relu_") {
    %0 = nn.conv2d(%FunctionVar_0_0, 1, padding=[0, 0, 0, 0]);
    %1 = nn.bias_add(%0, %FunctionVar_0_1);
    nn.relu(%1)
  };
  %2(%x, %b)
}

检查 Alt 匹配的提升/嵌入是否与其他排序一致:

pattern = is_op("nn.relu")(
    is_op("nn.bias_add")(is_op("nn.conv2d")(wildcard(), is_constant() | is_var()), wildcard())
)
partitioned = pattern.partition(relu) # lifted
print(tvm.IRModule.from_expr(partitioned))
def @main(%x, %w, %b) {
  %2 = fn (%FunctionVar_0_0, %FunctionVar_0_1, %FunctionVar_0_2, PartitionedFromPattern="nn.conv2d_nn.bias_add_nn.relu_") {
    %0 = nn.conv2d(%FunctionVar_0_0, %FunctionVar_0_1, padding=[0, 0, 0, 0]);
    %1 = nn.bias_add(%0, %FunctionVar_0_2);
    nn.relu(%1)
  };
  %2(%x, %w, %b)
}
partitioned = pattern.partition(reluc) # embeded
print(tvm.IRModule.from_expr(partitioned))
def @main(%x, %b) {
  %2 = fn (%FunctionVar_0_0, %FunctionVar_0_1, PartitionedFromPattern="nn.conv2d_nn.bias_add_nn.relu_") {
    %0 = nn.conv2d(%FunctionVar_0_0, 1, padding=[0, 0, 0, 0]);
    %1 = nn.bias_add(%0, %FunctionVar_0_1);
    nn.relu(%1)
  };
  %2(%x, %b)
}