重写 double

重写 double#

from testing import viz_expr # 可视化 relay
import tvm
from tvm import relay
from tvm.relay.dataflow_pattern import *
def test_match_match():
    add_pattern = is_op("add")(wildcard(), wildcard())

    class TestRewrite(DFPatternCallback):
        def __init__(self):
            super(TestRewrite, self).__init__()
            self.pattern = add_pattern

        def callback(self, pre, post, node_map):
            return post.args[0] - post.args[1]

    mod = tvm.IRModule({})
    tvm.relay.prelude.Prelude(mod)
    # Apply rewrite on IR including relay.Match
    out = rewrite(TestRewrite(), mod["tensor_concatenate_int64"])
    assert tvm.ir.structural_equal(mod["tensor_concatenate_int64"], out)