rewrite_once

rewrite_once#

from testing import viz_expr # 可视化 relay
import numpy as np

import tvm
from tvm import relay
# from tvm.relay.build_module import bind_params_by_name
from tvm.relay.dataflow_pattern import *
from tvm.relay.testing import run_opt_pass

# # NB: 1 corresponds to the C++ enum that specicfies this
# # we loose the type safety due to the Python/C++ calling
# # convention.
# K_ELEMWISE = 0
# K_BROADCAST = 1
class ConcatRewriter(DFPatternCallback):
    def __init__(self, rewrite_once):
        super().__init__(rewrite_once=rewrite_once)
        self.pattern = is_op("concatenate")(None)

    def callback(self, pre, post, node_map):
        concat_args = post.args[0]
        # Remove the last argument
        new_args = [concat_args[i] for i in range(len(concat_args) - 1)]
        if new_args:
            return relay.op.concatenate(relay.expr.Tuple(new_args), axis=0)
        else:
            return concat_args[0]

ConcatRewriter 类递归地移除 concat 的参数,直到没有剩余的内容可以拼接。

x = relay.var("x")
y = relay.var("y")
z = relay.var("z")
concat = relay.op.concatenate(relay.expr.Tuple([x, y, z]), axis=0)
print(tvm.IRModule.from_expr(concat))
def @main(%x, %y, %z) {
  %0 = (%x, %y, %z);
  concatenate(%0)
}

让重写器递归运行:

out = rewrite(ConcatRewriter(False), concat)
print(tvm.IRModule.from_expr(out))
def @main(%x) {
  %x
}

让重写器仅运行一次:

out = rewrite(ConcatRewriter(True), concat)
print(tvm.IRModule.from_expr(out))
def @main(%x, %y) {
  %0 = (%x, %y);
  concatenate(%0)
}
class OneMoreReluRewriter(DFPatternCallback):
    def __init__(self, rewrite_once):
        super().__init__(rewrite_once=rewrite_once)
        self.pattern = is_op("nn.softmax")(None)

    def callback(self, pre, post, node_map):
        return relay.nn.relu(post)

OneMoreReluRewriter 类在 nn.softmax 之后递归地添加 nn.relu 算子。

def before():
    # Before:
    #    x    y    z
    #    |    |    |
    #       concat
    #         |
    #      softmax
    return relay.nn.softmax(concat)
print(tvm.IRModule.from_expr(before()))
def @main(%x, %y, %z) {
  %0 = (%x, %y, %z);
  %1 = concatenate(%0);
  nn.softmax(%1)
}

运行 ConcatRewriter 一次,OneMoreReluRewriter 一次:

  x    y
  |    |
  concat
     |
  softmax
     |
   relu
out = rewrite(
    [OneMoreReluRewriter(True), ConcatRewriter(True)],
    before(),
)
print(tvm.IRModule.from_expr(out))
def @main(%x, %y) {
  %0 = (%x, %y);
  %1 = concatenate(%0);
  %2 = nn.softmax(%1);
  nn.relu(%2)
}

递归运行 ConcatRewriterOneMoreReluRewriter 一次:

     x
     |
  softmax
     |
   relu
out = rewrite(
    [OneMoreReluRewriter(True), ConcatRewriter(False)],
    before(),
)
print(tvm.IRModule.from_expr(out))
def @main(%x) {
  %0 = nn.softmax(%x);
  nn.relu(%0)
}