# Licensed to the Apache Software Foundation (ASF) under one# or more contributor license agreements. See the NOTICE file# distributed with this work for additional information# regarding copyright ownership. The ASF licenses this file# to you under the Apache License, Version 2.0 (the# "License"); you may not use this file except in compliance# with the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing,# software distributed under the License is distributed on an# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY# KIND, either express or implied. See the License for the# specific language governing permissions and limitations# under the License.# pylint: disable=invalid-name,too-many-locals"""Utility functions for Relax"""importitertoolsimportstringfromtypingimportTupleastyping_TuplefromtypingimportAny,Callable,List,Dict,Optionalimporttvmfrom..importtirfrom..tirimportPrimExprfrom..runtimeimportString,convert_to_objectfrom.import_ffi_apifrom.exprimportTupleasrx_Tuplefrom.exprimportExpr,ShapeExpr,Function,PrimValue,StringImm,te_tensorfrom..teimportTensoraste_Tensor,create_prim_funcfrom..irimportArray,Attrs,Type,Map,VDevicefrom.struct_infoimportPrimStructInfo,ShapeStructInfo,TensorStructInfo# Re-export `args_converter` here for backwards compatibilityfrom.type_converterimportargs_converter# pylint: disable=unused-importdefmetadata_partitioner(rx_txt:str)->List[str]:"""Extract Relax program and metadata section. Parameters ---------- rx_txt : str The input relax text. Returns ------- output : List[str] The result list of partitioned text, the first element is the relax program, and the second is metadata section. """partitions=[]left_curly=0meta_start=0meta_end=0fori,charinenumerate(rx_txt):ifi<0:raiseValueError("The program is invalid.")ifchar=="{":ifmeta_start==0:meta_start=ileft_curly+=1elifchar=="}":left_curly-=1ifleft_curly==0:meta_end=i+1breakifmeta_end==0:raiseValueError("The metadata section was not found.")metadata=rx_txt[meta_start:meta_end]rx_program=rx_txt[meta_end:-1]partitions.append(rx_program)partitions.append(metadata)returnpartitions
[文档]defconvert_to_expr(value:Any)->Expr:"""Helper function to convert the input to Expr, which follows the rules: 1. Return the input itself if it's already a `relax.Expr`; 2. Return `relax.PrimValue` if the input is a `PrimExpr`; 3. Return `relax.StringImm` if the input is `tvm.String` or `str`; 4. Return `relax.Tuple` if the input is a tuple/list of `Expr`. Notes ----- 1. `tvm.tir.StringImm` is not allowed because of ambiguity, which can be either `relax.StringImm` or `relax.PrimValue`. """ifisinstance(value,int):returnPrimValue(tir.IntImm("int64",value))ifisinstance(value,float):returnPrimValue(tir.FloatImm("float64",value))tvm_value=convert_to_object(value)# Case 1ifisinstance(tvm_value,Expr):# type: ignorereturntvm_value# Note`` 1ifisinstance(tvm_value,tir.StringImm):raiseTypeError("Cannot convert `tir.StringImm` to `relax.Expr` because of ambiguity,""which can be either `relax.StringImm` or `relax.PrimValue` ")# Case 2ifisinstance(tvm_value,PrimExpr):returnPrimValue(value)# Case 3ifisinstance(tvm_value,String):returnStringImm(value)# Case 4ifisinstance(value,(tuple,list)):# `convert_to_expr` ensures that all elements are `Expr` if no exception raisesreturnrx_Tuple([convert_to_expr(v)forvinvalue])raiseTypeError(f"Cannot convert {value} with type {type(value)} to `relax.Expr`")
defcopy_with_new_vars(func:Function)->Function:"""Copy the given function. All variables that are bound inside the original function would be copied to satisfy the restriction in the well-formed check: Variables in Relax must be bound exactly once. This also ensures that both the function and its copy can be inserted into the same IRModule, and be asserted on the structural equality agaisnt IRModule created by TVMScript. Parameters ---------- func : Function The relax function to copy. Returns ------- ret : Function The copied function. """return_ffi_api.CopyWithNewVars(func)# type: ignoredefgen_call_tir_inputs(func:Callable,*args:Any,**kwargs:Any)->typing_Tuple[tir.PrimFunc,Expr,List[TensorStructInfo],Optional[ShapeExpr]]:"""Generate the inputs for call_tir according to the te function. This function converts arguments from relax expression to te tensor, The callback func should return a te tensor or a list of te tensors. Parameters ---------- func : Callable A function that returns a te tensor or a list of te tensors. args : Any, optional arguments passed to the function. kwargs : Any, optional The keyword arguments passed to the function. Note that the keyword args 'primfunc_attrs' is reserved for passing func attributes to be added to the PrimFunc that gets created. Returns ------- ret : Tuple[tir.PrimFunc, Expr, List[TensorStructInfo], Optional[ShapeExpr]] ret contains the inputs for call_tir, including a tir prim_func, args, out_sinfo, and tir_vars. """tir_var_map:Dict[tir.Var,tir.PrimExpr]={}call_tir_args=[]create_primfunc_args=[]# extra list of tir expression arguments# that are not covered by Tensorextra_tir_args_list=[]def_copy_undefined_var(expr:tir.PrimExpr):def_visit_expr(e:tir.PrimExpr):ifisinstance(e,tir.Var)andenotintir_var_map:new_var=tir.Var(e.name,e.dtype)tir_var_map[e]=new_vartir.stmt_functor.post_order_visit(expr,_visit_expr)def_convert_te_arg(te_args:Any)->Any:"""Helper function used to convert Relax expressions to TE tensor. In the common case, the type of te_args is a Relax expression and is converted into a TE tensor. If te_args is a nested or recursive datatype (i.e list, dict, tvm.ir.Map, tvm.ir.Array), we recursive and convert any value of type Relax expression into a TE tensor. Common values of type int, float, and str are preserved. In dynamic shape cases, the passed in arguments may contain TIR variable. For example, the argument can be a Relax Var with TensorStructInfo, which has symbolic shape, or the argument can be a ShapeExpr with symbolic variables. To make the PrimFunc generated has independent variables with the caller Relax function, we will substitute the TIR variables in the input arguments with fresh ones, which is done by maintaining a TIR variable mapping. Parameters ---------- te_args : Any Argument to convert to TE tir_var_map : Dict[tir.Var, tir.PrimExpr] The TIR variable mapping, which maps TIR variables on the Relax function side to the new set of variables used on the PrimFunc side. Returns ------- ret : (Any, [tvm.te.Tensor]) A tuple of the converted te_args, and a list of te tensors for each converted Relax expression """def_convert_te_arg_helper(arg):ifisinstance(arg,Expr):# type: ignoreifisinstance(arg.struct_info,TensorStructInfo):assertisinstance(arg.struct_info.shape,ShapeExpr),"emit_te now only supports Tensor that has ShapeExpr shape"forshape_valueinarg.struct_info.shape.values:_copy_undefined_var(shape_value)n_args=len(create_primfunc_args)ifisinstance(arg,tvm.relax.Var):name=arg.name_hintelifn_args<len(string.ascii_uppercase):name=string.ascii_uppercase[n_args]else:name=f"tensor_input_{n_args}"te_arg=te_tensor(arg,tir_var_map,name)call_tir_args.append(arg)create_primfunc_args.append(te_arg)returnte_argifisinstance(arg.struct_info,ShapeStructInfo):assertisinstance(arg,ShapeExpr),"For Expr having ShapeStructInfo, emit_te now only supports ShapeExpr"return[_convert_te_arg_helper(val)forvalinarg.values]ifisinstance(arg.struct_info,PrimStructInfo):ifarg.struct_info.valueisNone:n_args=len(create_primfunc_args)ifisinstance(arg,tvm.relax.Var):name=arg.name_hintelifn_args<len(string.ascii_lowercase):name=string.ascii_lowercase[n_args]else:name=f"scalar_input_{n_args}"tir_param=tir.Var(name,arg.struct_info.dtype)call_tir_args.append(arg)create_primfunc_args.append(tir_param)returntir_paramelse:return_convert_te_arg_helper(arg.struct_info.value)elifisinstance(arg,(list,Array)):return[_convert_te_arg_helper(x)forxinarg]elifisinstance(arg,tuple):returntuple(_convert_te_arg_helper(x)forxinarg)elifisinstance(arg,(dict,Map)):forkeyinarg:assertisinstance(key,str),"emit_te only supports dict with string as the key currently"return{k:_convert_te_arg_helper(arg[k])forkinarg}elifisinstance(arg,tir.PrimExpr):_copy_undefined_var(arg)new_arg=tir.stmt_functor.substitute(arg,tir_var_map)extra_tir_args_list.append(new_arg)returnnew_argelifisinstance(arg,(int,float,str,Type,Attrs))orargisNone:returnargraiseTypeError("not supported type in emit_te: {}".format(type(arg)))new_arg=_convert_te_arg_helper(te_args)returnnew_argdef_get_unbound_tir_vars(args:List[te_Tensor],extra_tir_args:List[PrimExpr])->List[tir.Var]:"""get unbound TIR vars (i.e TIR vars used in the shape but is not itself a dimension of a shape)"""bound_vars=set()used_vars=set()def_populate_bound_vars(expr):ifisinstance(expr,te_Tensor):fordiminexpr.shape:_populate_bound_vars(dim)elifisinstance(expr,tir.Var):bound_vars.add(expr)def_populate_used_vars(expr):ifisinstance(expr,te_Tensor):fordiminexpr.shape:_populate_used_vars(dim)elifisinstance(expr,tir.PrimExpr):used_vars.update(tir.analysis.undefined_vars(expr))forarginitertools.chain(args,extra_tir_args):_populate_used_vars(arg)forarginargs:_populate_bound_vars(arg)diff=used_vars-bound_varsreturnlist(diff)def_get_vdevice(arg:Any)->Optional[VDevice]:"""get the virtual device from arguments."""vdevice=Noneifisinstance(arg,Expr):# type: ignoreifisinstance(arg.struct_info,TensorStructInfo):vdevice=arg.struct_info.vdeviceelifisinstance(arg,(list,Array,tuple)):forxinarg:vdevice=_get_vdevice(x)ifvdeviceisnotNone:returnvdeviceelifisinstance(arg,(dict,Map)):forkinarg:vdevice=_get_vdevice(arg[k])ifvdeviceisnotNone:returnvdevicereturnvdevicedef_shape_with_old_tir_var(shape_values:List[tir.PrimExpr],tir_var_inverse_map:Dict[tir.Var,tir.PrimExpr]):returnShapeExpr([tir.stmt_functor.substitute(value,tir_var_inverse_map)forvalueinshape_values])primfunc_attrs=kwargs.pop("primfunc_attrs",None)te_args=_convert_te_arg(args)te_kwargs=_convert_te_arg(kwargs)te_out=func(*te_args,**te_kwargs)assertisinstance(te_out,te_Tensor)or(isinstance(te_out,(tuple,list,Array))andall(isinstance(t,te_Tensor)fortinte_out)),"only support te.tensor or tuple/list/Array of te.tensor as function output"outs=[te_out]ifisinstance(te_out,te_Tensor)elselist(te_out)unbound_tir_vars=_get_unbound_tir_vars([*create_primfunc_args,*outs],extra_tir_args_list)inputs=[*create_primfunc_args]+outs+unbound_tir_varstir_func=create_prim_func(inputs,"int64")ifprimfunc_attrs:tir_func=tir_func.with_attrs(primfunc_attrs)tir_func=tir_func.without_attr("global_symbol")# Invert the TIR variable mapping, to convert the output shape back# with old set of variables.tir_var_inverse_map={v:kfork,vintir_var_map.items()}output_sinfo=[TensorStructInfo(_shape_with_old_tir_var(out.shape,tir_var_inverse_map),out.dtype,_get_vdevice(args),)foroutinouts]tir_vars=Noneiflen(unbound_tir_vars)>0:tir_vars=_shape_with_old_tir_var(unbound_tir_vars,tir_var_inverse_map)return(tir_func,call_tir_args,output_sinfo,tir_vars)