模式化简

模式化简#

import sys
from pathlib import Path
ROOT = Path(".").resolve().parents[3]
sys.path.extend([f"{ROOT}/tests", f"{ROOT}/src"])
# # from tools.tag_span import _create_span, _set_span, _verify_structural_equal_with_span
from tools.torch_utils import verify_model
import numpy as np

import tvm
from tvm import relay
from tvm.relay.build_module import bind_params_by_name
from tvm.relay.dataflow_pattern import *
from tvm.relay.testing import run_opt_pass

# NB: 1 corresponds to the C++ enum that specicfies this
# we loose the type safety due to the Python/C++ calling
# convention.
K_ELEMWISE = 0
K_BROADCAST = 1
def algebraic_simplify(expr):
    zero = is_expr(relay.const(0)) | is_expr(relay.const(0.0))
    one = is_expr(relay.const(1)) | is_expr(relay.const(1.0))

    class ElwiseNullCallback(DFPatternCallback):
        def callback(self, pre, post, node_map):
            return node_map[self.x][0]  # pylint: disable=no-member

    class AddCallback(ElwiseNullCallback):
        def __init__(self):
            super(AddCallback, self).__init__()
            self.x = wildcard()
            self.pattern = self.x + zero

    class SubCallback(ElwiseNullCallback):
        def __init__(self):
            super(SubCallback, self).__init__()
            self.x = wildcard()
            self.pattern = self.x - zero

    class MulCallback(ElwiseNullCallback):
        def __init__(self):
            super(MulCallback, self).__init__()
            self.x = wildcard()
            self.pattern = self.x * one

    class DivCallback(ElwiseNullCallback):
        def __init__(self):
            super(DivCallback, self).__init__()
            self.x = wildcard()
            self.pattern = self.x / one

    class MulZeroCallback(ElwiseNullCallback):
        def __init__(self):
            super(MulZeroCallback, self).__init__()
            self.x = zero
            self.pattern = self.x * wildcard()

    class ZeroDivCallback(ElwiseNullCallback):
        def __init__(self):
            super(ZeroDivCallback, self).__init__()
            self.x = zero
            self.pattern = self.x / wildcard()

    return rewrite(
        [
            AddCallback(),
            SubCallback(),
            MulCallback(),
            DivCallback(),
            MulZeroCallback(),
            ZeroDivCallback(),
        ],
        expr,
    )


def test_algebraic_simplify():
    x = relay.Var("x")
    y = relay.Var("y")

    one = relay.const(1)
    zero = relay.const(0)
    onef = relay.const(1.0)
    zerof = relay.const(0.0)

    assert algebraic_simplify(x + zero) == x
    assert algebraic_simplify(x + zerof) == x
    assert algebraic_simplify(zero + x) == x
    assert algebraic_simplify(zerof + x) == x

    assert algebraic_simplify(x - zero) == x
    assert algebraic_simplify(x - zerof) == x

    assert algebraic_simplify(x * one) == x
    assert algebraic_simplify(x * onef) == x
    assert algebraic_simplify(one * x) == x
    assert algebraic_simplify(onef * x) == x
    assert algebraic_simplify(x * zero) == zero
    assert algebraic_simplify(x * zerof) == zerof

    assert algebraic_simplify(x / one) == x
    assert algebraic_simplify(x / onef) == x
    assert algebraic_simplify(zero / x) == zero
    assert algebraic_simplify(zerof / x) == zerof

    assert tvm.ir.structural_equal(
        algebraic_simplify((x + zero * y) / one + (y * one) - zero / x), x + y
    )