rewrite_once
#
from testing import viz_expr # 可视化 relay
import numpy as np
import tvm
from tvm import relay
# from tvm.relay.build_module import bind_params_by_name
from tvm.relay.dataflow_pattern import *
from tvm.relay.testing import run_opt_pass
# # NB: 1 corresponds to the C++ enum that specicfies this
# # we loose the type safety due to the Python/C++ calling
# # convention.
# K_ELEMWISE = 0
# K_BROADCAST = 1
class ConcatRewriter(DFPatternCallback):
def __init__(self, rewrite_once):
super().__init__(rewrite_once=rewrite_once)
self.pattern = is_op("concatenate")(None)
def callback(self, pre, post, node_map):
concat_args = post.args[0]
# Remove the last argument
new_args = [concat_args[i] for i in range(len(concat_args) - 1)]
if new_args:
return relay.op.concatenate(relay.expr.Tuple(new_args), axis=0)
else:
return concat_args[0]
ConcatRewriter
类递归地移除 concat
的参数,直到没有剩余的内容可以拼接。
x = relay.var("x")
y = relay.var("y")
z = relay.var("z")
concat = relay.op.concatenate(relay.expr.Tuple([x, y, z]), axis=0)
print(tvm.IRModule.from_expr(concat))
def @main(%x, %y, %z) {
%0 = (%x, %y, %z);
concatenate(%0)
}
让重写器递归运行:
out = rewrite(ConcatRewriter(False), concat)
print(tvm.IRModule.from_expr(out))
def @main(%x) {
%x
}
让重写器仅运行一次:
out = rewrite(ConcatRewriter(True), concat)
print(tvm.IRModule.from_expr(out))
def @main(%x, %y) {
%0 = (%x, %y);
concatenate(%0)
}
class OneMoreReluRewriter(DFPatternCallback):
def __init__(self, rewrite_once):
super().__init__(rewrite_once=rewrite_once)
self.pattern = is_op("nn.softmax")(None)
def callback(self, pre, post, node_map):
return relay.nn.relu(post)
OneMoreReluRewriter
类在 nn.softmax
之后递归地添加 nn.relu
算子。
def before():
# Before:
# x y z
# | | |
# concat
# |
# softmax
return relay.nn.softmax(concat)
print(tvm.IRModule.from_expr(before()))
def @main(%x, %y, %z) {
%0 = (%x, %y, %z);
%1 = concatenate(%0);
nn.softmax(%1)
}
运行 ConcatRewriter
一次,OneMoreReluRewriter
一次:
x y
| |
concat
|
softmax
|
relu
out = rewrite(
[OneMoreReluRewriter(True), ConcatRewriter(True)],
before(),
)
print(tvm.IRModule.from_expr(out))
def @main(%x, %y) {
%0 = (%x, %y);
%1 = concatenate(%0);
%2 = nn.softmax(%1);
nn.relu(%2)
}
递归运行 ConcatRewriter
,OneMoreReluRewriter
一次:
x
|
softmax
|
relu
out = rewrite(
[OneMoreReluRewriter(True), ConcatRewriter(False)],
before(),
)
print(tvm.IRModule.from_expr(out))
def @main(%x) {
%0 = nn.softmax(%x);
nn.relu(%0)
}