重写模糊主体的函数#
from testing import viz_expr # 可视化 relay
import tvm
from tvm import relay
from tvm.relay.dataflow_pattern import *
通过支配者分析允许重写具有模糊主体的函数。
x = relay.var("x")
w = relay.var("w")
b = relay.var("b")
x1 = relay.var("x1")
w1 = relay.var("w1")
func = relay.Function([x1, w1], relay.nn.conv2d(x1, w1))
expr = func(x, w) + b + b
print(relay.nn.conv2d(x1, w1))
free_var %x1;
free_var %w1;
nn.conv2d(%x1, %w1, padding=[0, 0, 0, 0])
wc_x = wildcard()
wc_w = wildcard()
wc_b = wildcard()
wc_x1 = wildcard()
wc_w1 = wildcard()
func_pattern = FunctionPattern([wc_x1, wc_w1], wildcard())
pattern = func_pattern(wc_x, wc_w) + wc_b
class TestRewrite(DFPatternCallback):
def __init__(self):
super().__init__()
self.pattern = pattern
def callback(self, pre, post, node_map):
return x + w
out = rewrite(TestRewrite(), expr)
assert tvm.ir.structural_equal(out, x + w + b)
print(tvm.IRModule.from_expr(expr))
def @main(%x, %w, %b) {
%0 = fn (%x1, %w1) {
nn.conv2d(%x1, %w1, padding=[0, 0, 0, 0])
};
%1 = %0(%x, %w);
%2 = add(%1, %b);
add(%2, %b)
}
print(tvm.IRModule.from_expr(out))
def @main(%x, %w, %b) {
%0 = add(%x, %w);
add(%0, %b)
}