canonical_simplify()
#
参考:tvm/tests/python/arith/test_arith_canonical_simplify.py
%cd ..
import testing
import tvm
import tvm.testing
from tvm import te
class CanonicalChecker:
def __init__(self):
self.analyzer = tvm.arith.Analyzer()
def verify(self, data, expected):
res = self.analyzer.canonical_simplify(data)
expected = tvm.runtime.convert(expected)
assert tvm.ir.structural_equal(res, expected), "\ndata={}\nres={}\nexpected={}".format(
data, res, expected
)
canonical_simplify()
mul+sum#
ck = CanonicalChecker()
x, y, z = te.var("x"), te.var("y"), te.var("z")
ck.verify(2 + (3 * x + z + y + 1) * 4 + x, x * 13 + z * 4 + y * 4 + 6)
ck.verify(x * 3 - 4 * x + 1, 1 - x)
ck.verify(y + x * 3 - 5 * x + 1 + y, y * 2 + 1 - x * 2)
tdiv = tvm.tir.truncdiv
tmod = tvm.tir.truncmod
# trucdiv
ck.verify(tdiv(x + y + x + y * 3, 2), y * 2 + x)
ck.verify(tmod(x + y + x + y * 3, 2), 0)
# floordiv
fld = tvm.te.floordiv
flm = tvm.te.floormod
ck.verify(flm(x + x + y * 3, 2), flm(y * 3, 2))
ck.verify(fld(x + y + x + y * 3, 2), y * 2 + x)
ck.verify(flm(x + y + x + y * 3, 2), 0)
ck.verify(fld(x + x + y * 3, 2), fld(y * 3, 2) + x)
canonical_simplify()
plit_index#
ck = CanonicalChecker()
x, y, z = te.var("x"), te.var("y"), te.var("z")
# trucdiv
tdiv = tvm.tir.truncdiv
tmod = tvm.tir.truncmod
# split div const
ck.verify(tdiv(x, 3) * 3 + tmod(x, 3), x)
ck.verify(tdiv(x, 6) * 6 + tmod(tdiv(x, 3), 2) * 3 + tmod(x, 3), x)
ck.verify(tdiv(tdiv(tmod(x, 16), 2) * 2, 4), tdiv(tmod(x, 16), 4))
ck.verify(tdiv(tmod(x, 2), 8), 0)
ck.verify(tdiv(tmod(x, 2), 7), 0)
ck.verify(tdiv(tdiv(tmod(x, 16), 2) * 2, 6), tdiv(tmod(x, 16), 6))
# split mod const
ck.verify(tmod((x * 8), 16), tmod(x, 2) * 8)
ck.verify(tmod(x * 8, 2), 0)
# simplify then fold
ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1000))
ck.analyzer.update(y, tvm.arith.ConstIntBound(0, 1000))
ck.verify(tdiv(x * 4 + y, 2) * 2 + tmod(x * 4 + y, 2), x * 4 + y)
# complex fold
ck.verify(tdiv(z * 9 + y, 2) * 2 + tmod(z * 9 + y, 2), z * 9 + y)
ck.analyzer.update(x, tvm.arith.ConstIntBound(-100, 1000), True)
ck.analyzer.update(y, tvm.arith.ConstIntBound(-100, 1000), True)
ck.verify(tdiv(x * 4 + y, 2) * 2 + tmod(x * 4 + y, 2), x * 4 + y)
# floordiv
fld = tvm.te.floordiv
flm = tvm.te.floormod
ck.verify(fld(x * 5, 2), fld(x * 5, 2))
ck.verify(fld(x, 3) * 3 + flm(x, 3), x)
ck.verify(fld(x, 6) * 6 + flm(fld(x, 3), 2) * 3 + flm(x, 3), x)
ck.verify(fld(fld(flm(x, 16), 2) * 2, 4), fld(flm(x, 16), 4))
ck.verify(fld(flm(x, 2), 8), 0)
ck.verify(fld(flm(x, 2), 7), 0)
ck.verify(fld(fld(flm(x, 16), 2) * 2, 6), fld(flm(x, 16), 6))
# cannot simplify mixed case, unless we canonicalize into one mode.
ck.verify(tdiv(x, 6) * 2 + tmod(fld(x, 3), 2), tdiv(x, 6) * 2 + tmod(fld(x, 3), 2))
ck.verify(tmod(-x, 2), tmod(x, -2) * -1)
canonical_simplify()
div#
ck = CanonicalChecker()
x = te.var("x")
tdiv = tvm.tir.truncdiv
# truc div
ck.verify(tdiv(16 + 48 * x, 16), x * 3 + 1)
# (17+48*x)/16 is not simplifiable for arbitrary x because when 17+48*x<0
# (17+48*x)/16 != 1+3*x
ck.verify(tdiv(17 + 48 * x, 16), tdiv(x * 48 + 17, 16))
# However, when x >= 0, then 17+48*x >= 0 and (17+48*x)/16 can be simplified
ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 10))
ck.verify(tdiv(17 + 48 * x, 16), x * 3 + 1)
# Trying expressions that are not simplifiable for any values of the variables
ck.verify(tdiv(17 + 47 * x, 16), tdiv(x * 47 + 17, 16))
# floordiv
fld = tvm.te.floordiv
ck.analyzer.update(x, tvm.arith.ConstIntBound(-1000, 10000), True)
ck.verify(fld(16 + 48 * x, 16), x * 3 + 1)
ck.verify(fld(17 + 48 * x, 16), x * 3 + 1)
ck.verify(fld(17 + 47 * x, 16), fld(x * 47 + 17, 16))
canonical_simplify()
fp16_const_fold#
ck = CanonicalChecker()
zero = tvm.tir.const(0, "float16")
one = tvm.tir.const(1, "float16")
half = tvm.tir.const(0.5, "float16")
ck.verify(zero + half, half)
ck.verify(half - zero, half)
ck.verify(zero * half, zero)
ck.verify(half * one, half)
ck.verify(half / one, half)
ck.verify(zero / half, zero)
tvm.arith.Analyzer.canonical_simplify()
floormod#
ck = CanonicalChecker()
flm = tvm.te.floormod
x, y = te.var("x"), te.var("y")
ck.verify(flm(flm((x * 4) + y - 466036, 24528) - 24512, 16), flm((x * 4) + y + 12, 16))
ck.verify(flm(flm((x * 4), 16), 8), flm(x, 2) * 4)
ck.verify(flm(-x, 2), flm(x, -2) * -1)
canonical_simplify()
canonical_mixed#
ck = CanonicalChecker()
x = te.var("x")
z = tvm.tir.const(3, "int32")
tdiv = tvm.tir.truncdiv
tmod = tvm.tir.truncmod
ck.verify(tdiv(x, (z * z)) - tdiv(x, (z * z)), 0)
ck.verify(tdiv(x, (z + z)) - tdiv(x, (z + z)), 0)
ck.verify(x - 2 < 3, x < 5)
ck.verify(tvm.te.max(x, 1) - tvm.te.max(x, 1), 0)
ck.verify(tvm.te.min(x, 1) - tvm.te.min(x, 1), 0)
ck.verify(x * x - x * x, 0)
ck.verify(tmod(tdiv(tmod(x, 20), 2) * 2, 4), tdiv(tmod(x, 4), 2) * 2)
fld = tvm.te.floordiv
ck.verify(fld(x, (z * z)) - fld(x, (z * z)), 0)
ck.verify(fld(x, (z + z)) - fld(x, (z + z)), 0)
canonical_simplify()
reduce_combiner#
ck = CanonicalChecker()
dummy = te.var("dummy")
comm_reducer = te.comm_reducer
prod = comm_reducer(lambda x, y: x * y, lambda t0: tvm.tir.const(1, t0))
sum_or_prod = comm_reducer(
lambda x, y: tvm.tir.Select(dummy < 0, x + y, x * y),
lambda t0: tvm.tir.Select(dummy < 0, tvm.tir.const(0, t0), tvm.tir.const(1, t0)),
)
sum_and_prod = comm_reducer(
lambda x, y: (x[0] + y[0], x[1] * y[1]),
lambda t0, t1: (tvm.tir.const(0, t0), tvm.tir.const(5, t1) - tvm.tir.const(4, t1)),
)
some_reducer1 = comm_reducer(
lambda x, y: (
x[0] + y[0],
x[0] + y[0] + x[1] + y[1],
x[0] * y[2] + y[0] * x[2],
x[1] + y[2],
4.0,
),
lambda t0, t1, t2, t3, t4: (
tvm.tir.const(0, t0),
tvm.tir.const(1, t1),
tvm.tir.const(2, t2),
tvm.tir.const(3, t3),
tvm.tir.const(4, t4),
),
)
k = te.reduce_axis((0, 10), name="k")
A = te.placeholder((10,), name="A")
# Test that SimplifyCombiner makes use of vranges
ck.analyzer.update(dummy, tvm.arith.ConstIntBound(-10, -4))
ck.verify(sum_or_prod(A[k], k), te.sum(A[k], k))
ck.verify(sum_or_prod(A[k], k, init=1), te.sum(A[k], k, init=1))
ck.analyzer.update(dummy, tvm.arith.ConstIntBound(5, 9), True)
ck.verify(sum_or_prod(A[k], k), prod(A[k], k))
ck.verify(sum_or_prod(A[k], k, init=1), prod(A[k], k, init=1))
ck.analyzer.update(dummy, tvm.arith.ConstIntBound(-10, 100), True)
ck.verify(sum_and_prod((A[k], A[10 - k]), k)[0], te.sum(A[k], k))
ck.verify(sum_and_prod((A[k], A[10 - k]), k)[1], prod(A[10 - k], k))
reference_simplified_sources = [
[A[0]],
[A[0], A[1]],
[A[0], A[2]],
[A[0], A[1], A[2], A[3]],
[A[4]],
]
for j in range(5):
# Here we use the j-th component of the result, so only it and the components it
# depends on are left.
simplified = ck.analyzer.canonical_simplify(
some_reducer1((A[0], A[1], A[2], A[3], A[4]), k)[j]
)
# Check that the remaining components are the expected ones.
for lhs, rhs in zip(simplified.source, reference_simplified_sources[j]):
assert tvm.ir.structural_equal(lhs, rhs)
# Test that components with side effects are not removed
dummy = tvm.ir.GlobalVar("dummy")
side_effect = lambda *xs: tvm.tir.Call("int32", dummy, xs)
ck.verify(
sum_and_prod((A[k], side_effect(A[10 - k])), k)[0],
sum_and_prod((A[k], side_effect(A[10 - k])), k)[0],
)
ck.verify(sum_and_prod((side_effect(A[k]), A[10 - k]), k)[0], te.sum(side_effect(A[k]), k))
canonical_simplify()
reduce#
ck = CanonicalChecker()
k = te.reduce_axis((0, 10), name="k")
j = te.reduce_axis((-5, 3), name="j")
A = te.placeholder((10,), name="A")
ck.verify(te.sum(tvm.tir.Select(k + j < 12, k + j, 0), [k, j]), te.sum(k + j, [k, j]))
ck.verify(te.sum(A[3], []), A[3])
ck.verify(te.sum(A[3], [], where=k > 12, init=1.0), tvm.tir.const(1.0, dtype="float32"))
# The rule below is not typical, removed for now
ck.verify(te.sum(te.div(k, 10), k), te.sum(tvm.tir.const(0, "int32"), k))
canonical_simplify()
if_then_else#
ck = CanonicalChecker()
x = te.var("x")
y = te.var("y")
tdiv = tvm.tir.truncdiv
tmod = tvm.tir.truncmod
# simplification that takes condition into account.
res = tvm.tir.if_then_else(
(x * 4 + y) >= 466036,
tvm.tir.if_then_else(
24512 <= tmod(((x * 4) + y) - 466036, 24528),
tmod(tmod(((x * 4) + y) - 466036, 24528) - 24512, 16),
x,
),
y,
)
res2 = tvm.tir.if_then_else(
(x * 4) >= 466036 - y,
tvm.tir.if_then_else(
24512 <= tmod(((x * 4) + y) - 466036, 24528),
tmod(tmod(((x * 4) + y) - 466036, 24528) - 24512, 16),
x,
),
y,
)
expected = tvm.tir.if_then_else(
tvm.tir.LE(466036, (x * 4 + y)),
tvm.tir.if_then_else(
tvm.tir.LE(24512, tmod(((x * 4) + y) - 4, 24528)), tmod(((x * 4) + y) - 4, 16), x
),
y,
)
ck.verify(res, expected)
ck.verify(res2, expected)
# can only simplify if condition
res = tvm.tir.Select(tvm.tir.all(x >= -1, y >= 0), tmod(x + y + 100, 3), tmod(x + 100, 3))
expected = tvm.tir.Select(tvm.tir.all(x >= -1, y >= 0), tmod(x + y + 1, 3), tmod(x + 100, 3))
ck.verify(res, ck.analyzer.canonical_simplify(expected))
res = tvm.tir.Select(x >= 10, tvm.tir.if_then_else(tdiv(x, 3) > 2, x, 0), 0)
expected = tvm.tir.Select(x >= 10, x, 0)
ck.verify(res, ck.analyzer.canonical_simplify(expected))
res = tvm.tir.Select(x >= 10, tvm.tir.if_then_else(tdiv(x, 3) < 2, x, 0), 0)
ck.verify(res, 0)
canonical_simplify()
complex_cases#
ck = CanonicalChecker()
x = te.var("x")
y = te.var("y")
tdiv = tvm.tir.truncdiv
tmod = tvm.tir.truncmod
res2 = (
tdiv(tdiv(tmod(x * 128 + y, 1296), 36) * 2 + 1, 2) * 36
+ tdiv(tmod((x * 128) + y, 36) * 2 + 1, 2)
- tmod((x * 128) + y, 1296)
+ 1
)
ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 5))
ck.analyzer.update(y, tvm.arith.ConstIntBound(0, 127))
ck.verify(res2, 1)
ck.analyzer.update(y, tvm.arith.ConstIntBound(0, 1024), True)
res3 = (
tdiv(x * 1024 + y, 65536)
+ tdiv(tmod(x * 1024 + y, 65536), 256)
+ tdiv(tmod(x * 1024 + y, 256), 16)
+ tmod(x * 1024 + y, 16)
- tdiv(y, 256)
- tdiv(tmod(y, 256), 16)
- tmod(y, 16)
- (x * 4)
)
ck.verify(res3, tdiv((x * 1024) + y, 256) - tdiv(y, 256) - (x * 4))
canonical_simplify()
cast#
ck = CanonicalChecker()
tcast = tvm.tir.Cast
fld = tvm.te.floordiv
flm = tvm.te.floormod
# cast(i64, i + j + 1) - cast(i64, i)
i = te.var("i", dtype="int32")
j = te.var("j", dtype="int32")
res = tcast("int64", i + j + 1) - tcast("int64", i)
ck.verify(res, tcast("int64", j) + tvm.tir.const(1, "int64"))
# cast(i32, i + j + 1) - cast(i32, i)
i = te.var("i", dtype="int64")
j = te.var("j", dtype="int64")
ck.analyzer.update(i, tvm.arith.ConstIntBound(0, 10))
ck.analyzer.update(j, tvm.arith.ConstIntBound(0, 10))
res = tcast("int32", i + j + 1) - tcast("int32", i)
ck.verify(res, tcast("int32", j) + 1)
# cast(i32, i + j - 100)
i = te.var("i", dtype="int64")
j = te.var("j", dtype="int64")
ck.analyzer.update(i, tvm.arith.ConstIntBound(0, 2**31 - 1))
ck.analyzer.update(j, tvm.arith.ConstIntBound(0, 10))
res = tcast("int32", i + j - 100)
ck.verify(res, res)
# cast(i32, flm(axis, 7i64) * 2i64 + 1i64) + 1i32
# - cast(i32, flm(axis, 7i64) * 2i64)
axis = te.var("axis", dtype="int64")
ck.analyzer.update(axis, tvm.arith.ConstIntBound(0, 42))
res = (
tcast(
"int32",
flm(axis, tvm.tir.const(7, "int64")) * tvm.tir.const(2, "int64")
+ tvm.tir.const(1, "int64"),
)
+ tvm.tir.const(1, "int32")
- tcast("int32", flm(axis, tvm.tir.const(7, "int64")) * tvm.tir.const(2, "int64"))
)
ck.verify(res, 2)
canonical_simplify()
normalize_min_value_expr#
ck = CanonicalChecker()
x = te.var("x", "int32")
ck.verify(te.min_value("int32") - x == 0, x == te.min_value("int32"))
ck.verify(te.min_value("int32") + x == 0, False)
ck.verify(0 == te.min_value("int32") - x, x == te.min_value("int32"))
ck.verify(0 == te.min_value("int32") + x, False)
ck.verify(-x + te.min_value("int32") == 0, x == te.min_value("int32"))
ck.verify(x + te.min_value("int32") == 0, False)
ck.verify(0 == -x + te.min_value("int32"), x == te.min_value("int32"))
ck.verify(0 == x + te.min_value("int32"), False)
canonical_simplify()
proddiv#
ck = CanonicalChecker()
flm = tvm.te.floormod
fld = tvm.te.floordiv
tdiv = tvm.te.truncdiv
tmod = tvm.te.truncmod
x, y, z = te.var("x"), te.var("y"), te.var("y")
ck.verify(flm(x * 32 * x, x), 0)
ck.verify(flm(z * x * 32 * x * y, x * z), 0)
ck.verify(flm(z * x * 32 * x * y, x * z * y * 8 * x), 0)
ck.verify(flm(z * x * 32 * (x * y), 6 * x * z), flm(x * y * 16, 3) * (x * z * 2))
ck.verify(flm(x * 32 * x, x * z), flm(x * 32, z) * x)
ck.verify(tmod(x * 32 * x, x), 0)
ck.verify(tmod(z * x * 32 * x * y, x * z), 0)
ck.verify(tmod(z * x * 32 * (x * y), 6 * x * z), tmod(x * y * 16, 3) * (x * z * 2))
ck.verify(tmod(x * 32 * x, x * z), tmod(x * 32, z) * x)
ck.verify(fld(x * 2 * x * z, 4 * x * x * x), fld(z, x * 2))
ck.verify(fld(x * (2 * y) * 3, 3 * y), x * 2)
ck.verify(fld(x * (2 * y) * 3, 3 * y * z), fld(x * 2, z))
ck.verify(tdiv(x * 2 * x * z, 4 * x * x * x), tdiv(z, x * 2))
ck.verify(tdiv(x * (2 * y) * 3, 3 * y), x * 2)
ck.verify(tdiv(x * (2 * y) * 3, 3 * y * z), tdiv(x * 2, z))
canonical_simplify()
floormod_two#
ck = CanonicalChecker()
flm = tvm.te.floormod
x, y = te.var("x"), te.var("y")
ck.verify(flm(x * 10 + 1 + y * 2 + 2, 2), 1)
canonical_simplify()
le#
ck = CanonicalChecker()
# Case 1. Ignore the extra expr if it's small than the division number
x, y, z = te.var("x"), te.var("y"), te.var("z")
ck.analyzer.bind(y, tvm.ir.Range(0, 8))
ck.analyzer.bind(z, tvm.ir.Range(0, 2))
ck.verify(x * 8 + y < 16, x < 2)
ck.verify(x * 8 + z * 4 < 16, x < 2)
ck.verify(x * 8 + z * 4 < 16, x < 2)
# TODO: Not sure why `-2 < x` will be convert to `x > -2`, use a explicit simplify here.
ck.verify(x * -8 + y < 16, ck.analyzer.rewrite_simplify(-2 < x))
ck.verify(x * -8 + z * 4 < 16, ck.analyzer.rewrite_simplify(-2 < x))
ck.verify(x * 8 + y + z < 16, x * 8 + y + z < 16)
ck.verify(x * 8 + y - z < 16, x < 2)
n = te.size_var("n")
ck.verify(x * 8 + y < n, x * 8 + y < n)
# Case 2. Simplify the extra expr
x1, x2, ty, tx, vec = (
tvm.te.var("x1"),
tvm.te.var("x2"),
tvm.te.var("ty"),
tvm.te.var("tx"),
tvm.te.var("vec"),
)
ck.analyzer.bind(x1, tvm.ir.Range(0, 2))
ck.analyzer.bind(x2, tvm.ir.Range(0, 3))
ck.analyzer.bind(ty, tvm.ir.Range(0, 8))
ck.analyzer.bind(tx, tvm.ir.Range(0, 32))
ck.analyzer.bind(vec, tvm.ir.Range(0, 8))
ck.verify(
x1 * 5632 + (((x2 * 8 + ty) * 32 + tx) * 8 + vec) % 5632 < 11008,
x1 * 22 + (x2 * 8 + ty) % 22 < 43,
)
ck.verify(tx // 2 % 8 + vec < 8, tx % 16 // 2 + vec < 8)
# Case 3. No failure
x, y, z = te.var("x"), te.var("y"), te.var("z")
ck.analyzer.bind(y, tvm.ir.Range(0, 1024))
ck.verify(x * 1024 + y < z * 7168, x - z * 7 < 0)