simplify()#

%cd ..
import testing
/media/pc/data/lxw/ai/tvm-book/doc/read/arith
import tvm
import tvm.testing
from tvm import tir
from tvm.script import tir as T

simplify() reshape_flattened#

ana = tvm.arith.Analyzer()

i0 = tir.Var("i0", "int64")
i1 = tir.Var("i1", "int64")
ana.bind(i0, tvm.ir.Range(0, 8))
ana.bind(i1, tvm.ir.Range(0, 3))

i_flattened = i0 * 3 + i1
assert tvm.ir.structural_equal(
    ana.simplify((i_flattened) // 12 * 12 + (i_flattened) % 12 // 4 * 4 + (i_flattened) % 4),
    i_flattened,
)

simplify() symbolic_comparison#

ana = tvm.arith.Analyzer()

i0 = tir.Var("i0", "int64")
i1 = tir.Var("i1", "int64")
n, m = tvm.tir.SizeVar("n", "int64"), tvm.tir.SizeVar("m", "int64")
outer = (n + 31) // 32
ana.bind(i0, tvm.ir.Range(0, outer))
ana.bind(i1, tvm.ir.Range(0, 32))
PS = tvm.arith.ProofStrength

assert not ana.can_prove(i0 * 32 + i1 < (n + 31) // 32 * 32, PS.DEFAULT)
assert ana.can_prove(i0 * 32 + i1 < (n + 31) // 32 * 32, PS.SYMBOLIC_BOUND)
assert ana.can_prove(i0 * 32 + i1 < (n + 31) // 32 * 32 + m, PS.SYMBOLIC_BOUND)
assert ana.can_prove(i0 * 32 + i1 + 1 <= (n + 31) // 32 * 32, PS.SYMBOLIC_BOUND)
assert ana.can_prove((n + 31) // 32 * 32 >= i0 * 32 + i1 + 1, PS.SYMBOLIC_BOUND)
assert ana.can_prove((n + 31) // 32 * 32 >= i0 * 32 + i1, PS.SYMBOLIC_BOUND)

simplify() vscale_comparison_with_sve_target#

for expression in [
    T.vscale() * 32 < T.vscale() * 64,
    T.vscale() * 2 * (T.vscale() * 2) >= T.vscale() * 4,
    (T.vscale() * 4 + 114) // (T.vscale() * 4) * (T.vscale() * 4) >= 115,
    64 % T.vscale() <= T.vscale(),
]:
    ana = tvm.arith.Analyzer()

    with tvm.target.Target("llvm -mtriple=aarch64-linux-gnu -mattr=+sve"):
        assert ana.can_prove(expression)

simplify() vscale_comparison_without_sve_target#

ana = tvm.arith.Analyzer()
vs = tvm.tir.vscale()
with tvm.target.Target("llvm -mtriple=aarch64-linux-gnu"):
    assert ana.can_prove(vs * 32 < vs * 64)
[09:35:26] /media/pc/data/lxw/ai/tvm/src/arith/analyzer.cc:240: Warning: The expression contains scalable values. An attempt to prove by substituting with known values of vscale was not performed. This proof currently only supports AArch64 SVE targets, but the target was llvm -keys=arm_cpu,cpu -mtriple=aarch64-linux-gnu
---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
Cell In[8], line 4
      2 vs = tvm.tir.vscale()
      3 with tvm.target.Target("llvm -mtriple=aarch64-linux-gnu"):
----> 4     assert ana.can_prove(vs * 32 < vs * 64)

AssertionError: 

simplify() vscale_non_comparison#

ana = tvm.arith.Analyzer()
vs = tvm.tir.vscale()

with tvm.target.Target("llvm -mtriple=aarch64-linux-gnu -mattr=+sve"):
    ana.can_prove(vs * 4)
---------------------------------------------------------------------------
InternalError                             Traceback (most recent call last)
Cell In[9], line 5
      2 vs = tvm.tir.vscale()
      4 with tvm.target.Target("llvm -mtriple=aarch64-linux-gnu -mattr=+sve"):
----> 5     ana.can_prove(vs * 4)

File /media/pc/data/lxw/ai/tvm/python/tvm/arith/analyzer.py:247, in Analyzer.can_prove(self, expr, strength)
    231 def can_prove(self, expr, strength=ProofStrength.DEFAULT):
    232     """Check whether we can prove expr to be true.
    233 
    234     Parameters
   (...)
    245         The result.
    246     """
--> 247     return self._can_prove(expr, strength)

File /media/pc/data/lxw/ai/tvm/python/tvm/_ffi/_ctypes/packed_func.py:239, in PackedFuncBase.__call__(self, *args)
    227 ret_tcode = ctypes.c_int()
    228 if (
    229     _LIB.TVMFuncCall(
    230         self.handle,
   (...)
    237     != 0
    238 ):
--> 239     raise_last_ffi_error()
    240 _ = temp_args
    241 _ = args

File /media/pc/data/lxw/ai/tvm/python/tvm/_ffi/base.py:481, in raise_last_ffi_error()
    475 # The exception PyObject may contain a large amount of state,
    476 # including all stack frames that may be inspected in a later
    477 # PDB post-mortem.  Therefore, we must make sure to remove the
    478 # underlying PyObject* from the C++ side after we retrieve it.
    479 _LIB.TVMDropLastPythonError()
--> 481 raise py_err

InternalError: Traceback (most recent call last):
  2: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::arith::__mk_TVM0::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::arith::__mk_TVM0, tvm::runtime::TVMRetValue) const::{lambda(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)#1}::operator()(std::allocator<char>) const::{lambda(tvm::arith::__mk_TVM0, tvm::runtime::TVMRetValue)#11}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::arith::__mk_TVM0, tvm::runtime::TVMRetValue)
  1: tvm::arith::Analyzer::CanProve(tvm::PrimExpr const&, tvm::arith::ProofStrength)
  0: tvm::arith::CanProveVscaleExpressionFromKnownValues(tvm::arith::Analyzer*, tvm::PrimExpr const&, std::vector<unsigned int, std::allocator<unsigned int> > const&)
  File "/media/pc/data/lxw/ai/tvm/src/arith/scalable_expression.cc", line 82
InternalError: Check failed: (IsComparison(expr)) is false: Expected comparison but got: T.vscale() * 4

simplify() regression_simplify_inf_recursion#

ana = tvm.arith.Analyzer()
cond = tir.Var("cond", "int32")

res = (tvm.tir.NE(cond, 0).astype("int8") - tvm.tir.NE(cond, 0).astype("int8")).astype(
    "int32"
) == 0
# regression in a previous case
# try compare and int set recursive call can cause infinite loop
ana.rewrite_simplify(res)
T.Cast("int32", T.Cast("int8", cond != 0) - T.Cast("int8", cond != 0)) == 0