# Licensed to the Apache Software Foundation (ASF) under one# or more contributor license agreements. See the NOTICE file# distributed with this work for additional information# regarding copyright ownership. The ASF licenses this file# to you under the Apache License, Version 2.0 (the# "License"); you may not use this file except in compliance# with the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing,# software distributed under the License is distributed on an# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY# KIND, either express or implied. See the License for the# specific language governing permissions and limitations# under the License.# pylint: disable=invalid-name"""Arithmetic data structure and utility"""importenumfromtypingimportUnionimporttvm._ffifromtvmimporttir,irfromtvm.runtimeimportObjectfrom.import_ffi_apiclassProofStrength(enum.IntEnum):"""Proof strength of the analysis"""DEFAULT=0SYMBOLIC_BOUND=1classExtension(enum.Flag):"""Extensions enabled for RewriteSimplifier Values should match `RewriteSimplifier::Extensions` """NoExtensions=0TransitivelyProveInequalities=1<<0ConvertBooleanToAndOfOrs=1<<1ApplyConstraintsToBooleanBranches=1<<2ComparisonOfProductAndSum=1<<3@tvm._ffi.register_object("arith.ModularSet")classModularSet(Object):"""Represent range of (coeff * x + base) for x in Z"""def__init__(self,coeff,base):self.__init_handle_by_constructor__(_ffi_api.ModularSet,coeff,base)@tvm._ffi.register_object("arith.ConstIntBound")classConstIntBound(Object):"""Represent constant integer bound Parameters ---------- min_value : int The minimum value of the bound. max_value : int The maximum value of the bound. """POS_INF=(1<<63)-1NEG_INF=-POS_INFdef__init__(self,min_value,max_value):self.__init_handle_by_constructor__(_ffi_api.ConstIntBound,min_value,max_value)classConstraintScope:"""Constraint scope. Parameters ---------- fenter : function A function that will be called to create an enter context. Note ---- Do not create object directly, use Analyzer.constraint_scope """def__init__(self,fenter):self._fenter=fenterself._fexit=Nonedef__enter__(self):self._fexit=self._fenter()def__exit__(self,ptype,value,trace):self._fexit()
[文档]classAnalyzer:"""Integer arithmetic analyzer This is a stateful analyzer class that can be used to perform various symbolic integer analysis. """def__init__(self):_mod=_ffi_api.CreateAnalyzer()self._const_int_bound=_mod("const_int_bound")self._const_int_bound_update=_mod("const_int_bound_update")self._bind=_mod("bind")self._modular_set=_mod("modular_set")self._simplify=_mod("Simplify")self._rewrite_simplify=_mod("rewrite_simplify")self._get_rewrite_simplify_stats=_mod("get_rewrite_simplify_stats")self._reset_rewrite_simplify_stats=_mod("reset_rewrite_simplify_stats")self._canonical_simplify=_mod("canonical_simplify")self._int_set=_mod("int_set")self._enter_constraint_context=_mod("enter_constraint_context")self._can_prove_equal=_mod("can_prove_equal")self._can_prove=_mod("can_prove")self._get_enabled_extensions=_mod("get_enabled_extensions")self._set_enabled_extensions=_mod("set_enabled_extensions")
[文档]defconst_int_bound(self,expr):"""Find constant integer bound for expr. Parameters ---------- expr : PrimExpr The expression. Returns ------- bound : ConstIntBound The result bound """returnself._const_int_bound(expr)
[文档]defmodular_set(self,expr):"""Find a modular set that expr belongs to. Parameters ---------- expr : PrimExpr The expression. Returns ------- result : ModularSet The result. """returnself._modular_set(expr)
[文档]defsimplify(self,expr,steps=2):"""Simplify expression via both rewrite and canonicalization. Parameters ---------- expr : PrimExpr The expression. steps : The simplification runs in the order of rewrite_simplify (step 1) -> canonical_simplify (step 2) -> rewrite_simplify (step 3) -> canonical_simplify (step 4) -> ... param steps controls how many steps to run. Default is 2, i.e., rewrite_simplify + canonical_simplify. Returns ------- result : Expr The result. """returnself._simplify(expr,steps)
[文档]defrewrite_simplify(self,expr):"""Simplify expression via rewriting rules. Parameters ---------- expr : PrimExpr The expression. Returns ------- result : Expr The result. """returnself._rewrite_simplify(expr)
[文档]defcanonical_simplify(self,expr):"""Simplify expression via canonicalization. Parameters ---------- expr : PrimExpr The expression. Returns ------- result : Expr The result. """returnself._canonical_simplify(expr)
[文档]defint_set(self,expr,dom_map):"""Compute a symbolic IntSet that covers expr for all values in dom_map. Parameters ---------- expr : PrimExpr The expression. dom_map : Dict[tvm.tir.Var, tvm.arith.IntSet] The domain for variables to be relaxed. Returns ------- result : IntSet The result. """returnself._int_set(expr,dom_map)
[文档]defcan_prove(self,expr,strength=ProofStrength.DEFAULT):"""Check whether we can prove expr to be true. Parameters ---------- expr : PrimExpr The expression. strength: ProofStrength The proof strength Returns ------- result : Expr The result. """returnself._can_prove(expr,strength)
[文档]defbind(self,var:tir.Var,expr:Union[tir.PrimExpr,ir.Range]):"""Bind a variable to the expression. Parameters ---------- var : tvm.tir.Var The variable. expr : Union[tir.PrimExpr, ir.Range] The expression or the range to bind to. """returnself._bind(var,expr)
[文档]defconstraint_scope(self,constraint):"""Create a constraint scope. Parameters ---------- constraint : PrimExpr The constraint expression. returns ------- scope : ConstraintScope The constraint scope Examples -------- .. code-block:: python x = te.var("x") analyzer = tvm.arith.Analyzer() with analzyer.constraint_scope(x % 3 == 0): # constraint in effect assert analyzer.modular_set(x).coeff == 3 # constraint no longer in effect assert analyzer.modular_set(x).coeff != 3 """def_fenter():returnself._enter_constraint_context(constraint)returnConstraintScope(_fenter)
[文档]defupdate(self,var,info,override=False):"""Update infomation about var Parameters ---------- var : tvm.tir.Var The variable. info : tvm.Object Related information. override : bool Whether allow override. """ifisinstance(info,ConstIntBound):self._const_int_bound_update(var,info,override)else:raiseTypeError("Do not know how to handle type {}".format(type(info)))
[文档]defcan_prove_equal(self,lhs:"PrimExpr",rhs:"PrimExpr"):"""Whether we can prove that lhs == rhs Parameters ---------- lhs: PrimExpr The left-hand side of the comparison rhs: PrimExpr The right-hand side of the comparison Returns ------- result: bool Whether we can prove that lhs == rhs """returnself._can_prove_equal(lhs,rhs)
@propertydefenabled_extensions(self)->Extension:"""Return the currently enabled extensions"""value=self._get_enabled_extensions()returnExtension(value)@enabled_extensions.setterdefenabled_extensions(self,flags:Union[int,Extension]):"""Enable extensions for the analyzer Parameters ---------- flags: Union[int,Extension] The extensions to enable. """flags=Extension(flags).valueself._set_enabled_extensions(flags)