重写模式简介#
import testing
import numpy as np
import tvm
from tvm import relay
from tvm.relay.dataflow_pattern import *
替换加法为减法#
x = relay.var("x")
y = relay.var("y")
add_pattern = is_op("add")(wildcard(), wildcard())
sub_pattern = is_op("subtract")(wildcard(), wildcard())
class TestRewrite(DFPatternCallback):
def __init__(self):
super(TestRewrite, self).__init__()
self.pattern = add_pattern
def callback(self, pre, post, node_map):
return post.args[0] - post.args[1]
out = rewrite(TestRewrite(), x + y)
assert sub_pattern.match(out)
重写函数#
x = relay.var("x")
w = relay.var("w")
y = relay.var("y")
add_pattern = is_op("add")(wildcard(), wildcard())
sub_pattern = is_op("subtract")(wildcard(), wildcard())
class TestRewrite(DFPatternCallback):
def __init__(self):
super(TestRewrite, self).__init__()
self.pattern = add_pattern
def callback(self, pre, post, node_map):
return post.args[0] - post.args[1]
inpf = relay.var("input")
weightf = relay.var("weight")
func = relay.Function(
[inpf, weightf], relay.op.nn.relu(relay.op.nn.conv2d(inpf, weightf)), attrs=None
)
out = rewrite(TestRewrite(), func(x, w) + y)
assert sub_pattern.match(out)
x = relay.var("x")
y = relay.var("y")
f = relay.Function([x, y], x + y).with_attr("Composite", "add")
a = relay.var("a")
b = relay.var("b")
c = relay.Call(f, [a, b])
c_abs = relay.abs(c)
class TestRewrite(DFPatternCallback):
def __init__(self):
super(TestRewrite, self).__init__()
self.pattern = wildcard().has_attr({"Composite": "add"})(wildcard(), wildcard())
def callback(self, pre, post, node_map):
return post.args[0] + post.args[1]
out = rewrite(TestRewrite(), c_abs)
inlined_add_pattern = is_op("abs")(is_op("add")(wildcard(), wildcard()))
assert inlined_add_pattern.match(out)
重写嵌套#
class PatternCallback(DFPatternCallback):
def __init__(self, pattern):
super(PatternCallback, self).__init__()
self.pattern = pattern
def callback(self, pre, post, node_map):
return post
def gen():
x = relay.var("x")
y = relay.var("y")
y_add = relay.add(y, y)
n0 = relay.add(x, y_add)
n1 = relay.add(x, n0)
return relay.add(n1, n0)
def pattern():
a = wildcard()
b = wildcard()
n0 = is_op("add")(a, b)
n1 = is_op("add")(n0, a)
return is_op("add")(n0, n1)
out = gen()
pat = pattern()
new_out = rewrite(PatternCallback(pat), out)
assert tvm.ir.structural_equal(out, new_out)