# Licensed to the Apache Software Foundation (ASF) under one# or more contributor license agreements. See the NOTICE file# distributed with this work for additional information# regarding copyright ownership. The ASF licenses this file# to you under the Apache License, Version 2.0 (the# "License"); you may not use this file except in compliance# with the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing,# software distributed under the License is distributed on an# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY# KIND, either express or implied. See the License for the# specific language governing permissions and limitations# under the License."""The core infra for nn.Module, which includes the following pieces:- Tensor, a wrapper on top of relax.Expr whose struct_info is a TensorStructInfo, providing more convenient access shape and dtype information. Tensor is always symbolic and not bound to any concrete values.- Parameter, a special tensor which could be bound or not bound to concrete values.- Module, a container of nn.Parameters and sub nn.Modules.- Effect, a non-user-facing class that encloses potential side effects, for example, IO, impure external function callings, inplace mutation, etc."""fromcollectionsimportOrderedDictfromtypingimport(TYPE_CHECKING,Any,Callable,Dict,Iterator,List,Optional,Sequence,Tuple,Union,)importnumpyasnp# type: ignorefromtvmimporttirfromtvm.irimportIRModulefromtvm.ir.transformimportPassfromtvm.runtimeimportDevice,NDArrayfromtvm.runtimeimportdeviceasas_devicefromtvm.runtimeimportndarrayfromtvm.runtime.relax_vmimportVirtualMachinefromtvm.targetimportTargetfrom....importrelaxasrxfrom...block_builderimportBlockBuilderfrom...struct_infoimport(ObjectStructInfo,ShapeStructInfo,TensorStructInfo,TupleStructInfo,)from._tensor_opimport_TensorOpfrom.subroutineimportSubroutineMixinifTYPE_CHECKING:importtorch# type: ignorefrom.importspecas_specfrom.externimportExternModule_DEFAULT_DTYPE="float32"
[文档]defget_default_dtype()->str:"""Get the default parameter dtype if not specified. By default it is float32. Returns ------- dtype : str The default dtype """return_DEFAULT_DTYPE
defset_default_dtype(dtype:str)->None:"""Set the default parameter dtype. Parameters ---------- dtype : str The default dtype to be set """global_DEFAULT_DTYPE# pylint: disable=global-statement_DEFAULT_DTYPE=dtype
[文档]classTensor(_TensorOp):"""A wrapper on top of relax.Expr whose struct_info is a TensorStructInfo, providing more convenient access shape and dtype information. Tensor is always symbolc and not bound to any concrete values. Shape and dtype inference is done eagerly upon tensor creation, i.e. when operators are applied on tensors, the shape and dtype information is already available. """_expr:rx.Exprdef__init__(self,*,_expr:rx.Expr)->None:"""Private constructor. Tensor is never supposed to be constructed directly by users."""def_check_tensor(expr:rx.Expr)->None:assertexpr.struct_info_isnotNoneassertisinstance(expr.struct_info,TensorStructInfo)assertexpr.struct_info.ndim!=-1assertexpr.struct_info.shapeisnotNoneassertexpr.struct_info.shape.struct_info_isnotNoneassertisinstance(expr.struct_info.shape.struct_info,ShapeStructInfo)assertexpr.struct_info.shape.struct_info.valuesisnotNone_check_tensor(_expr)self._expr=_expr
[文档]@staticmethoddeffrom_const(data)->"Tensor":"""Construct a tensor from numpy constants."""returnTensor(_expr=rx.const(data))
[文档]@staticmethoddeffrom_scalar(data:Union[int,float],dtype:str)->"Tensor":"""Construct a tensor from a scalar with dtype specified."""returnTensor(_expr=rx.const(data,dtype=dtype))
[文档]@staticmethoddeffrom_struct_info(struct_info:rx.TensorStructInfo,name:str="tensor")->"Tensor":"""Construct a nn.Tensor from relax TensorStructInfo"""returnTensor(_expr=rx.Var(name_hint=name,struct_info=struct_info,))
[文档]@staticmethoddefplaceholder(shape:Sequence[Union[int,str,tir.PrimExpr]],dtype:str,name:str="tensor",)->"Tensor":"""Create a placeholder tensor with given shape and dtype. A placeholder tensor should never be created directly by users in usual cases, and the only exception is to indicate the shape/dtype of return values of an external function. If shape is a string `name`, we create a symbolic shape `tvm.tir.Var(name, "int64")`. """new_shape=[]forexprinshape:ifisinstance(expr,(int,tir.IntImm)):expr=int(expr)assertexpr>=0new_shape.append(expr)continueifisinstance(expr,str):expr=tir.Var(expr,"int64")new_shape.append(expr)continueifnotisinstance(expr,tir.PrimExpr):raiseTypeError(f"Invalid shape: {shape}")assertexpr.dtype=="int64"new_shape.append(expr)returnTensor(_expr=rx.Var(name_hint=name,struct_info=TensorStructInfo(shape=new_shape,# type: ignore[arg-type]dtype=dtype,),))
@propertydefshape(self)->List[Union[int,tir.PrimExpr]]:"""Returns the shape of the tensor as a list of integers. An integer can be a python int or tvm.tir.PrimExpr, depending on whether the shape is fully static, for example, [1, 2, tvm.tir.Var("n")] is a valid shape where the last dimension is dynamic while the first two dimensions are always static constants. Returns ------- shape : List[Union[int, tir.PrimExpr]] The shape of the tensor """def_simplify(expr:tir.PrimExpr):returnexpr.valueifisinstance(expr,tir.IntImm)elseexprshape_sinfo:ShapeStructInfo=self._expr.struct_info.shape.struct_inforeturn[_simplify(x)forxinshape_sinfo.values]@propertydefndim(self)->int:"""Returns the number of dimensions of the tensor. Returns ------- ndim : int The number of dimensions of the tensor """returnself._expr.struct_info.ndim@propertydefdtype(self)->str:"""Returns the data type of the tensor. Returns ------- dtype : str The data type of the tensor """returnself._expr.struct_info.dtypedef__repr__(self)->str:returnf'Tensor({self.shape}, "{self.dtype}")'
[文档]classParameter(Tensor):"""A parameter represents the weight of a neural network layer. It is a special tensor which could be bound or not bound to concrete values. If a parameter is bound to a concrete value, it is called a bound parameter, otherwise it is called an unbound parameter. """_data:Optional[NDArray]attrs:Dict[str,Any]def__init__(self,shape:Sequence[Union[int,str,tir.PrimExpr]],dtype:Optional[str]=None,)->None:"""Create a parameter with given shape and dtype. The parameter is not bound to any concrete values. Parameters ---------- shape : Sequence[Union[int, str, tir.PrimExpr]] The shape of the parameter. If it is a string `name`, we create a symbolic shape `tvm.tir.Var(name, "int64")`. dtype : Optional[str] The data type of the parameter. If not specified, the default dtype will be used. """ifdtypeisNone:dtype=get_default_dtype()super().__init__(_expr=Tensor.placeholder(shape,dtype=dtype,name="param")._expr)self._data=Noneself.attrs=OrderedDict()@propertydefdata(self)->Optional[NDArray]:"""Returns the concrete value of the parameter if it is bound to a concrete value, otherwise returns None. The returned value is a tvm.runtime.NDArray."""returnself._data@data.setterdefdata(self,data:Union[None,NDArray,np.ndarray,"torch.Tensor"])->None:"""Set the concrete value of the parameter. The data should be one of the following: - None: unbind the parameter to concrete values - tvm.runtime.NDArray - numpy.ndarray - torch.Tensor and any other DLPack-compliant tensors """ifdataisNone:self._data=datareturn# Try to do zero-copy if possibleifisinstance(data,NDArray):passelifisinstance(data,np.ndarray):data=ndarray.array(data)elifhasattr(data,"__dlpack__"):data=_from_dlpack(data)else:raiseTypeError(f"Unsupported data type: {type(data)}")ifdata.shape!=tuple(self.shape):raiseValueError(f"Shape mismatch: expected {tuple(self.shape)}, got {data.shape}")ifdata.dtype!=self.dtype:raiseValueError(f"Dtype mismatch: expected {self.dtype}, got {data.dtype}")self._data=data
[文档]defto(self,dtype:Optional[str]=None)->None:# pylint: disable=invalid-name"""Change the dtype of the parameter if it is not bound to any concrete data"""ifdtypeisnotNone:ifself._dataisnotNone:raiseValueError("Changing the dtype of a Parameter that has been bound to concrete ""data is not recommended. It might lead to potential precision loss ""or other unexpected behaviors")self._expr=Tensor.placeholder(# pylint: disable=protected-accessself.shape,dtype=dtype,name="param")._expr
[文档]classObject:"""A wrapper on top of relax.Expr whose struct_info is the base ObjectStructInfo (rather than any its subclass). Object effectively represents non-tensor frontend components such as KV caches. """_expr:rx.Vardef__init__(self,*,_expr:rx.Expr,_name:str)->None:"""Private constructor. Object is never supposed to be constructed directly by users."""ifnotisinstance(_expr,rx.Var):_expr=BlockBuilder.current().emit(_expr,_name)self._expr=_exprassertisinstance(self._expr.struct_info,ObjectStructInfo)
[文档]classEffect:"""Effect is a special non-user facing type that is used to represent operations with side effects, for example, print. It is used to represent the output of a computation. """
[文档]defemit_init(self,name_hint:str,builder:BlockBuilder)->List[rx.DataflowVar]:"""Emit the initialization of the effect. This method is called by the compiler to initialize the effect."""raiseNotImplementedError
[文档]defcreate(self,name_hint:str)->List[rx.Var]:"""Create the implicit inputs to a relax.Function that represents the side effect"""raiseNotImplementedError
[文档]defset_state(self,state_vars:List[rx.Var])->None:"""Set the variables that represents the effect"""raiseNotImplementedError
[文档]deffinalize(self)->List[rx.Var]:"""finalize the effect as the implicit return value of a relax.Function"""raiseNotImplementedError
[文档]defto(self,dtype:Optional[str]=None)->None:# pylint: disable=invalid-name"""Convert the effect to specific dtype. Usually it is no-op for most of the effects"""
[文档]classModule(SubroutineMixin):"""Base class for neural network components. Subclass it to build your models. Modules can nest within each other in a tree structure using regular attribute assignment."""
[文档]defnamed_parameters(self,prefix:str="")->Iterator[Tuple[str,Parameter]]:"""This method provides an iterator over module parameters, yielding both the parameter name and its corresponding value. Parameters ---------- prefix : str Prefix to prepend to all parameter names. Yields ------ (str, Parameter) - Tuple containing the name and parameter """yield from_attribute_finder(self,prefix,condition_yield=lambdax:isinstance(x,Parameter))
[文档]defparameters(self)->Iterator[Parameter]:"""This method provides an iterator over module parameters, yielding only the Parameter value. Yields ------ Parameter - The module's parameter """for_,paraminself.named_parameters():yieldparam
[文档]defstate_dict(self,*,prefix:str="",destination:Optional[Dict[str,Parameter]]=None)->Dict[str,Parameter]:"""Returns a dictionary containing references to the whole state of the module. Parameters ---------- prefix : str Prefix to prepend to all parameter names. destination : Optional[Dict[str, Parameter]] Dictionary to which state will be saved. If None, a new dictionary is created. Returns ------- dict : Dict[str, Parameter] a dictionary containing a whole state of the module """ifdestinationisNone:destination=OrderedDict()forname,paramin_attribute_finder(self,prefix,condition_yield=lambdax:isinstance(x,Parameter)):destination[name]=paramreturndestination
[文档]defload_state_dict(self,state_dict:Dict[str,Parameter],strict:bool=True)->Tuple[List[str],List[str]]:"""This function copies parameters and buffers from the state_dict into the current module and its descendants. If `strict` is set to True, the keys in the `state_dict` must exactly match the keys returned by the `state_dict()` function of this module. Parameters ---------- state_dict : Dict[str, Parameter] A dictionary containing a whole state of the module strict : bool = True Whether to strictly enforce that the keys in `state_dict` match the keys returned by this module's `state_dict()` function. Returns ------- (missing_keys, unexpected_keys) : Tuple[List[str], List[str]] A tuple of two lists: the missing keys and the unexpected keys. """self_state_dict=self.state_dict()missing_keys:List[str]=[]unexpected_keys:List[str]=[]forkey,valueinstate_dict.items():ifkeynotinself_state_dict:unexpected_keys.append(key)continueifvalue.dataisNone:raiseValueError(f"Parameter {key} is not set to any concrete tensor")self_state_dict.pop(key).data=value.datamissing_keys=list(self_state_dict.keys())ifstrictand(missing_keysorunexpected_keys):raiseKeyError(f"Missing keys: {missing_keys}, Unexpected keys: {unexpected_keys}")returnmissing_keys,unexpected_keys
def__call__(self,*args:Any,**kwargs:Any)->Any:"""Call the module with the given inputs and returns the output."""ifnothasattr(self,"forward"):raiseNotImplementedError(f"Module {type(self)} does not have a `forward` method")returnself.forward(*args,**kwargs)# pylint: disable=no-member
[文档]defto(self,dtype:Optional[str]=None)->None:# pylint: disable=invalid-name"""Convert the module to specific dtype recursively"""for_,iteminself.__dict__.items():ifhasattr(item,"to")andcallable(item.to):item.to(dtype=dtype)ifdtypeisnotNoneandisinstance(getattr(self,"dtype",None),str):self.dtype=dtype# pylint: disable=attribute-defined-outside-init
[文档]defexport_tvm(self,spec:"_spec.ModuleSpecType",debug:bool=False,allow_extern:bool=False,)->Union[Tuple[IRModule,List[Tuple[str,Parameter]],],Tuple[IRModule,List[Tuple[str,Parameter]],List["ExternModule"],],]:"""Export the module to TVM IRModule and parameters Parameters ---------- spec : _spec.ModuleSpecType A dictionary mapping each input name to a specification that defines the inputs shape and dtype. debug : bool If set to True, then the exported module will support effects. This enables things like printing in the graph. Returns ------- irmodule : tvm.ir.IRModule The converted tvm IR representation of the model. params : List[Tuple[str, Parameter]] A list of Parameters corresponding to the weights of the model. ext_mods : List[nn.ExternModule] A list of ExternModules that are used in the model. """# pylint: disable=import-outside-toplevelfrom.importspecas_specfrom.exporterimportExporter# pylint: enable=import-outside-toplevelspec=_spec.ModuleSpec.from_raw(spec,self)mod,params,ext_mods=Exporter(debug=debug).build(spec)ifallow_extern:returnmod,params,ext_modsifext_mods:raiseValueError("`ExternModule`(s) exist when they are not allowed. ""Turn on flag `allow_extern` to allow.")returnmod,params
[文档]defjit(# pylint: disable=too-many-argumentsself,spec:"_spec.ModuleSpec",device:Union[str,Device]="cpu",pipeline:Union[None,str,Pass]="default_build",out_format:str="torch",debug:bool=False,)->Any:"""Just-in-time compilation of a nn.model to an executable"""def_compile(spec,device,pipeline,debug):# pylint: disable=import-outside-toplevelfrom...transformimportAttachExternModulesfrom...vm_buildimportbuildasrelax_buildfrom.importspecas_specfrom.exporterimportExporter# pylint: enable=import-outside-toplevelspec=_spec.ModuleSpec.from_raw(spec,self)mod,params,ext_mods=Exporter(debug=debug).build(spec)mod=AttachExternModules(ext_mods)(mod)# pylint: disable=not-callablevm=VirtualMachine(# pylint: disable=invalid-namerelax_build(mod,target=Target.from_device(device),relax_pipeline=pipeline,),device,)params=_param_to_ndarray(params,device)returnspec,vm,paramsdevice=as_device(device)spec,vm,params=_compile(spec,device,pipeline,debug)# pylint: disable=invalid-nameifout_format=="torch":from.importtorch# pylint: disable=import-outside-toplevelreturntorch.TorchModule(spec=spec,params=params,vm=vm)raiseValueError(f"Unknown out_format: {out_format}")
[文档]classModuleList(Module):"""Holds submodules in a list."""def__init__(self,modules:List[Module]):self.modules=modulesdef__iter__(self):returniter(self.modules)def__getitem__(self,idx:int)->Module:returnself.modules[idx]def__setitem__(self,idx:int,module:Module)->None:self.modules[idx]=moduledef__len__(self):returnlen(self.modules)
[文档]defappend(self,module:Module):"""Add a module to the end of the ModuleList"""self.modules.append(module)
[文档]defforward(self,x):# pylint: disable=invalid-name"""Feed-forward pass of the module"""formoduleinself.modules:x=module(x)returnx
[文档]defwrap_nested(expr:rx.Expr,name:str)->Union[Tensor,Sequence[Tensor]]:"""Wrap the given relax.Expr, emit it using the current BlockBuilder, and automatically handle nested cases if the expr represents a Tuple. Parameters ---------- expr : relax.Expr The Expr to be wrapped. name : str Name hint. Returns ------- result : Union[Tensor, Tuple[Tensor]] The computed result. """ifnotisinstance(expr,rx.DataflowVar):expr=BlockBuilder.current().emit(expr,name)ifisinstance(expr.struct_info_,TensorStructInfo):returnTensor(_expr=expr)ifisinstance(expr.struct_info_,TupleStructInfo):returntuple(wrap_nested(# type: ignorerx.TupleGetItem(expr,i),name=f"{name}.{i}",)foriinrange(len(expr.struct_info_.fields)))raiseTypeError(f"Unsupported return type: {expr.struct_info_}")
def_attribute_finder(root:Module,prefix:str,condition_yield:Callable[[Any],bool]):"""Find attributes that satisfy the condition recursively"""ifisinstance(root,ModuleList):fori,subiteminenumerate(root):yield from_attribute_finder(subitem,prefix+f"{i}.",condition_yield)returnforname,iteminroot.__dict__.items():ifcondition_yield(item):yieldprefix+name,itemelifisinstance(item,ModuleList):yield from_attribute_finder(item,prefix+name+".",condition_yield,)elifisinstance(item,Module):yield from_attribute_finder(item,prefix+name+".",condition_yield,)def_from_dlpack(tensor)->NDArray:try:returnndarray.from_dlpack(tensor)exceptRuntimeError:pass# special logic for PyTorchdevice_type=tensor.device.typedevice_id=tensor.device.indexor0returnndarray.array(tensor.numpy(),device=Device(Device.STR2MASK[device_type],device_id,),)def_param_to_ndarray(params:List[Tuple[str,Parameter]],device:Device)->List[NDArray]:results=[]missing=[]forname,paraminparams:ifparam.dataisNone:missing.append(name)else:results.append(param.data.copyto(target=device))ifmissing:raiseValueError(f"Parameters are not set to any concrete values: {', '.join(missing)}")returnresults