# Licensed to the Apache Software Foundation (ASF) under one# or more contributor license agreements. See the NOTICE file# distributed with this work for additional information# regarding copyright ownership. The ASF licenses this file# to you under the Apache License, Version 2.0 (the# "License"); you may not use this file except in compliance# with the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing,# software distributed under the License is distributed on an# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY# KIND, either express or implied. See the License for the# specific language governing permissions and limitations# pylint: disable=redefined-builtin"""The base Relax operators."""fromtypingimportDict,Union,List,Tuple,Optional,Callableimporttvmimporttvm.runtimefromtvm.runtime.objectimportObjectfromtvm.runtimeimportObjectGenericfrom.import_ffi_apifrom..exprimportExpr,StringImm,ShapeExpr,Call,ExternFunc,GlobalVar,Varfrom..struct_infoimportStructInfo,TensorStructInfofrom...irimportPrimExprfrom..utilsimportargs_converterpy_print=print# pylint: disable=invalid-name
[文档]defregister_gradient(op_name:str,fgradient:Callable[[Var,Call,Var,"BlockBuilder"],List[Expr]]=None,level:int=10,):"""Register operator gradient function for a relax operator. Parameters ---------- op_name: str The name of the op. fgradient: function (orig_var: Var, orig_call: Call, output_grad: Var, ctx: BlockBuilder) -> partials: List[Expr] The gradient function being used. level: int The priority level """returntvm.ir.register_op_attr(op_name,"FPrimalGradient",fgradient,level)
[文档]defnull_value()->Call:"""Create a call node that represents a null value object. Returns ------- ret: Call The created call node. """return_ffi_api.null_value()# type: ignore
def_wrap_inline_arg_tuple(args)->Expr:"""Helper function to wrap argument tuple Normalize the arguments provided the functions that accept a tuple of arguments, and require the tuple of arguments to be written in-line. If the arguments provided are a single relax expression, and are not a reference to a relax tuple, then wrap them into an in-line relax Tuple. """if(isinstance(args,Expr)andnotisinstance(args,tvm.relax.Tuple)and(args.struct_info_isNoneornotisinstance(args.struct_info_,tvm.relax.TupleStructInfo))):returntvm.relax.Tuple([args])else:returnargs
[文档]@args_converter.autodefcall_tir(gvar:GlobalVar,args:Expr,out_sinfo:Union[TensorStructInfo,List[TensorStructInfo]],tir_vars:Optional[Union[ShapeExpr,Tuple[PrimExpr],List[PrimExpr]]]=None,)->Call:""" Call a tir.prim_func and return the output. Parameters ---------- gvar : GlobalVar The GlobalVar referring to a tir PrimFunc. args : Expr The input arguments. out_sinfo : Union[TensorStructInfo, List[TensorStructInfo]] The structure info of the call_tir output. It should be a single or a list of TensorStructInfo. Each one denotes the structure info of a returned tensor. tir_vars : Optional[Union[ShapeExpr, Tuple[PrimExpr], List[PrimExpr]]] ShapeExpr representing a tuple of integers to unpack when calling func. Is null if not used Returns ------- ret: Call A call node for the call_tir operator. """args=_wrap_inline_arg_tuple(args)ifnotisinstance(out_sinfo,list):out_sinfo=[out_sinfo]ifisinstance(tir_vars,(list,tuple)):tir_vars=ShapeExpr(tir_vars)return_ffi_api.call_tir(gvar,args,out_sinfo,tir_vars)# type: ignore
[文档]@args_converter.autodefcall_tir_with_grad(gvar:GlobalVar,args:Expr,out_sinfo:Union[TensorStructInfo,List[TensorStructInfo]],te_grad_name:str,te_grad_kwargs:Dict[str,Object]=None,tir_vars:Optional[Union[ShapeExpr,Tuple[PrimExpr],List[PrimExpr]]]=None,)->Call:""" Call a tir.prim_func and return the output. This intrinsic will bind a te gradient function (refered by te_grad_name) to the call_tir_with_grad node. The te gradient function will be called by the Gradient pass. Parameters ---------- gvar : GlobalVar The GlobalVar referring to a tir PrimFunc. args : Expr The input arguments. out_sinfo : Union[TensorStructInfo, List[TensorStructInfo]] The structure info of the call_tir_with_grad output. It should be a single or a list of TensorStructInfo. Each one denotes the structure info of a returned tensor. te_grad_name : str The registered name of the te gradient function associated with the call_tir_with_grad node. Must be provided as a keyword argument. te_grad_kwargs : Dict[str, Object], optional The keyword arguments passed to the te gradient function. Optionally provided as a keyword argument. Default: {}. tir_vars : Optional[Union[ShapeExpr, Tuple[PrimExpr], List[PrimExpr]]] ShapeExpr representing a tuple of integers to unpack when calling func. Is null if not used Returns ------- ret: Call A call node for the call_tir_with_grad operator. """args=_wrap_inline_arg_tuple(args)ifnotisinstance(out_sinfo,list):out_sinfo=[out_sinfo]ifisinstance(tir_vars,(list,tuple)):tir_vars=ShapeExpr(tir_vars)ifte_grad_kwargsisNone:te_grad_kwargs={}return_ffi_api.call_tir_with_grad(# type: ignoregvar,args,out_sinfo,te_grad_name,te_grad_kwargs,tir_vars)
[文档]@args_converter.autodefcall_tir_inplace(gvar:GlobalVar,args:Expr,inplace_indices:Union[int,List[int]],out_sinfo:Union[TensorStructInfo,List[TensorStructInfo]],tir_vars:Optional[Union[ShapeExpr,Tuple[PrimExpr],List[PrimExpr]]]=None,)->Call:""" Call a TIR PrimFunc and return the result, doing the specified computations in-place (based on the `inplace_indices` argument; outputs will alias the inputs selected by in-place indices). Warning: This operator is considered pure by the type system but actually mutates the arguments specified by `inplace_indices`. This operator should not be used directly, but rather should be inserted by passes that have checked whether it is safe to perform operations in-place (i.e., none of the arguments specified as an output is aliased or is live after calling call_tir_inplace). Direct calls to this operator should be done for testing purposes only. Parameters ---------- gvar : GlobalVar The GlobalVar referring to a TIR PrimFunc. args : Expr The input arguments. inplace_indices : Union[int, List[int]] Specify which arguments should be used for in-place computations. If `inplace_indices` is a single integer, it will be made into a singleton list. Suppose `inplace_indices[i] = j`, where `j >= 0`. Then the `i`th output will be an alias of `args[j]`. If `inplace_indices[i] = -1`, then the `i`th output will be a freshly allocated tensor. At least one member of `inplace_indices` must not be -1. out_sinfo : Union[TensorStructInfo, List[TensorStructInfo]] The structure info of the call_tir_inplace output. It should be a single `TensorStructInfo` or a list of `TensorStructInfo`. Each one denotes the structure info of a returned tensor. If a list of `TensorStructInfo` is given, the result will be a tuple of `TensorStructInfo`. tir_vars : Optional[Union[ShapeExpr, Tuple[PrimExpr], List[PrimExpr]]] ShapeExpr representing a tuple of integers to unpack when calling func. Is null if not used Returns ------- ret: Call A call node for the call_tir operator. """args=_wrap_inline_arg_tuple(args)ifnotisinstance(inplace_indices,list):inplace_indices=[inplace_indices]ifnotisinstance(out_sinfo,list):out_sinfo=[out_sinfo]ifisinstance(tir_vars,(list,tuple)):tir_vars=ShapeExpr(tir_vars)return_ffi_api.call_tir_inplace(# type: ignoregvar,args,inplace_indices,out_sinfo,tir_vars,)
[文档]@args_converter.autodefcall_dps_packed(func:Union[str,Expr],args:Expr,out_sinfo:Union[TensorStructInfo,List[TensorStructInfo]],)->Call:""" Call a destination-passing-style packed function and return the output. Note: The called function is assumed to be _pure_ (other than modifying the designated output arguments). If the function _does_ result in other side effects, then the compiler may end up removing, reordering, or repeating those effects--no guarantees can be made. Parameters ---------- func : Union[str, Expr] The destination-passing-style function, can be ExternFunc. args : Expr The input arguments. out_sinfo : Union[TensorStructInfo, List[TensorStructInfo]] The structure info of the call_dps_packed output. It should be a single or a list of TensorStructInfo. Each one denotes the structure info of a returned tensor. Returns ------- ret: Call A call node for the call_dps_packed operator. """ifisinstance(func,str):func=ExternFunc(func)args=_wrap_inline_arg_tuple(args)ifnotisinstance(out_sinfo,list):out_sinfo=[out_sinfo]return_ffi_api.call_dps_packed(func,args,out_sinfo)# type: ignore
[文档]@args_converter.autodefcall_builtin_with_ctx(func:Union[str,Expr],args:Expr,*,sinfo_args:Optional[Union[StructInfo,List[StructInfo]]]=None,)->Call:"""Call a builtin function func. Parameters ---------- func : Expr The builtin function to be called. args : Expr The input arguments. sinfo_args: Optional[Union[StructInfo, List[StructInfo]]] The struct info arguments to the call node. Returns ------- ret: Call The created call node. """ifisinstance(func,str):func=ExternFunc(func)ifsinfo_argsisnotNoneandnotisinstance(sinfo_args,(list,tuple)):sinfo_args=[sinfo_args]return_ffi_api.call_builtin_with_ctx(# type: ignorefunc,args,sinfo_args,# type: ignore)
[文档]@args_converter.autodefmake_closure(func:Expr,args:Expr,)->Object:""" Create a closure with free variables and return the closure. Parameters ---------- func : Expr The closure, can be ExternFunc or PrimFunc. args : Expr The input arguments. Returns ------- ret: Object The VMClosure. """return_ffi_api.make_closure(func,args)# type: ignore
[文档]@args_converter.autodefinvoke_closure(closure:Expr,args:Expr,sinfo_args:Union[List[StructInfo],StructInfo],)->Call:""" Invoke a closure. Parameters ---------- closure : Expr The VMClosure object. args : Expr The input arguments. type_args: Union[List[StructInfo], StructInfo] The structure info arguments of the CallNode Returns ------- ret: Call A call to `invoke_closure`. """ifnotisinstance(sinfo_args,(list,tuple)):sinfo_args=[sinfo_args]return_ffi_api.invoke_closure(closure,args,sinfo_args)# type: ignore
defrender_object(val:tvm.Object)->str:""" Given a TVM Object, renders it in string form. Used for Relax printing and assertions. Parameters ---------- val: tvm.Object An object to render Returns ------- ret: str A string representing the value, ideally human-readable """ifisinstance(val,tvm.nd.NDArray):returnstr(val)ifisinstance(val,tvm.ir.Array):fields=", ".join([render_object(val[i])foriinrange(len(val))])returnf"({fields})"returnstr(val)@tvm.register_func("relax.run.shape_to_tensor")defrelax_shape_to_tensor(shape_tuple:tvm.runtime.ShapeTuple)->tvm.nd.NDArray:""" Takes a ShapeTuple and convert it to NDArray. Parameters ---------- shape_tuple: tvm.runtime.ShapeTuple Shape tuple that we want to convert to NDArray at runtime """returntvm.nd.array([int(v)forvinshape_tuple])@tvm.register_func("relax.run.print")defrelax_print(format_str:str,*format_args:tvm.Object)->None:""" Takes a list of values to print, formats with the given format string. If the format string is empty, simply prints. Call from TVM script like this: `relax.print(value1, value2, ..., valueN, format=format_str)` or `relax.print(value1, value2, ..., valueN) # format_str defaults to ""` Parameters ---------- format_str: str The last argument is a Python-style format string for printing the value format_args: List[Object] The values to print. """val_strs=map(render_object,format_args)ifformat_str=="":py_print(*val_strs)else:py_print(format_str.format(*val_strs))
[文档]defprint(*values:List[Expr],format:Union[str,Expr]="")->Expr:"""Print op to print the values Parameters ---------- values : List[Expr] The values to print. format: Union[str, Expr] The format string or StringImm. Returns ------- result : Expr A relax Call, which will print the value during runtime. """ifisinstance(format,str):format=StringImm(format)return_ffi_api.print(values,format)# type: ignore # pylint: disable=no-member
@tvm.register_func("relax.run.assert_op")defrelax_assert_op(condition:tvm.Object,format_str:str,*format_args:tvm.Object)->None:""" A variadic function. The first value serves as the assertion condition: If the condition is true, then the operator does nothing. If the condition is false, then the operator raises an assertion error. Arguments after the first value serve as format arguments for the error message; the last argument must be a format string for the error message (empty by default). If the format string is the empty string, then the error message will simply include a comma-separated list of the format arguments. The condition argument is not included in the format string. Parameters ---------- condition: tvm.Object The assertion condition. Must be a boolean scalar. format_str: str The last argument is a Python-style format string for printing the value format_args: List[tvm.Object] Values used for formatting the string. """ifnotisinstance(format_str,str):raiseValueError(f"The format string argument to assert must be a string, given {type(format_str)})")ifisinstance(condition,(bool,int)):val=conditionelifisinstance(condition,tvm.nd.NDArray):# may happen if the original program had unknown shape or dtype for the tensor's typedtype=condition.dtypeifdtype!="bool":raiseValueError(f"The condition must be a bool scalar, but given a {dtype} tensor")shape=condition.shapeiflen(shape)!=0:raiseValueError(f"The condition must be a scalar, but it has a shape of {shape}")val=condition.numpy()else:# should be guaranteed by the type systemraiseValueError(f"The condition for relax assert must be a bool, int, or NDArray, "f"but received a {type(condition)}.")ifnotval:error_message="Assertion Failed"ifformat_argsorformat_str!="":rendered=map(render_object,format_args)ifformat_str!="":error_message=format_str.format(*rendered)else:error_message=", ".join(rendered)raiseAssertionError(error_message)
[文档]defassert_op(condition:Union[Expr,PrimExpr],format_args:Optional[Union[Expr,List[Expr]]]=None,format:Union[str,Expr]="",)->Expr:""" Create a call to Relax's assert_op operation (`assert` is reserved in Python, so the name must be distinct). Parameters ---------- condition: Union[Expr, PrimExpr] The assertion condition. format_args: Optional[Union[Expr, List[Expr]]] Format arguments for the error message if the condition fails. format: Union[str, Expr] The format string or StringImm for the error message. Returns ------- result : Expr A Call to the Relax assert operation. """ifnotisinstance(condition,Expr):condition=tvm.relax.PrimValue(condition)ifformat_argsisNone:format_args=[]elifisinstance(format_args,Expr):format_args=[format_args]ifisinstance(format,str):format=StringImm(format)return_ffi_api.assert_op(condition,format_args,format)# type: ignore
[文档]defshape_of(expr:Expr)->Expr:"""Get shape of a tensor. Parameters ---------- expr : Expr The input Expr. Returns ------- result : Expr A relax Call, which gets the shape of the input """return_ffi_api.shape_of(expr)# type: ignore # pylint: disable=no-member
[文档]deftensor_to_shape(expr:Expr)->Expr:"""Convert tensor to shape expr. Parameters ---------- expr : Expr The input Expr Returns ------- result : Expr A relax Call, which transforms the tensor values to the shape """return_ffi_api.tensor_to_shape(expr)# type: ignore # pylint: disable=no-member
[文档]defshape_to_tensor(expr:Expr)->Expr:"""Convert shape to tensor expr. Parameters ---------- expr : Expr The input Expr Returns ------- result : Expr A relax Call, which transforms the shape values to the tensor """return_ffi_api.shape_to_tensor(expr)# type: ignore # pylint: disable=no-member
[文档]@args_converter.autodefcall_inplace_packed(func:Union[str,ExternFunc,GlobalVar],*args:Expr,inplace_indices:Union[int,List[int]],sinfo_args:Union[StructInfo,List[StructInfo]],)->Expr:""" Construct a call to a packed function that consumes some of its arguments "in-place" and returns the mutated arguments (aliased), but should be considered to be otherwise pure. The `inplace_indices` argument indicates which of the outputs are mutated arguments. The resulting call will have the same semantics as calling the packed function directly. Note: This should be used for cases when the user knows that calling the packed function with these arguments will **in reality** not cause any other side effects. If it is used for a call that **does** result in other side effects, then the compiler may end up removing, reordering, or repeating that call, with no guarantees made about any side effects from the callee. Warning: This operator as treated as pure by the type system even though it *is* performing side effects (mutating some arguments). It is therefore incumbent upon the user to ensure that it is being used safely (viz., that mutated arguments are not live after the mutation, that they do not alias values live after the mutation). Parameters ---------- func : Union[str, ExternFunc] The name (global symbol) for a PackedFunc or an ExternFunc node. args: Expr The arguments for the PackedFunc. inplace_indices : Union[int, List[int]] Specify which arguments should be used for in-place computations. If `inplace_indices` is a single integer, it will be made into a singleton list. Suppose `inplace_indices[i] = j`, where `j >= 0`. Then the `i`th output will be an alias of `args[j]`. If `inplace_indices[i] = -1`, then the `i`th output will be a freshly allocated tensor. At least one member of `inplace_indices` must not be -1. sinfo_args: Union[StructInfo, List[StructInfo]] The list of structure info arguments (giving the structural info for the returned value). Returns ------- result : Expr A Relax call, corresponding to `call_pure_packed(ExternFunc(func), args, DictAttrs(kwargs), sinfo_args)` """ifisinstance(func,ExternFunc):func=func.global_symbolop=ExternFunc(func)ifsinfo_argsisNone:raiseValueError("R.call_pure_packed is required to have type_args")ifisinstance(sinfo_args,tuple):# type: ignoresinfo_args=list(sinfo_args)elifnotisinstance(sinfo_args,list):sinfo_args=[sinfo_args]ifnotisinstance(inplace_indices,list):inplace_indices=[inplace_indices]return_ffi_api.call_inplace_packed(op,args,inplace_indices,sinfo_args)# type: ignore # pylint: disable=no-member
[文档]@args_converter.autodefcall_pure_packed(func:Union[str,ExternFunc,GlobalVar],*args:Expr,sinfo_args:Union[StructInfo,List[StructInfo]],)->Expr:""" Construct a call to a packed function that should be treated as pure, even though packed calls are normally not treated as pure. The resulting call will have the same semantics as calling the packed function directly. Note: This should be used for cases when the user knows that calling the packed function with these arguments will **in reality** not cause any side effects. If it is used for a call that **does** result in side effects, then the compiler may end up removing, reordering, or repeating that call, with no guarantees made about any side effects from the callee. Parameters ---------- func : Union[str, ExternFunc] The name (global symbol) for a PackedFunc or an ExternFunc node. args: Expr The arguments for the PackedFunc. sinfo_args: Union[StructInfo, List[StructInfo]] The list of structure info arguments (giving the structural info for the returned value). Returns ------- result : Expr A Relax call, corresponding to `call_pure_packed(ExternFunc(func), args, DictAttrs(kwargs), sinfo_args)` """ifisinstance(func,ExternFunc):func=func.global_symbolop=ExternFunc(func)ifsinfo_argsisNone:raiseValueError("R.call_pure_packed is required to have type_args")ifisinstance(sinfo_args,tuple):# type: ignoresinfo_args=list(sinfo_args)elifnotisinstance(sinfo_args,list):sinfo_args=[sinfo_args]sinfo_args=[sinfo()ifcallable(sinfo)elsesinfo.asobject()ifisinstance(sinfo,ObjectGeneric)elsesinfoforsinfoinsinfo_args]# note: if we need attributes, we can also take them herereturn_ffi_api.call_pure_packed(op,args,None,sinfo_args)# type: ignore # pylint: disable=no-member
[文档]@args_converter.autodefinvoke_pure_closure(closure:Expr,args:Expr,sinfo_args:Union[List[StructInfo],StructInfo],)->Call:""" Invoke a closure and indicate to the compiler that it is pure. Note: This should be used for cases when the user knows that calling the closure with these arguments will **in reality** not cause any side effects. If it is used for a call that _does_ result in side effects, then the compiler may end up removing, reordering, or repeating that call, with no guarantees made about any side effects from the callee. Parameters ---------- closure : Expr The VMClosure object. args : Expr The input arguments. type_args: Union[List[StructInfo], StructInfo] The structure info arguments of the CallNode Returns ------- ret: Call A call to `invoke_pure_closure`. """ifnotisinstance(sinfo_args,(list,tuple)):sinfo_args=[sinfo_args]return_ffi_api.invoke_pure_closure(closure,args,sinfo_args)# type: ignore
[文档]defto_vdevice(data,dst_vdevice)->Expr:"""Copy data to the destination device. This operator helps data transferring between difference devices for heterogeneous execution. Parameters ---------- data : Expr The tensor to be copied. dst_device : VDevice The destination device where the data is copied to. Returns ------- result : Expr The copied result. """return_ffi_api.to_vdevice(data,dst_vdevice)# type: ignore
[文档]defhint_on_device(data,dst_vdevice)->Expr:"""It provides a hint specifying the device on which the input data should be executed. This hint is utilized by RealizeVDevice to propagate the virtual device." Parameters ---------- data : Expr The tensor to be copied. dst_device : VDevice The destination device where the data is supposed to be executed. Returns ------- result : Expr The result. """return_ffi_api.hint_on_device(data,dst_vdevice)# type: ignore