四元重写支配者#
import set_env
# import numpy as np
import tvm
from tvm import relay
# from tvm.relay.build_module import bind_params_by_name
from tvm.relay.dataflow_pattern import *
# from tvm.relay.testing import run_opt_pass
# NB: 1 corresponds to the C++ enum that specicfies this
# we loose the type safety due to the Python/C++ calling
# convention.
K_ELEMWISE = 0
K_BROADCAST = 1
class DominatorRemovalCallback(DFPatternCallback):
def __init__(self):
super(DominatorRemovalCallback, self).__init__()
self.inp = wildcard()
self.weight = wildcard()
is_conv2d = is_op("nn.conv2d")(self.inp, self.weight)
is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(
wildcard()
) | is_op("add")(wildcard(), wildcard())
reduction = is_op("add")(wildcard(), wildcard())
self.pattern = dominates(is_conv2d, is_unary_elemwise, reduction)
def callback(self, pre, post, node_map):
inp = node_map[self.inp][0]
weight = node_map[self.weight][0]
return relay.op.nn.conv2d(inp, weight)
inp = relay.var("input")
weight = relay.var("weight")
# Classic Diamond
conv2d = relay.op.nn.conv2d(inp, weight)
relu = relay.op.nn.relu(conv2d)
relu = relay.op.nn.relu(relu)
leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0)
out = relu + leaky_relu
# Deeper Branch
conv2d = relay.op.nn.conv2d(out, weight)
relu = relay.op.nn.relu(conv2d)
relu = relay.op.nn.relu(relu)
relu = relay.op.tanh(relu)
leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0)
out = relu + leaky_relu
# Single Branch
conv2d = relay.op.nn.conv2d(out, weight)
relu = relay.op.nn.relu(conv2d)
relu = relay.op.nn.relu(relu)
tanh = relay.op.tanh(relu)
out = relu + tanh
# Fuzzy path/nested Diamond
conv2d = relay.op.nn.conv2d(out, weight)
relu = relay.op.nn.relu(conv2d)
relu = relu + relu
tanh = relay.op.tanh(relu)
leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0)
out = tanh + leaky_relu
one = relay.op.nn.conv2d(inp, weight)
two = relay.op.nn.conv2d(one, weight)
three = relay.op.nn.conv2d(two, weight)
four = relay.op.nn.conv2d(three, weight)
assert tvm.ir.structural_equal(DominatorRemovalCallback().rewrite(out), four)
tvm.IRModule.from_expr(four).show()
def @main(%input, %weight) {
%0 = nn.conv2d(%input, %weight, padding=[0, 0, 0, 0]);
%1 = nn.conv2d(%0, %weight, padding=[0, 0, 0, 0]);
%2 = nn.conv2d(%1, %weight, padding=[0, 0, 0, 0]);
nn.conv2d(%2, %weight, padding=[0, 0, 0, 0])
}