使用模式递归重写

使用模式递归重写#

from testing import viz_expr # 可视化 relay
import numpy as np
import tvm
from tvm import relay
from tvm.relay.dataflow_pattern import *

构建计算图:

data = relay.var("data", relay.TensorType((2, 8), "float32"))
dense_weight = relay.const(np.zeros((4, 8)))
feat = relay.nn.dense(data, dense_weight)
feat = relay.cast(feat, "float32")
feat = relay.cast(feat, "float32")
feat = relay.cast(feat, "float32")
feat = relay.cast(feat, "float32")
feat = relay.cast(feat, "float32")
oup = relay.cast(feat, "float32")
viz_expr(oup)
../../../../_images/a35ce19e8670cf943004b0d6973eaeb334d8ed1409cb1f94e97d0a511813e0e5.svg
print(tvm.IRModule.from_expr(oup))
def @main(%data: Tensor[(2, 8), float32]) {
  %0 = nn.dense(%data, meta[relay.Constant][0], units=None);
  %1 = cast(%0, dtype="float32");
  %2 = cast(%1, dtype="float32");
  %3 = cast(%2, dtype="float32");
  %4 = cast(%3, dtype="float32");
  %5 = cast(%4, dtype="float32");
  cast(%5, dtype="float32")
}

构建模式:

class TheRewrite(DFPatternCallback):
    def __init__(self, pattern):
        super().__init__(rewrite_once=True)
        self.pattern = pattern

    def callback(self, pre, post, node_map):
        return relay.nn.relu(post)
expected = relay.nn.relu(oup)
def test_reset_call_args():
    dense_pattern = is_op("nn.dense")(wildcard(), wildcard())
    wildcard_redirect = wildcard()
    the_pattern = is_op("cast")(wildcard_redirect)
    the_pattern2 = the_pattern | dense_pattern
    wildcard_redirect.redirect_to(the_pattern2)

    actual = rewrite(TheRewrite(the_pattern), oup)
    tvm.ir.assert_structural_equal(actual, expected)

def test_reset_alt_left():
    dense_pattern = is_op("nn.dense")(wildcard(), wildcard())
    wildcard_redirect = wildcard()
    or_pattern = wildcard_redirect | dense_pattern
    the_pattern = is_op("cast")(or_pattern)
    wildcard_redirect.redirect_to(the_pattern)

    actual = rewrite(TheRewrite(the_pattern), oup)
    tvm.ir.assert_structural_equal(actual, expected)

def test_reset_alt_right():
    dense_pattern = is_op("nn.dense")(wildcard(), wildcard())
    wildcard_redirect = wildcard()
    or_pattern = dense_pattern | wildcard_redirect
    the_pattern = is_op("cast")(or_pattern)
    wildcard_redirect.redirect_to(the_pattern)

    actual = rewrite(TheRewrite(the_pattern), oup)
    tvm.ir.assert_structural_equal(actual, expected)

test_reset_call_args()
test_reset_alt_left()
test_reset_alt_right()