simplify()
#
%cd ..
import testing
/media/pc/data/lxw/ai/tvm-book/doc/read/arith
import tvm
import tvm.testing
from tvm import tir
from tvm.script import tir as T
simplify()
reshape_flattened#
ana = tvm.arith.Analyzer()
i0 = tir.Var("i0", "int64")
i1 = tir.Var("i1", "int64")
ana.bind(i0, tvm.ir.Range(0, 8))
ana.bind(i1, tvm.ir.Range(0, 3))
i_flattened = i0 * 3 + i1
assert tvm.ir.structural_equal(
ana.simplify((i_flattened) // 12 * 12 + (i_flattened) % 12 // 4 * 4 + (i_flattened) % 4),
i_flattened,
)
simplify()
symbolic_comparison#
ana = tvm.arith.Analyzer()
i0 = tir.Var("i0", "int64")
i1 = tir.Var("i1", "int64")
n, m = tvm.tir.SizeVar("n", "int64"), tvm.tir.SizeVar("m", "int64")
outer = (n + 31) // 32
ana.bind(i0, tvm.ir.Range(0, outer))
ana.bind(i1, tvm.ir.Range(0, 32))
PS = tvm.arith.ProofStrength
assert not ana.can_prove(i0 * 32 + i1 < (n + 31) // 32 * 32, PS.DEFAULT)
assert ana.can_prove(i0 * 32 + i1 < (n + 31) // 32 * 32, PS.SYMBOLIC_BOUND)
assert ana.can_prove(i0 * 32 + i1 < (n + 31) // 32 * 32 + m, PS.SYMBOLIC_BOUND)
assert ana.can_prove(i0 * 32 + i1 + 1 <= (n + 31) // 32 * 32, PS.SYMBOLIC_BOUND)
assert ana.can_prove((n + 31) // 32 * 32 >= i0 * 32 + i1 + 1, PS.SYMBOLIC_BOUND)
assert ana.can_prove((n + 31) // 32 * 32 >= i0 * 32 + i1, PS.SYMBOLIC_BOUND)
simplify()
vscale_comparison_with_sve_target#
for expression in [
T.vscale() * 32 < T.vscale() * 64,
T.vscale() * 2 * (T.vscale() * 2) >= T.vscale() * 4,
(T.vscale() * 4 + 114) // (T.vscale() * 4) * (T.vscale() * 4) >= 115,
64 % T.vscale() <= T.vscale(),
]:
ana = tvm.arith.Analyzer()
with tvm.target.Target("llvm -mtriple=aarch64-linux-gnu -mattr=+sve"):
assert ana.can_prove(expression)
simplify()
vscale_comparison_without_sve_target#
ana = tvm.arith.Analyzer()
vs = tvm.tir.vscale()
with tvm.target.Target("llvm -mtriple=aarch64-linux-gnu"):
assert ana.can_prove(vs * 32 < vs * 64)
[09:35:26] /media/pc/data/lxw/ai/tvm/src/arith/analyzer.cc:240: Warning: The expression contains scalable values. An attempt to prove by substituting with known values of vscale was not performed. This proof currently only supports AArch64 SVE targets, but the target was llvm -keys=arm_cpu,cpu -mtriple=aarch64-linux-gnu
---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
Cell In[8], line 4
2 vs = tvm.tir.vscale()
3 with tvm.target.Target("llvm -mtriple=aarch64-linux-gnu"):
----> 4 assert ana.can_prove(vs * 32 < vs * 64)
AssertionError:
simplify()
vscale_non_comparison#
ana = tvm.arith.Analyzer()
vs = tvm.tir.vscale()
with tvm.target.Target("llvm -mtriple=aarch64-linux-gnu -mattr=+sve"):
ana.can_prove(vs * 4)
---------------------------------------------------------------------------
InternalError Traceback (most recent call last)
Cell In[9], line 5
2 vs = tvm.tir.vscale()
4 with tvm.target.Target("llvm -mtriple=aarch64-linux-gnu -mattr=+sve"):
----> 5 ana.can_prove(vs * 4)
File /media/pc/data/lxw/ai/tvm/python/tvm/arith/analyzer.py:247, in Analyzer.can_prove(self, expr, strength)
231 def can_prove(self, expr, strength=ProofStrength.DEFAULT):
232 """Check whether we can prove expr to be true.
233
234 Parameters
(...)
245 The result.
246 """
--> 247 return self._can_prove(expr, strength)
File /media/pc/data/lxw/ai/tvm/python/tvm/_ffi/_ctypes/packed_func.py:239, in PackedFuncBase.__call__(self, *args)
227 ret_tcode = ctypes.c_int()
228 if (
229 _LIB.TVMFuncCall(
230 self.handle,
(...)
237 != 0
238 ):
--> 239 raise_last_ffi_error()
240 _ = temp_args
241 _ = args
File /media/pc/data/lxw/ai/tvm/python/tvm/_ffi/base.py:481, in raise_last_ffi_error()
475 # The exception PyObject may contain a large amount of state,
476 # including all stack frames that may be inspected in a later
477 # PDB post-mortem. Therefore, we must make sure to remove the
478 # underlying PyObject* from the C++ side after we retrieve it.
479 _LIB.TVMDropLastPythonError()
--> 481 raise py_err
InternalError: Traceback (most recent call last):
2: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::arith::__mk_TVM0::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::arith::__mk_TVM0, tvm::runtime::TVMRetValue) const::{lambda(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)#1}::operator()(std::allocator<char>) const::{lambda(tvm::arith::__mk_TVM0, tvm::runtime::TVMRetValue)#11}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::arith::__mk_TVM0, tvm::runtime::TVMRetValue)
1: tvm::arith::Analyzer::CanProve(tvm::PrimExpr const&, tvm::arith::ProofStrength)
0: tvm::arith::CanProveVscaleExpressionFromKnownValues(tvm::arith::Analyzer*, tvm::PrimExpr const&, std::vector<unsigned int, std::allocator<unsigned int> > const&)
File "/media/pc/data/lxw/ai/tvm/src/arith/scalable_expression.cc", line 82
InternalError: Check failed: (IsComparison(expr)) is false: Expected comparison but got: T.vscale() * 4
simplify()
regression_simplify_inf_recursion#
ana = tvm.arith.Analyzer()
cond = tir.Var("cond", "int32")
res = (tvm.tir.NE(cond, 0).astype("int8") - tvm.tir.NE(cond, 0).astype("int8")).astype(
"int32"
) == 0
# regression in a previous case
# try compare and int set recursive call can cause infinite loop
ana.rewrite_simplify(res)
T.Cast("int32", T.Cast("int8", cond != 0) - T.Cast("int8", cond != 0)) == 0