重写模糊主体的函数

重写模糊主体的函数#

from testing import viz_expr # 可视化 relay
import tvm
from tvm import relay
from tvm.relay.dataflow_pattern import *

通过支配者分析允许重写具有模糊主体的函数。

x = relay.var("x")
w = relay.var("w")
b = relay.var("b")
x1 = relay.var("x1")
w1 = relay.var("w1")
func = relay.Function([x1, w1], relay.nn.conv2d(x1, w1))
expr = func(x, w) + b + b
print(relay.nn.conv2d(x1, w1))
free_var %x1;
free_var %w1;
nn.conv2d(%x1, %w1, padding=[0, 0, 0, 0])
wc_x = wildcard()
wc_w = wildcard()
wc_b = wildcard()
wc_x1 = wildcard()
wc_w1 = wildcard()

func_pattern = FunctionPattern([wc_x1, wc_w1], wildcard())
pattern = func_pattern(wc_x, wc_w) + wc_b
class TestRewrite(DFPatternCallback):
    def __init__(self):
        super().__init__()
        self.pattern = pattern

    def callback(self, pre, post, node_map):
        return x + w

out = rewrite(TestRewrite(), expr)
assert tvm.ir.structural_equal(out, x + w + b)
print(tvm.IRModule.from_expr(expr))
def @main(%x, %w, %b) {
  %0 = fn (%x1, %w1) {
    nn.conv2d(%x1, %w1, padding=[0, 0, 0, 0])
  };
  %1 = %0(%x, %w);
  %2 = add(%1, %b);
  add(%2, %b)
}
print(tvm.IRModule.from_expr(out))
def @main(%x, %w, %b) {
  %0 = add(%x, %w);
  add(%0, %b)
}