四元重写支配者

四元重写支配者#

import set_env
# import numpy as np

import tvm
from tvm import relay
# from tvm.relay.build_module import bind_params_by_name
from tvm.relay.dataflow_pattern import *
# from tvm.relay.testing import run_opt_pass
# NB: 1 corresponds to the C++ enum that specicfies this
# we loose the type safety due to the Python/C++ calling
# convention.
K_ELEMWISE = 0
K_BROADCAST = 1

class DominatorRemovalCallback(DFPatternCallback):
    def __init__(self):
        super(DominatorRemovalCallback, self).__init__()
        self.inp = wildcard()
        self.weight = wildcard()
        is_conv2d = is_op("nn.conv2d")(self.inp, self.weight)
        is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(
            wildcard()
        ) | is_op("add")(wildcard(), wildcard())
        reduction = is_op("add")(wildcard(), wildcard())
        self.pattern = dominates(is_conv2d, is_unary_elemwise, reduction)

    def callback(self, pre, post, node_map):
        inp = node_map[self.inp][0]
        weight = node_map[self.weight][0]
        return relay.op.nn.conv2d(inp, weight)
inp = relay.var("input")
weight = relay.var("weight")
# Classic Diamond
conv2d = relay.op.nn.conv2d(inp, weight)
relu = relay.op.nn.relu(conv2d)
relu = relay.op.nn.relu(relu)
leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0)
out = relu + leaky_relu

# Deeper Branch
conv2d = relay.op.nn.conv2d(out, weight)
relu = relay.op.nn.relu(conv2d)
relu = relay.op.nn.relu(relu)
relu = relay.op.tanh(relu)
leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0)
out = relu + leaky_relu

# Single Branch
conv2d = relay.op.nn.conv2d(out, weight)
relu = relay.op.nn.relu(conv2d)
relu = relay.op.nn.relu(relu)
tanh = relay.op.tanh(relu)
out = relu + tanh

# Fuzzy path/nested Diamond
conv2d = relay.op.nn.conv2d(out, weight)
relu = relay.op.nn.relu(conv2d)
relu = relu + relu
tanh = relay.op.tanh(relu)
leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0)
out = tanh + leaky_relu
one = relay.op.nn.conv2d(inp, weight)
two = relay.op.nn.conv2d(one, weight)
three = relay.op.nn.conv2d(two, weight)
four = relay.op.nn.conv2d(three, weight)

assert tvm.ir.structural_equal(DominatorRemovalCallback().rewrite(out), four)
tvm.IRModule.from_expr(four).show()
def @main(%input, %weight) {
  %0 = nn.conv2d(%input, %weight, padding=[0, 0, 0, 0]);
  %1 = nn.conv2d(%0, %weight, padding=[0, 0, 0, 0]);
  %2 = nn.conv2d(%1, %weight, padding=[0, 0, 0, 0]);
  nn.conv2d(%2, %weight, padding=[0, 0, 0, 0])
}