重写 double#
from testing import viz_expr # 可视化 relay
import tvm
from tvm import relay
from tvm.relay.dataflow_pattern import *
def test_match_match():
add_pattern = is_op("add")(wildcard(), wildcard())
class TestRewrite(DFPatternCallback):
def __init__(self):
super(TestRewrite, self).__init__()
self.pattern = add_pattern
def callback(self, pre, post, node_map):
return post.args[0] - post.args[1]
mod = tvm.IRModule({})
tvm.relay.prelude.Prelude(mod)
# Apply rewrite on IR including relay.Match
out = rewrite(TestRewrite(), mod["tensor_concatenate_int64"])
assert tvm.ir.structural_equal(mod["tensor_concatenate_int64"], out)