重写模式简介

重写模式简介#

参考:DFPatternCallback

import testing
import numpy as np

import tvm
from tvm import relay
from tvm.relay.dataflow_pattern import *

替换加法为减法#

x = relay.var("x")
y = relay.var("y")
add_pattern = is_op("add")(wildcard(), wildcard())
sub_pattern = is_op("subtract")(wildcard(), wildcard())

class TestRewrite(DFPatternCallback):
    def __init__(self):
        super(TestRewrite, self).__init__()
        self.pattern = add_pattern

    def callback(self, pre, post, node_map):
        return post.args[0] - post.args[1]

out = rewrite(TestRewrite(), x + y)
assert sub_pattern.match(out)

重写函数#

x = relay.var("x")
w = relay.var("w")
y = relay.var("y")
add_pattern = is_op("add")(wildcard(), wildcard())
sub_pattern = is_op("subtract")(wildcard(), wildcard())

class TestRewrite(DFPatternCallback):
    def __init__(self):
        super(TestRewrite, self).__init__()
        self.pattern = add_pattern

    def callback(self, pre, post, node_map):
        return post.args[0] - post.args[1]

inpf = relay.var("input")
weightf = relay.var("weight")
func = relay.Function(
    [inpf, weightf], relay.op.nn.relu(relay.op.nn.conv2d(inpf, weightf)), attrs=None
)
out = rewrite(TestRewrite(), func(x, w) + y)
assert sub_pattern.match(out)
x = relay.var("x")
y = relay.var("y")
f = relay.Function([x, y], x + y).with_attr("Composite", "add")

a = relay.var("a")
b = relay.var("b")
c = relay.Call(f, [a, b])
c_abs = relay.abs(c)

class TestRewrite(DFPatternCallback):
    def __init__(self):
        super(TestRewrite, self).__init__()
        self.pattern = wildcard().has_attr({"Composite": "add"})(wildcard(), wildcard())

    def callback(self, pre, post, node_map):
        return post.args[0] + post.args[1]

out = rewrite(TestRewrite(), c_abs)
inlined_add_pattern = is_op("abs")(is_op("add")(wildcard(), wildcard()))
assert inlined_add_pattern.match(out)

重写嵌套#

class PatternCallback(DFPatternCallback):
    def __init__(self, pattern):
        super(PatternCallback, self).__init__()
        self.pattern = pattern

    def callback(self, pre, post, node_map):
        return post

def gen():
    x = relay.var("x")
    y = relay.var("y")
    y_add = relay.add(y, y)
    n0 = relay.add(x, y_add)
    n1 = relay.add(x, n0)
    return relay.add(n1, n0)

def pattern():
    a = wildcard()
    b = wildcard()
    n0 = is_op("add")(a, b)
    n1 = is_op("add")(n0, a)
    return is_op("add")(n0, n1)

out = gen()
pat = pattern()
new_out = rewrite(PatternCallback(pat), out)

assert tvm.ir.structural_equal(out, new_out)