def algebraic_simplify(expr):
zero = is_expr(relay.const(0)) | is_expr(relay.const(0.0))
one = is_expr(relay.const(1)) | is_expr(relay.const(1.0))
class ElwiseNullCallback(DFPatternCallback):
def callback(self, pre, post, node_map):
return node_map[self.x][0] # pylint: disable=no-member
class AddCallback(ElwiseNullCallback):
def __init__(self):
super(AddCallback, self).__init__()
self.x = wildcard()
self.pattern = self.x + zero
class SubCallback(ElwiseNullCallback):
def __init__(self):
super(SubCallback, self).__init__()
self.x = wildcard()
self.pattern = self.x - zero
class MulCallback(ElwiseNullCallback):
def __init__(self):
super(MulCallback, self).__init__()
self.x = wildcard()
self.pattern = self.x * one
class DivCallback(ElwiseNullCallback):
def __init__(self):
super(DivCallback, self).__init__()
self.x = wildcard()
self.pattern = self.x / one
class MulZeroCallback(ElwiseNullCallback):
def __init__(self):
super(MulZeroCallback, self).__init__()
self.x = zero
self.pattern = self.x * wildcard()
class ZeroDivCallback(ElwiseNullCallback):
def __init__(self):
super(ZeroDivCallback, self).__init__()
self.x = zero
self.pattern = self.x / wildcard()
return rewrite(
[
AddCallback(),
SubCallback(),
MulCallback(),
DivCallback(),
MulZeroCallback(),
ZeroDivCallback(),
],
expr,
)
def test_algebraic_simplify():
x = relay.Var("x")
y = relay.Var("y")
one = relay.const(1)
zero = relay.const(0)
onef = relay.const(1.0)
zerof = relay.const(0.0)
assert algebraic_simplify(x + zero) == x
assert algebraic_simplify(x + zerof) == x
assert algebraic_simplify(zero + x) == x
assert algebraic_simplify(zerof + x) == x
assert algebraic_simplify(x - zero) == x
assert algebraic_simplify(x - zerof) == x
assert algebraic_simplify(x * one) == x
assert algebraic_simplify(x * onef) == x
assert algebraic_simplify(one * x) == x
assert algebraic_simplify(onef * x) == x
assert algebraic_simplify(x * zero) == zero
assert algebraic_simplify(x * zerof) == zerof
assert algebraic_simplify(x / one) == x
assert algebraic_simplify(x / onef) == x
assert algebraic_simplify(zero / x) == zero
assert algebraic_simplify(zerof / x) == zerof
assert tvm.ir.structural_equal(
algebraic_simplify((x + zero * y) / one + (y * one) - zero / x), x + y
)