模糊主体的函数分区

模糊主体的函数分区#

from testing import viz_expr # 可视化 relay
import tvm
from tvm import relay
from tvm.relay.dataflow_pattern import *

通过支配者分析允许重写具有模糊主体的函数。

x = relay.var("x")
w = relay.var("w")
b = relay.var("b")
x1 = relay.var("x1")
w1 = relay.var("w1")
func = relay.Function([x1, w1], relay.nn.conv2d(x1, w1))
expr = func(x, w) + b + b
wc_x = wildcard()
wc_w = wildcard()
wc_b = wildcard()
wc_x1 = wildcard()
wc_w1 = wildcard()

func_pattern = FunctionPattern([wc_x1, wc_w1], wildcard())
pattern = func_pattern(wc_x, wc_w) + wc_b
new_expr = pattern.partition(expr)
print(tvm.IRModule.from_expr(expr))
def @main(%x, %w, %b) {
  %0 = fn (%x1, %w1) {
    nn.conv2d(%x1, %w1, padding=[0, 0, 0, 0])
  };
  %1 = %0(%x, %w);
  %2 = add(%1, %b);
  add(%2, %b)
}
print(tvm.IRModule.from_expr(new_expr))
def @main(%x, %w, %b) {
  %2 = fn (%FunctionVar_0_0, %FunctionVar_0_1, %FunctionVar_0_2, PartitionedFromPattern="nn.conv2d_FunctionCall_add_") {
    %0 = fn (%x1, %w1) {
      nn.conv2d(%x1, %w1, padding=[0, 0, 0, 0])
    };
    %1 = %0(%FunctionVar_0_0, %FunctionVar_0_1);
    add(%1, %FunctionVar_0_2)
  };
  %3 = %2(%x, %w, %b);
  add(%3, %b)
}