# Licensed to the Apache Software Foundation (ASF) under one# or more contributor license agreements. See the NOTICE file# distributed with this work for additional information# regarding copyright ownership. The ASF licenses this file# to you under the Apache License, Version 2.0 (the# "License"); you may not use this file except in compliance# with the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing,# software distributed under the License is distributed on an# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY# KIND, either express or implied. See the License for the# specific language governing permissions and limitations# under the License.# pylint: disable=too-many-lines,invalid-name,protected-access,redefined-outer-name# pylint: disable=redefined-builtin"""nn.Tensor operators."""importinspectimportmathfromtypingimportAny,Callable,Dict,List,Optional,Sequence,Tuple,TypeVar,Unionimportnumpyasnpfromtvmimporttefromtvmimporttiras_tirfromtvm.scriptimporttirasTfrom...importexprasrxfrom...importopas_opfrom...block_builderimportBlockBuilderfrom.coreimportTensor,get_default_dtype,wrap_nestedIntExpr=Union[int,_tir.PrimExpr]
[文档]defunsqueeze(x:Tensor,dim:int,name:str="unsqueeze")->Tensor:"""Add a new axis to a tensor Parameters ---------- x : Tensor Input tensor to expand. dim : int Dimension to expand. name : str Name hint for this operator. Returns ------- result : Tensor Expanded result. """returnwrap_nested(_op.expand_dims(x._expr,dim),name)
[文档]defconcat(x:List[Tensor],dim:int,name:str="concat")->Tensor:"""Concatenate a list of tensors along an axis. Parameters ---------- x : List[Tensor] List of tensors to concatenate. dim : int Dimension to concatenate upon. name : str Name hint for this operator. Returns ------- result : Tensor Expanded result. """# Convert tensors to expressions.x=[t._exprfortinx]returnwrap_nested(_op.concat(x,dim),name)
[文档]defadd(a:Tensor,b:Tensor,name:str="add")->Tensor:"""Addition with numpy-style broadcasting. Parameters ---------- a : Tensor The first input tensor. b : Tensor The second input tensor. name : str Name hint. Returns ------- result : Tensor The computed result. Examples -------- .. code:: python c = add(a, b) """returnwrap_nested(_op.add(a._expr,b._expr),name)
[文档]defsubtract(a:Tensor,b:Tensor,name:str="subtract")->Tensor:"""Subtraction with numpy-style broadcasting. Parameters ---------- a : Tensor The first input tensor. b : Tensor The second input tensor. name : str Name hint. Returns ------- result : Tensor The computed result. Examples -------- .. code:: python c = subtract(a, b) """returnwrap_nested(_op.subtract(a._expr,b._expr),name)
[文档]defmultiply(a:Tensor,b:Tensor,name:str="mul")->Tensor:"""Multiplication with numpy-style broadcasting. Parameters ---------- a : Tensor The first input tensor. b : Tensor The second input tensor. name : str Name hint. Returns ------- result : Tensor The computed result. Examples -------- .. code:: python c = multiply(a, b) """returnwrap_nested(_op.multiply(a._expr,b._expr),name)
[文档]defdivide(a:Tensor,b:Tensor,name:str="divide")->Tensor:"""Division with numpy-style broadcasting. Parameters ---------- a : Tensor The first input tensor. b : Tensor The second input tensor. name : str Name hint. Returns ------- result : Tensor The computed result. Examples -------- .. code:: python c = divide(a, b) """returnwrap_nested(_op.divide(a._expr,b._expr),name)
[文档]defchunk(x:Tensor,chunks:int,dim:int=0,name:str="chunk")->Tensor:"""Split a tensor along dim into the specified number of chunks. Parameters ---------- x : Tensor Input tensor to be split. chunks : int Number of pieces to slice x into. dim : int Which dimension to split x. name : str Name hint for this operation. Returns ------- result : Tuple[Tensor] A tuple with chunks elements containing slices of x. """returnwrap_nested(_op.split(x._expr,chunks,dim),name)
[文档]defsum(x:Tensor,axis:Optional[Union[int,List[int]]]=None,keepdims:bool=False,name:str="sum",)->Tensor:"""Computes the sum of tensor elements over given axes. Parameters ---------- x : Tensor The input data tensor axis : Optional[Union[int, List[int]]] Axis or axes along which a sum is performed. The default, axis=None, will sum all of the elements of the input tensor. Negative indexing is supported. keepdims : bool If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the input tensor. name : str Name hint for this operation. Returns ------- result : Tensor The computed result. """returnwrap_nested(_op.sum(x._expr,axis,keepdims),name)
[文档]defmax(x:Tensor,axis:Optional[Union[int,List[int]]]=None,keepdims:bool=False,name:str="max",)->Tensor:"""Computes the max of tensor elements over given axes. Parameters ---------- x : Tensor The input data tensor axis : Optional[Union[int, List[int]]] Axis or axes along which a max is performed. The default, axis=None, will max all of the elements of the input tensor. Negative indexing is supported. keepdims : bool If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the input tensor. name : str Name hint for this operation. Returns ------- result : Tensor The computed result. """returnwrap_nested(_op.max(x._expr,axis,keepdims),name)
[文档]defmin(x:Tensor,axis:Optional[Union[int,List[int]]]=None,keepdims:bool=False,name:str="min",)->Tensor:"""Computes the min of tensor elements over given axes. Parameters ---------- x : Tensor The input data tensor axis : Optional[Union[int, List[int]]] Axis or axes along which a min is performed. The default, axis=None, will min all of the elements of the input tensor. Negative indexing is supported. keepdims : bool If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the input tensor. name : str Name hint for this operation. Returns ------- result : Tensor The computed result. """returnwrap_nested(_op.min(x._expr,axis,keepdims),name)
[文档]defmatmul(a:Tensor,b:Tensor,out_dtype:Optional[str]=None,name:str="matmul")->Tensor:"""General matrix multiplication of two tensors, with broadcasting on batched dimensions. The semantics and output shape deduction rule is specified as https://data-apis.org/array-api/latest/API_specification/generated/array_api.matmul.html. Parameters ---------- a : Tensor The first input tensor. b : Tensor The second input tensor. out_dtype: Optional[Union[str, DataType]] The data type of the matmul result. When it is not specified, the output dtype will be the same as input dtype. name : str Name hint. Returns ------- result : Tensor The computed result. Examples -------- .. code:: python c = matmul(a, b) """returnwrap_nested(_op.matmul(a._expr,b._expr,out_dtype=out_dtype),name)
[文档]defconv1d(x:Tensor,weight:Tensor,bias:Optional[Tensor]=None,stride:Optional[Union[int,Tuple]]=1,padding:Optional[Union[int,Tuple,str]]=0,dilation:Optional[Union[int,Tuple]]=1,groups:Optional[int]=1,name:str="conv1d",)->Tensor:r"""1D convolution. This operator takes the weight as the 1D convolution kernel and convolves it with data to produce an output. In the default case, where the data_layout is `NCW` and kernel_layout is `OIW`, conv1d takes in a data Tensor with shape `(batch_size, in_channels, width)`, and a weight Tensor with shape `(channels, in_channels, kernel_w)`, where `kernel_w` is the length of the `W` kernel dimension, to produce an output Tensor with the following rule: .. math:: \mbox{out}[b, c, x] = \sum_{dx, k} \mbox{data}[b, k, \mbox{strides} * x + dx] * \mbox{weight}[c, k, dx] Padding and dilation are applied to data and weight respectively before the computation. This operator accepts data layout specification. Semantically, the operator will convert the layout to the canonical layout (`NCW` for data and `OIW` for weight), perform the computation, then convert to the out_layout. Parameters ---------- x : Tensor The input data to the operator. weight : Tensor The weight expressions. bias : Optional[Tensor] Optional bias tensor of shape [O]. strides : Optional[Union[int, Tuple]] The strides of convolution. It is required to have length 1. padding : Optional[Union[int, Tuple, str]] The padding of convolution on both sides of inputs before convolution. It is required to have length either 1 or 2. dilation : Optional[Union[int, Tuple]] Specifies the dilation rate to be used for dilated convolution. It is required to have length 1. groups : Optional[int] Number of groups to split the input into for grouped convolution. The number of input and output channels should be divisible by the number of groups. name : str Name hint. Returns ------- result : Tensor The computed result. """conv_out=_op.nn.conv1d(data=x._expr,weight=weight._expr,strides=stride,padding=padding,dilation=dilation,groups=groups,)ifbiasisnotNone:conv_out=_op.add(conv_out,_op.reshape(bias._expr,[1,-1,1]))returnwrap_nested(conv_out,name)
[文档]defconv2d(x:Tensor,weight:Tensor,bias:Optional[Tensor]=None,stride:Optional[Union[int,Tuple]]=1,padding:Optional[Union[int,Tuple,str]]=0,dilation:Optional[Union[int,Tuple]]=1,groups:Optional[int]=1,data_layout:Optional[str]="NCHW",name:str="conv2d",)->Tensor:"""Applies a 2D convolution over an input image composed of sevaral input planes Parameters ---------- x : Tensor Input tensor of shape [B, N, H, W] weight : Tensor Filters of shape [O, N/groups, kH, kW] bias : Optional[Tensor] Optional bias tensor of shape [O]. stride : Optional[Union[int, Tuple]] The stride of the convolving kernel. Can be a single number or tuple of (sH, sW). padding : Optional[[Union[int, Tuple]]] Implicit paddings on both sides of the input. dilation : Optional[Union[int, Tuple]] The spacing between kernel elements. Can be a single number of tuple (dH, dW). groups : Optional[int] Split input into a number of groups. data_layout : Optional[str] Layout of input and output data. name : str Name hint. Returns ------- result : Tensor The computed result with shape [B, O, oH, oW]. """conv_out=_op.nn.conv2d(data=x._expr,weight=weight._expr,strides=stride,padding=padding,dilation=dilation,data_layout=data_layout,groups=groups,)ifbiasisnotNone:ifdata_layout=="NCHW":conv_out=_op.add(conv_out,_op.reshape(bias._expr,[1,-1,1,1]))elifdata_layout=="NHWC":conv_out=_op.add(conv_out,_op.reshape(bias._expr,[1,1,1,-1]))else:raiseNotImplementedError(f"Dont know how to handle layout {data_layout}.")returnwrap_nested(conv_out,name)
[文档]defconv3d(x:Tensor,weight:Tensor,bias:Optional[Tensor]=None,stride:Optional[Union[int,Tuple]]=1,padding:Optional[Union[int,Tuple,str]]=0,dilation:Optional[Union[int,Tuple]]=1,groups:Optional[int]=1,data_layout:Optional[str]="NCDHW",name:str="conv3d",)->Tensor:"""Applies a 3D convolution over an input image composed of sevaral input planes Parameters ---------- x : Tensor Input tensor of shape [B, N, D, H, W] weight : Tensor Filters of shape [O, N/groups, kD, kH, kW] bias : Optional[Tensor] Optional bias tensor of shape [O]. stride : Optional[Union[int, Tuple]] The stride of the convolving kernel. Can be a single number or tuple of (sD, sH, sW). padding : Optional[[Union[int, Tuple]]] Implicit paddings on both sides of the input. dilation : Optional[Union[int, Tuple]] The spacing between kernel elements. Can be a single number of tuple (dD, dH, dW). groups : Optional[int] Split input into a number of groups. data_layout : Optional[str] Optional layout of the input and output data. name : str Name hint. Returns ------- result : Tensor The computed result with shape [B, O, oD, oH, oW]. """conv_out=_op.nn.conv3d(data=x._expr,weight=weight._expr,strides=stride,padding=padding,dilation=dilation,groups=groups,data_layout=data_layout,)ifbiasisnotNone:ifdata_layout=="NCDHW":conv_out=_op.add(conv_out,_op.reshape(bias._expr,[1,-1,1,1,1]))elifdata_layout=="NDHWC":conv_out=_op.add(conv_out,_op.reshape(bias._expr,[1,1,1,1,-1]))else:raiseNotImplementedError(f"Dont know how to handle layout {data_layout}.")returnwrap_nested(conv_out,name)
[文档]defconv1d_transpose(x:Tensor,weight:Tensor,bias:Optional[Tensor]=None,stride:Optional[Union[int,Tuple[int]]]=1,padding:Optional[Union[int,Tuple[int,...]]]=0,output_padding:Optional[Union[int,Tuple[int]]]=0,dilation:Optional[Union[int,Tuple]]=1,groups:Optional[int]=1,name:str="conv1d_transpose",)->Tensor:"""1D transposed convolution operator. This operator can be seen as the gradient operator of conv1d. The output shape can be explained in the simple case when `data_layout == "NCW"` and `kernel_layout == "IOW"`. Suppose `data` has shape `(N, in_channel, in_w)`, `weight` has shape `(in_channel, out_channel, weight_w)`, we need to assure that `in_channel % groups == 0`. The shape of the output will be `(N, out_channel * groups, out_w)`, where - `out_w = ((in_w - 1) * strides[0] + weight_w - 2 * padding[0] + output_padding[0])` Parameters ---------- data : Tensor The input data to the operator. weight : Tensor The weight tensor. strides : Union[int, Tuple[int]] The strides of convolution. It is required to have length 1. padding : Union[int, Tuple[int, ...]] The padding of convolution on both sides of inputs before convolution. It is required to have length either 1 or 2. output_padding : Union[int, Tuple[int, ...]], optional Used to disambiguate the output shape. dilation : Union[int, Tuple[int]] Specifies the dilation rate to be used for dilated convolution. It is required to have length either 1. groups : int Number of groups to split the input into for grouped convolution. The number of input and output channels should be divisible by the number of groups. data_layout : str Layout of the input. kernel_layout : str Layout of the weight. out_layout : Optional[str] Layout of the output. If not specified, it is the same as data_layout out_dtype : Optional[Union[str, DataType]] Specifies the output data type for mixed precision conv2d. Returns ------- result : Tensor The computed result. """conv_out=_op.nn.conv1d_transpose(data=x._expr,weight=weight._expr,strides=stride,padding=padding,output_padding=output_padding,dilation=dilation,groups=groups,)ifbiasisnotNone:conv_out=_op.add(conv_out,_op.reshape(bias._expr,[1,-1,1]))returnwrap_nested(conv_out,name)
[文档]defmaximum(x1:Tensor,x2:Tensor,name:str="maximum"):"""Element-wise maximum Parameters ---------- x1 : Tensor The first input tensor. x2 : Tensor The second input tensor. name : str Name hint. Returns ------- result : Tensor The computed result. Examples -------- .. code:: python c = maximum(a, b) """returnwrap_nested(_op.maximum(x1._expr,x2._expr),name)
[文档]defminimum(x1:Tensor,x2:Tensor,name:str="minimum"):"""Element-wise minimum Parameters ---------- x1 : Tensor The first input tensor. x2 : Tensor The second input tensor. name : str Name hint. Returns ------- result : Tensor The computed result. Examples -------- .. code:: python c = minimum(a, b) """returnwrap_nested(_op.minimum(x1._expr,x2._expr),name)
[文档]defbroadcast_to(x:Tensor,shape:Sequence[IntExpr],name:str="broadcast_to")->Tensor:"""Broadcasts a tensor to a specified shape. Parameters ---------- x : Tensor The input data to the operator. shape : Sequence[IntExpr] The target shape. name : str Name hint. Returns ------- result : Tensor The broadcasted tensor. """returnwrap_nested(_op.broadcast_to(x._expr,shape),name)
[文档]defpermute_dims(x:Tensor,axes:Optional[List[int]]=None,name:str=None)->Tensor:"""Permutes the dimensions of an array. Parameters ---------- x : Tensor The input data to the operator. axes : Optional[List[int]] The target axes order, reverse order if not specified. name : str Name hint. Returns ------- result : Tensor The transposed result. """ifnameisNone:x_name=getattr(getattr(x,"_expr",None),"name_hint",None)ifx_nameisnotNoneand"linear"inx_name:name=x_name.replace("linear","matmul")else:name="permute_dims"returnwrap_nested(_op.permute_dims(x._expr,axes=axes),name)
[文档]defreshape(x:Tensor,shape:Sequence[IntExpr],name="reshape")->Tensor:"""Reshape the input array. ``-1`` infers the dimension of the output shape by using the remainder of the input dimensions keeping the size of the new array same as that of the input array. At most one dimension of shape can be -1. .. code-block:: python x.shape = (2, 3, 4), shape = (6, 1, -1), result.shape = (6, 1, 4) x.shape = (2, 3, 4), shape = (3, -1, 8), result.shape = (3, 1, 8) x.shape = (2, 3, 4), shape = (-1,), result.shape = (24,) Parameters ---------- x : Tensor The input data to the operator. shape : Sequence[IntExpr] The new shape. Should be compatible with the original shape. name : str Name hint. Returns ------- result : Tensor The reshaped result. Note ---- The ``-1`` inference is only performed at compile-time. That is to say, in any case the dimension length of ``-1`` cannot be inferred in compile-time, an error will be thrown. """returnwrap_nested(_op.reshape(x._expr,shape),name)
[文档]defrepeat(x:Tensor,repeats:int,axis:Optional[int]=None,name="repeat")->Tensor:"""Repeats elements of an array. Parameters ---------- data : Tensor The input tensor. repeats : int The number of repetitions. axis: Optional[int] The axis along which to repeat values. The negative numbers are interpreted counting from the backward. By default, use the flattened input array, and return a flat output array. name : str Name hint. Returns ------- ret : Tensor The computed result. Examples -------- .. code-block:: python np_x = numpy.array([[1, 2], [3, 4]]) x = Tensor.from_const(np_x) lv1 = repeat(x, repeats=2) # lv1 == [1, 1, 2, 2, 3, 3, 4, 4] lv2 = repeat(x, repeats=2, axis=1) # lv2 == [[1., 1., 2., 2.], # [3., 3., 4., 4.]] """returnwrap_nested(_op.repeat(x._expr,repeats,axis),name)
[文档]defsqueeze(x:Tensor,axis:int=-1,name:str="squeeze")->Tensor:"""Squeeze axes in the array. Parameters ---------- x : Tensor The input data to the operator. axis : Optional[Union[int, List[int]] The set of axes to remove. If axis = None, remove all axis of dimensions 1. If any specified axis has dimension that does not equal 1, it is an error. name : str Name hint. Returns ------- result : Tensor The squeezed result. """returnwrap_nested(_op.squeeze(x._expr,axis),name)
[文档]deftake(x:Tensor,indices:Tensor,axis:Optional[int]=None,name="take")->Tensor:"""Take elements from a tensor along an axis. Its semantic is mostly similar to `numpy.take` (https://numpy.org/doc/stable/reference/generated/numpy.take.html), which can cover `torch.take` (https://pytorch.org/docs/stable/generated/torch.take.html) and `onnx.gather` (https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Gather-13). Parameters ---------- x : Tensor The source tensor. indices : Tensor The indices of the values to extract. axis : Optional[int] The axis over which to select values. If it is none, the input tensor is required to be one-dimensional. name : str Name hint. Returns ------- ret : Tensor The taken result. """returnwrap_nested(_op.take(x._expr,indices._expr,axis),name)
[文档]defastype(x:Tensor,dtype:str,name:str="astype")->Tensor:"""Cast input tensor to the given data type. Parameters ---------- x : Tensor The input data to the operator. dtype: str The target data type name : str Name hint. Returns ------- result : Tensor The casted result. """# If trying to cast to same dtype as x, skip casting.ifx.dtype==dtype:returnxreturnwrap_nested(_op.astype(x._expr,dtype),name)
[文档]defrelu(x:Tensor,name:str="relu")->Tensor:"""Rectified Linear Unit (ReLU) activation function. .. math:: \text{ReLU}(x) = \text{max}(x, 0) Parameters ---------- x : Tensor The input data. name : str Name hint. Returns ------- result : Tensor The computed result. """returnwrap_nested(_op.nn.relu(x._expr),name)
[文档]defrelu6(x:Tensor,name:str="relu6")->Tensor:r"""ReLU6 activation function. .. math:: \text{ReLU6}(x) = \min(\max(x, 0), 6) Parameters ---------- x : Tensor The input data. name : str Name hint. Returns ------- result : Tensor The computed result. """returnwrap_nested(_op.nn.relu6(x._expr),name)
[文档]defsilu(x:Tensor,name:str="silu")->Tensor:r"""Sigmoid Linear Unit function .. math:: \text{SiLU}(x) = x * \text{sigmoid}(x) Parameters ---------- data : Tensor The input data name : str Name hint. Returns ------- result : Tensor The computed result. Note ---- The input tensor is required to have float dtype """returnwrap_nested(_op.nn.silu(x._expr),name)
[文档]defgelu(x:Tensor,approximate:Optional[str]=None,name:str="gelu")->Tensor:r"""Applies the Gaussian Error Linear Units function .. math:: \text{GeLU}(x) = 0.5 * x * (1 + \text{erf}(x * 0.5**0.5)) where :math:`erf` is the Gauss Error function. Parameters ---------- x : Tensor The input data approximate : Optional[str] If set to tanh, use an approximation when calculating CDF. name : str Name hint. Returns ------- result : Tensor The computed result. Note ---- The input tensor is required to have float dtype """ifapproximate=="tanh":gelu_out=_op.nn.gelu_tanh(x._expr)else:gelu_out=_op.nn.gelu(x._expr)returnwrap_nested(gelu_out,name)
[文档]defsigmoid(x:Tensor,name:str="sigmoid")->Tensor:r"""Computes sigmoid. .. math:: \text{sigmoid}(x) = \frac{1}{1 + \exp(-x)} Parameters ---------- data: Tensor The input data to the operator. name : str Name hint. Returns ------- result : Tensor The computed result. Note ---- The input tensor is required to have float dtype """returnwrap_nested(_op.sigmoid(x._expr),name)
[文档]defsoftmax(x:Tensor,axis:int=-1,name:str="softmax")->Tensor:r"""Computes softmax. .. math:: \text{softmax}(x)_i = \frac{\exp(x_i)}{\sum_j \exp(x_j)} Parameters ---------- data: Tensor The input data to the operator. axis: int The axis to sum over when computing softmax. If not specified, it is by default the last axis of the input tensor. Supports negative indexing. name : str Name hint. Returns ------- result : Tensor The computed result. Note ---- The input tensor is required to have float dtype """returnwrap_nested(_op.nn.softmax(x._expr,axis),name)
[文档]defsoftplus(x:Tensor,beta:float=1.0,threshold:float=20.0,name:str="softplus"):r"""Softplus activation function. .. math:: \text{Softplus}(x) = \frac{1}{\beta} \log(1 + e^{\beta x}) Parameters ---------- data : relax.Expr The input data. beta : float, optional Controls the smoothness of the transition. Default is 1.0. threshold : float, optional The value beyond which the function is approximated as linear to avoid numerical instability. Default is 20.0. Returns ------- result : relax.Expr The computed result. """returnwrap_nested(_op.nn.softplus(x._expr,beta=beta,threshold=threshold),name)
[文档]defprelu(x:Tensor,alpha:Tensor,name:str="prelu"):r"""Parametric ReLU activation function. .. math:: \text{PReLU}(x) = \begin{cases} x & \text{if } x \geq 0 \\ \alpha \cdot x & \text{if } x < 0 \end{cases} Parameters ---------- x : Tensor The input data. alpha : Tensor Slope coefficient for the negative part of the input. name : str, optional Optional name for the operation. Default is "prelu". Returns ------- result : Tensor The computed result. """returnwrap_nested(_op.nn.prelu(x._expr,alpha._expr),name)
[文档]deftanh(x:Tensor,name:str="tanh")->Tensor:r"""Applies the hyperbolic tangent function. .. math:: \text{Tanh}(x) = \frac{e^x - e^{-x}}{e^x + e^{-x}} Parameters ---------- x : Tensor The input data to the operator. name : str Name hint. Returns ------- result : Tensor The computed result. Note ---- The input tensor is required to have float dtype """returnwrap_nested(_op.tanh(x._expr),name)
[文档]defexp(x:Tensor,name:str="exp")->Tensor:r"""Applies the exponential function. .. math:: \text{Exp}(x) = e^x Parameters ---------- x : Tensor The input data to the operator. name : str Name hint. Returns ------- result : Tensor The computed result. Note ---- The input tensor is required to have float dtype """returnwrap_nested(_op.exp(x._expr),name)
[文档]defpermute(x:Tensor,axes:Optional[List[int]],name:str="permute")->Tensor:"""Permutes the dimensions of the input tensor. Parameters ---------- x : Tensor The input data to the operator. axes : Optional[List[int]] The target axes order. name : str Name hint. Returns ------- result : Tensor The transposed result. """returnwrap_nested(_op.permute_dims(x._expr,axes=axes),name)
[文档]defnegative(x:Tensor,name:str="neg")->Tensor:"""Numerical negative of the input tensor. Parameters ---------- x : Tensor The input data to the operator. name : str Name hint. Returns ------- result : Tensor The computed result. """returnwrap_nested(_op.negative(x._expr),name)
[文档]deflayer_norm(x:Tensor,normalized_shape:Union[int,List[int]],weight:Optional[Tensor]=None,bias:Optional[Tensor]=None,eps:float=1e-5,name:str="layer_norm",)->Tensor:r""" Layer normalization (Lei Ba and et al., 2016). Applies layer normalization to the n-dimensional input array. This operator takes an n-dimensional input array and normalizes the input using the given axis: .. math:: out = \frac{data - mean(data, axis)}{\sqrt{var(data, axis)+\epsilon}} * gamma + beta Unlike batch normalization, the mean and var are computed along the channel dimension. Assume the input has size k on axis 1, then both gamma and beta have shape (k,). .. note:: This operator can be optimized away for inference. Parameters ---------- x : Tensor Input to which layer_norm will be applied. normalized_shape: Union[int, List[int]] The shape of axes to normalize. If a single integer is used, it is treated as a singleton list and this module will normalize over the last dimension. weight: Tensor The gamma scale factor. bias: Tensor The beta offset factor. eps: float Small float added to variance to avoid dividing by zero. name : str Name hint. Returns ------- result : Tensor The computed result. """ifisinstance(normalized_shape,int):normalized_shape=[normalized_shape]dim_num=len(normalized_shape)axes=list(range(-dim_num,0))dtype=x._expr.struct_info.dtypeifweightisnotNone:weight=weight._exprelse:weight=rx.const(np.ones(normalized_shape),dtype=dtype)ifbiasisnotNone:bias=bias._exprelse:bias=rx.const(np.zeros(normalized_shape),dtype=dtype)returnwrap_nested(_op.nn.layer_norm(x._expr,gamma=weight,beta=bias,axes=axes,epsilon=eps,),name=name,)
[文档]defrms_norm(x:Tensor,weight:Tensor,axes:Union[int,List[int]],epsilon:float=1e-5,name:str="rms_norm",)->Tensor:r""" Root mean square normalization (Biao Zhang and et al., 2019). Applies root mean square normalization to the n-dimensional input array. This operator takes an n-dimensional input array and normalizes the input using the given axis: .. math:: out = \frac{data}{\sqrt{mean(data, axis)+\epsilon}} * weight Parameters ---------- data : Tensor Input to which rms_norm will be applied. weight : Tensor The scale factor. axes : Union[int, List[int]] The axes that along which the normalization is applied. epsilon : float Small float added to square mean to avoid dividing by zero. name : str Name hint. Returns ------- result : Tensor The computed result. """returnwrap_nested(_op.nn.rms_norm(x._expr,weight._expr,axes,epsilon),name)
[文档]defgroup_norm(x:Tensor,num_groups:int,weight:Optional[Tensor],bias:Optional[Tensor],eps:float=1e-5,channel_axis:int=1,axes:Optional[List[int]]=None,name:str="group_norm",)->Tensor:r""" Applies Group Normalization over a mini-batch of inputs as described in the paper `Group Normalization <https://arxiv.org/abs/1803.08494>`__ .. math:: y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta Parameters ---------- x : Tensor Input to which rms_norm will be applied. num_groups : int Number of groups to separate the channels into. weight : Tensor The gamma scale factor. bias : Tensor The beta offset factor. epsilon : float Small float added to square mean to avoid dividing by zero. channel_axis: int The channel axis of the data. axes : Optional[int] Which axes to compute the groupnorm over. If None, assumes first two channels should be ignored. name : str Name hint. Returns ------- result : Tensor The computed result. """ifweightisnotNone:weight=weight._exprifbiasisnotNone:bias=bias._exprdim=len(x._expr.struct_info.shape)ifaxesisNone:axes=list(range(2,dim))returnwrap_nested(_op.nn.group_norm(x._expr,weight,bias,num_groups,channel_axis=channel_axis,axes=axes,epsilon=eps),name,)
[文档]deftriu(x:Tensor,diagonal:int=0,name:str="triu")->Tensor:"""Return the upper triangular part of a matrix or a batch of matrices. Parameters ---------- x : Tensor The tensor that triu will be applied to. It is required to have at least two dimensions. k : int The index indicating the diagonal below which to zero elements. If k = 0, the diagonal is the main diagonal. If k < 0, the diagonal is below the main diagonal. If k > 0, the diagonal is above the main diagonal. name : str Name hint. Returns ------- ret : Tensor The result tensor. """returnwrap_nested(_op.triu(x._expr,diagonal),name)
[文档]deffull(shape:Sequence[IntExpr],fill_value:Tensor,dtype:str="float32",name:str="full",)->Tensor:"""Fill array with scalar value. Parameters ---------- shape : Sequence[IntExpr] The shape of the created tensor. fill_value : Tensor The value to fill. Must be a scalar tensor. dtype : str The data type of the created tensor. If dtype is not given, it will by default use the dtype of fill_value. name : str Name hint. Returns ------- result : Tensor The result tensor. """ifisinstance(fill_value,(_tir.FloatImm,_tir.IntImm)):fill_value=rx.const(fill_value.value,dtype=dtype)elifisinstance(fill_value,(int,float)):fill_value=rx.const(fill_value,dtype=dtype)else:fill_value=fill_value._exprreturnwrap_nested(_op.full(shape,fill_value,dtype),name)
[文档]defzeros(shape:Sequence[IntExpr],dtype:str="float32",name:str="zeros",)->Tensor:"""Construct a tensor of all zeros, with the input shape and dtype. Parameters ---------- shape : Sequence[IntExpr] The shape of the created tensor. dtype : str The data type of the created tensor. name : str Name hint. Returns ------- result : Tensor The result tensor. """returnwrap_nested(_op.zeros(shape,dtype),name)
[文档]defones(shape:Sequence[IntExpr],dtype:str="float32",name:str="ones",)->Tensor:"""Construct a tensor of all zeros, with the input shape and dtype. Parameters ---------- shape : Sequence[IntExpr] The shape of the created tensor. dtype : str The data type of the created tensor. name : str Name hint. Returns ------- result : Tensor The result tensor. """returnwrap_nested(_op.ones(shape,dtype),name)
[文档]defempty(shape:Sequence[IntExpr],dtype:str="float32",name:str="empty",)->Tensor:"""Construct an uninitialized tensor, with the input shape and dtype. Parameters ---------- shape : Sequence[IntExpr] The shape of the created tensor. dtype : str The data type of the created tensor. name : str Name hint. Returns ------- result : Tensor The result tensor. """returnwrap_nested(# type: ignore_op.builtin.alloc_tensor(rx.ShapeExpr(shape),# type: ignoredtype,runtime_device_index=0,),name,)
[文档]defsplit(ary:Tensor,indices_or_sections:Union[int,Sequence[int]],axis:int=0,name:str="split",)->Tuple[Tensor,...]:"""Split an array into multiple sub-arrays. Parameters ---------- ary : Tensor Input tensor to be split. indices_or_sections : Union[int, Sequence[int]] Indices or sections to split into. axis : int = 0 The axis along which to split, default is 0. name : str Name hint. Returns ------- result : Tuple[Tensor, ...] A list of sub-arrays as the outcome of splitting. """returnwrap_nested(_op.split(ary._expr,indices_or_sections,axis),name)
[文档]defpad(x:Tensor,pad:List[int],mode:str="constant",value:float=0.0,name:str="pad",)->Tensor:""" Apply spatial padding to the input tensor. Parameters ---------- x : Tensor Input tensor to be padded. pad : List[int] List in the format of [before_0, after_0, before_1, after_1, ...] indicating how much to pad each axis of x. mod : str Padding mode to use, constant implies padded elements will use value argument. value : float What to pad with in constant mode. name : str Name hint for this operator. Returns ------- result : Tensor Padded output tensor. """returnwrap_nested(_op.nn.pad(x._expr,pad_width=pad,pad_mode=mode,pad_value=value),name)
[文档]defsquare(x:Tensor,name:str="square")->Tensor:"""Computes the element-wise square of the input tensor. Parameters ---------- x : Tensor The input tensor. name : str Name hint. Returns ------- result : Tensor The computed result. """returnwrap_nested(_op.square(x._expr),name)
[文档]defsqrt(x:Tensor,name:str="sqrt")->Tensor:"""Computes the element-wise sqrt of the input tensor. Parameters ---------- x : Tensor The input tensor. name : str Name hint. Returns ------- result : Tensor The computed result. Note ---- The input tensor is required to have float dtype """returnwrap_nested(_op.sqrt(x._expr),name)
[文档]defget_timestep_embedding(x:Tensor,embedding_dim:int,flip_sin_to_cos:bool=False,downscale_freq_shift:float=1,scale:float=1,max_period:int=10000,name:str="get_timestep_embedding",)->Tensor:""" Timestep calculation as described in Denoising Diffusion Probabilistic Models. Parameters ---------- x : Tensor A 1-D Tensor of N indices. embedding_dim : int The dimension of the output. flip_sin_to_cos : bool If True, change the order of sine and cosine embeddings. downscale_freq_shift : float Adjusts the frequency of the sinusoidal sampling. scale : float Weight adjustment for embedding magnitude. max_period : int Controls the minimum frequency of the embeddings. name : str The name to label this operator with. Returns ------- result : Tensor [N x dim] Tensor of positional embeddings. """dtype=get_default_dtype()# Arithmetic should be done in float for precision.timesteps=_op.astype(x._expr,"float32")half_dim=embedding_dim//2exponent=rx.const(-math.log(max_period),"float32")*_op.arange(start=0,end=half_dim,dtype="float32")exponent=exponent/(rx.const(half_dim-downscale_freq_shift,"float32"))emb=_op.exp(exponent)emb=_op.expand_dims(timesteps,1)*_op.expand_dims(emb,0)# Scale embeddingsifscale!=1:emb=rx.const(scale,"float32")*emb# Concat sine and cosine embeddings.ifflip_sin_to_cos:emb=_op.concat([_op.cos(emb),_op.sin(emb)],axis=-1)else:emb=_op.concat([_op.sin(emb),_op.cos(emb)],axis=-1)# Zero padifembedding_dim%2==1:emb=_op.nn.pad(emb,(0,1,0,0))# Cast to proper output typeemb=_op.astype(emb,dtype)returnwrap_nested(emb,name)
[文档]defscaled_dot_product_attention(query:Tensor,key:Tensor,value:Tensor,attn_mask:Optional[Tensor]=None,is_causal:Optional[bool]=False,scale:Optional[float]=None,name:str="scaled_dot_product_attention",):""" Computes a scaled dot product attention on provided attention query, key, and values. Compliant with the functional torch implementation. Parameters ---------- query : Tensor Tensor representing current attention lookup of shape [batch, seq_len, num_heads, head_size]. key : Tensor Tensor representing cross attention mapping of shape [batch, seq_len_kv, num_heads_kv, head_size]. value : Tensor Tensor representing embedded attention values of shape [batch, seq_len_kv, num_heads_kv, head_size_value]. attn_mask : Optional[Tensor] Optional mask for attention, not yet supported. is_causal : Optional[bool] If set, uses a causal attention mask. scale : Optional[float] Optional extra scaling argument applied to attention. name : str Name hint for this function. """assertattn_maskisNone,"attn_mask not yet supported."causal_mask="TopLeft"ifis_causalelseNoneattn=_op.nn.attention(query._expr,key._expr,value._expr,causal_mask=causal_mask,scale=scale)returnwrap_nested(attn,name)
[文档]definterpolate(x:Tensor,size:Optional[Union[int,Tuple[int]]]=None,scale_factor:Optional[Union[float,Tuple[float]]]=None,mode:str="nearest",align_corners:Optional[bool]=None,recompute_scale_factor:Optional[bool]=None,antialias:Optional[bool]=None,data_layout:Optional[str]="NCHW",name:str="interpolate",):"""Resize a tensor using the specified mode. Parameters ---------- x : Tensor Input tensor to be resized. size : Optional[Union[int, Tuple[int]]] Requested output size, only one of size and scale_factor may be specified. scale_factor : Optional[Union[float, Tuple[float]]] Multiplier for spatial size. mode : str Algorithm used for sampling. align_corners : Optional[bool] How to map pixels before and after sampling. recompute_scale_factor : Optional[bool] Recompute the scale_factor for use in interpolation. antialias : Optional[bool] Apply antialiasing to output. data_layout : Optional[str] Layout of the input and output data. name : str Name hint for this operation. Returns ------- result : Tensor Output tensor with requested shape. """assertrecompute_scale_factorisNone,"recompute_scale_factor is not supported."assertantialiasisNone,"antialias is not supported."ifsizeisNone:size=[]fori,diminenumerate(data_layout):# Only upscale spatial dimensions.ifdimnotin["N","C"]:ifisinstance(scale_factor,(list,tuple)):size.append(int(x.shape[i]*scale_factor[len(size)]))else:size.append(int(x.shape[i]*scale_factor))ifmode.startswith("nearest"):mode="nearest_neighbor"elifmode[0:2]=="bi":mode=mode[2:]ifmode=="nearest_neighbor":coord_trans="asymmetric"elifalign_corners:coord_trans="align_corners"else:coord_trans="half_pixel"returnwrap_nested(_op.image.resize2d(x._expr,size,layout=data_layout,method=mode,coordinate_transformation_mode=coord_trans,),name,)
[文档]defccl_allreduce(x:Tensor,op_type:str="sum",in_group:bool=True,name="ccl_allreduce"):"""CCL Allreduce operator Parameters ---------- x : relax.Expr The input tensor. op_type : str The type of reduction operation to be applied to the input data. Now "sum", "prod", "min", "max" and "avg" are supported. in_group : bool Whether the reduction operation performs globally or in group as default. name : str Name hint for this operation. Returns ------- result : Tensor The result tensor of allreduce. """returnwrap_nested(_op.ccl.allreduce(x._expr,op_type,in_group),name)
[文档]defccl_allgather(x:Tensor,num_workers:int,name="ccl_allgather"):"""CCL Allgather operator Parameters ---------- x : relax.Expr The input tensor. num_workers : int Number of workers. name : str Name hint for this operation. Returns ------- result : Tensor The result tensor of allgather. """returnwrap_nested(_op.ccl.allgather(x._expr,num_workers),name)
[文档]defccl_broadcast_from_worker0(x:Tensor,name="broadcast_from_worker"):"""Broadcast data from worker-0 to all other workers. Parameters ---------- x : Tensor The tensor to be broadcast. name : str Name hint for this operation. Returns ------- result : Tensor The same tensor, which has been broadcast to all other workers. """returnwrap_nested(_op.ccl.broadcast_from_worker0(x._expr),name)
[文档]deftensor_expr_op(tensor_expr_func:Callable,name_hint:str,args:List[Union[Tensor,_tir.Var,int]],*,attrs:Optional[Dict[str,Any]]=None,):"""Build the given tensor_expr_func with te. Parameters ---------- tensor_expr_func : Callable A function that returns a te tensor or a list of tensors. name_hint : str Name hint. args: List[Union[Tensor, _tir.Var]] Arguments passed to the function. attrs: Optional[Dict[str, Any]] A dict of attributes to apply to the function. Returns ------- result : Tensor The result tensor. """def_convert(arg):ifisinstance(arg,Tensor):returnarg._expr# pylint: disable=protected-accessreturnargreturnwrap_nested(BlockBuilder.current().emit_te(tensor_expr_func,*[_convert(arg)forarginargs],primfunc_name_hint=name_hint,primfunc_attrs=attrs,),name=name_hint,)
[文档]deftensor_ir_op(func:_tir.PrimFunc,name_hint:str,args:Union[Tensor,Sequence[Union[Tensor,rx.ShapeExpr,_tir.PrimExpr]]],out:OutType,)->OutType:"""Create a `call_tir` binding with given PrimFunc Parameters ---------- func : _tir.PrimFunc The PrimFunc to call. name_hint : str Name hint. args : Union[Tensor, Sequence[Union[Tensor, rx.ShapeExpr, _tir.PrimExpr]]] The arguments to pass to the PrimFunc. out : Union[Tensor, List[Tensor]] The output tensors. Returns ------- result : Tensor The result tensor """fromtvmimportrelaxasrx# pylint: disable=import-outside-toplevelcall_tir_args,tir_vars=[],[]ifnotisinstance(args,(tuple,list)):args=[args]forarginargs:ifisinstance(arg,Tensor):call_tir_args.append(arg._expr)elifisinstance(arg,(rx.ShapeExpr,_tir.PrimExpr)):tir_vars.append(arg)else:raiseTypeError("Unsupported type: tensor_ir_op args expect Tensor or ShapeExpr or PrimExpr,"f"but got {type(arg)}")ifisinstance(out,Tensor):out_sinfo=[out._expr.struct_info]else:out_sinfo=[x._expr.struct_infoforxinout]bb=BlockBuilder.current()global_var=bb.add_func(func,name_hint)iflen(tir_vars)==0:tir_vars=Nonereturnwrap_nested(bb.emit(rx.call_tir(global_var,call_tir_args,out_sinfo,tir_vars=tir_vars)),name=name_hint,)
[文档]deftensor_ir_inplace_op(func:_tir.PrimFunc,name_hint:str,args:Union[Tensor,Sequence[Union[Tensor,rx.ShapeExpr,_tir.PrimExpr]]],inplace_indices:Union[int,List[int]],out:OutType,)->OutType:"""Create a `call_tir_inplace` binding with given PrimFunc Parameters ---------- func : _tir.PrimFunc The PrimFunc to call. name_hint : str Name hint. args : Union[Tensor, Sequence[Union[Tensor, rx.ShapeExpr, _tir.PrimExpr]]] The arguments to pass to the PrimFunc. inplace_indices : Union[int, List[int]] Specify which arguments should be used for in-place computations. If `inplace_indices` is a single integer, it will be made into a singleton list. Suppose `inplace_indices[i] = j`, where `j >= 0`. Then the `i`th output will be an alias of `args[j]`. If `inplace_indices[i] = -1`, then the `i`th output will be a freshly allocated tensor. At least one member of `inplace_indices` must not be -1. out : Union[Tensor, List[Tensor]] The output tensors. Returns ------- result : Tensor The result tensor """fromtvmimportrelaxasrx# pylint: disable=import-outside-toplevelcall_tir_args,tir_vars=[],[]ifnotisinstance(args,(tuple,list)):args=[args]forarginargs:ifisinstance(arg,Tensor):call_tir_args.append(arg._expr)elifisinstance(arg,(rx.ShapeExpr,_tir.PrimExpr)):tir_vars.append(arg)else:raiseTypeError("Unsupported type: tensor_ir_inplace_op args expect Tensor or ShapeExpr or"f" PrimExpr, but got {type(arg)}")ifisinstance(out,Tensor):out_sinfo=[out._expr.struct_info]else:out_sinfo=[x._expr.struct_infoforxinout]bb=BlockBuilder.current()global_var=bb.add_func(func,name_hint)returnwrap_nested(bb.emit(rx.call_tir_inplace(global_var,call_tir_args,inplace_indices,out_sinfo,tir_vars)),name=name_hint,)
[文档]defextern(name:str,args:Sequence[Union[Tensor,_tir.PrimExpr,int,float,str]],out:OutType,)->OutType:"""Invoke an extern function during runtime. The extern function must be registered with the " TVM runtime using `reflection::GlobalDef().def` (C++), or `tvm.register_func` (Python). Parameters ---------- name : str The name of the extern function to call. args : Sequence[Union[Tensor, _tir.PrimExpr, int, float, str]] The arguments to pass to the extern function. out : Union[Tensor, List[Tensor]] The output tensors, only Returns ------- result : Tensor The result """fromtvmimportrelaxasrx# pylint: disable=import-outside-topleveldef_convert(arg,name:str):ifisinstance(arg,Tensor):returnarg._expr# pylint: disable=protected-accessifisinstance(arg,int):returnrx.PrimValue(_tir.IntImm("int64",arg))ifisinstance(arg,float):returnrx.PrimValue(_tir.FloatImm("float64",arg))ifisinstance(arg,str):returnrx.StringImm(arg)ifisinstance(arg,_tir.PrimExpr):returnrx.PrimValue(arg)ifisinstance(arg,(tuple,list)):returnrx.Tuple([_convert(e,f"{name}_{i}")fori,einenumerate(arg)])raiseTypeError(f"Unsupported input type: {type(arg)}")rx_inputs=_convert(args,"input")rx_outputs_sinfo=_convert(out,"dummy").struct_inforeturnwrap_nested(_op.call_dps_packed(name,args=rx_inputs,out_sinfo=rx_outputs_sinfo,),name,)# type: ignore
[文档]defdebug_func(name:str,*args:Union[Tensor,_tir.PrimExpr,int,float,str],_line_info:Optional[str]=None,):"""Call a debug function during runtime. The debug function must be registered with the following type signature: .. code-block:: python @tvm.register_func(name_of_debug_func) def debug_func(lineno: str, arg_0, arg_1, ...) -> None: ... Parameters ---------- name : str The name of the debug function to call. *args : Union[Tensor, _tir.PrimExpr, int, float, str] The arguments to pass to the debug function. """# pylint: disable=import-outside-toplevelfromtvmimportrelaxasrxfrom.exporterimportExporterfrom.modulesimportIOEffect# pylint: enable=import-outside-toplevelifExporter.current().io_effectisNone:raiseRuntimeError("Debugging is only supported when debug mode is on.")io:IOEffect=Exporter.current().io_effect# type: ignoreif_line_infoisNone:filename,line_number=inspect.getframeinfo(inspect.currentframe().f_back)[:2]_line_info=f"{filename}:{line_number}"converted_args=[]forarginargs:ifisinstance(arg,Tensor):converted_args.append(arg._expr)# pylint: disable=protected-accesselifisinstance(arg,int):converted_args.append(rx.PrimValue(_tir.IntImm("int64",arg)))elifisinstance(arg,float):converted_args.append(rx.PrimValue(_tir.FloatImm("float32",arg)))elifisinstance(arg,_tir.PrimExpr):converted_args.append(rx.PrimValue(arg))elifisinstance(arg,str):converted_args.append(rx.StringImm(arg))else:raiseTypeError(f"Unsupported type {type(arg)}")io.effect=BlockBuilder.current().emit(rx.call_pure_packed("vm.builtin.invoke_debug_func",io.effect,rx.StringImm(name),rx.StringImm(_line_info),*converted_args,sinfo_args=[rx.ObjectStructInfo()],),name_hint=io.effect.name_hint,)
[文档]defprint_(tensor:Tensor):"""Debug printing a Tensor during runtime."""filename,line_number=inspect.getframeinfo(inspect.currentframe().f_back)[:2]line_info=f"{filename}:{line_number}"debug_func("vm.builtin.debug_print",tensor,_line_info=line_info)
[文档]defless(a:Tensor,b:Tensor,name:str="less")->Tensor:"""Broadcasted element-wise comparison for (lhs < rhs). Parameters ---------- a : Tensor The first input tensor. b : Tensor The second input tensor. name : str Name hint. Returns ------- result : Tensor The computed result. """returnwrap_nested(_op.less(a._expr,b._expr),name)
[文档]defless_equal(a:Tensor,b:Tensor,name:str="less_equal")->Tensor:"""Broadcasted element-wise comparison for (lhs <= rhs). Parameters ---------- a : Tensor The first input tensor. b : Tensor The second input tensor. name : str Name hint. Returns ------- result : Tensor The computed result. """returnwrap_nested(_op.less_equal(a._expr,b._expr),name)
[文档]defgreater(a:Tensor,b:Tensor,name:str="greater")->Tensor:"""Broadcasted element-wise comparison for (lhs > rhs). Parameters ---------- a : Tensor The first input tensor. b : Tensor The second input tensor. name : str Name hint. Returns ------- result : Tensor The computed result. """returnwrap_nested(_op.greater(a._expr,b._expr),name)
[文档]defgreater_equal(a:Tensor,b:Tensor,name:str="greater_equal")->Tensor:"""Broadcasted element-wise comparison for (lhs >= rhs). Parameters ---------- a : Tensor The first input tensor. b : Tensor The second input tensor. name : str Name hint. Returns ------- result : Tensor The computed result. """returnwrap_nested(_op.greater_equal(a._expr,b._expr),name)
[文档]defequal(a:Tensor,b:Tensor,name:str="equal")->Tensor:"""Broadcasted element-wise comparison for (lhs == rhs). Parameters ---------- a : Tensor The first input tensor. b : Tensor The second input tensor. name : str Name hint. Returns ------- result : Tensor The computed result. """returnwrap_nested(_op.equal(a._expr,b._expr),name)
[文档]defnot_equal(a:Tensor,b:Tensor,name:str="not_equal")->Tensor:"""Broadcasted element-wise comparison for (lhs != rhs). Parameters ---------- a : Tensor The first input tensor. b : Tensor The second input tensor. name : str Name hint. Returns ------- result : Tensor The computed result. """returnwrap_nested(_op.not_equal(a._expr,b._expr),name)
[文档]defwhere(condition:Tensor,x1:Tensor,x2:Tensor,name:str="where")->Tensor:"""Selecting elements from either the input tensors depending on the value of the condition. For a given position, return the corresponding value in `x1` if `condition` is True, and return the corresponding value in `x2` otherwise. Parameters ---------- condition : Tensor When True, yield `x1`; otherwise, yield `x2`. Must be broadcasting compatible with `x1` and `x2`. Must have boolean dtype. x1 : Tensor The first input tensor. Must be broadcasting compatible with `condition` and `x2`. x2 : Tensor The second input tensor. Must be broadcasting compatible with `condition` and `x1`. name : str Name hint. Returns ------- result : Tensor The result tensor. """# Cast condition to boolean.condition=astype(condition,"bool")returnwrap_nested(_op.where(condition._expr,x1._expr,x2._expr),name)
[文档]defcumsum(data:Tensor,axis:Optional[int]=None,dtype:Optional[str]=None,exclusive:Optional[bool]=None,name:str="cumsum",)->Tensor:"""Numpy style cumsum op. Return the cumulative inclusive sum of the elements along a given axis. Parameters ---------- data : Tensor The input data to the operator. axis : Optional[int] Axis along which the cumulative sum is computed. The default (None) is to compute the cumsum over the flattened array. dtype : Optional[str] Type of the returned array and of the accumulator in which the elements are summed. If dtype is not specified, it defaults to the dtype of data. exclusive : Optional[bool] If true will return exclusive sum in which the first element is not included. name : str Name hint. Returns ------- result : Tensor The result has the same size as data, and the same shape as data if axis is not None. If axis is None, the result is a 1-d array. Examples -------- .. code-block:: python a = [[1, 2, 3], [4, 5, 6]] cumsum(a) # if axis is not provided, cumsum is done over the flattened input. -> [ 1, 3, 6, 10, 15, 21] cumsum(a, dtype="float32") -> [ 1., 3., 6., 10., 15., 21.] cumsum(a, axis=0) # sum over rows for each of the 3 columns -> [[1, 2, 3], [5, 7, 9]] cumsum(a, axis=1) -> [[ 1, 3, 6], [ 4, 9, 15]] a = [1, 0, 1, 0, 1, 1, 0] # a is a boolean array cumsum(a, dtype=int32) # dtype should be provided to get the expected results -> [1, 1, 2, 2, 3, 4, 4] """returnwrap_nested(_op.cumsum(data._expr,axis,dtype,exclusive),name)
[文档]defsort(x:Tensor,axis:int=-1,descending:bool=False,name="sort"):"""Performs sorting along the given axis and returns an array in sorted order. Parameters ---------- x : Tensor The input tensor. axis : int Axis along which to sort the input tensor. By default the last axis of the input is used. descending : bool Whether to sort in descending order, the default is False name : str Name hint. Returns ------- out : Tensor The sorted tensor. """returnwrap_nested(_op.sort(x._expr,axis,descending),name=name)
[文档]defargsort(data:Tensor,axis:int=-1,descending:bool=False,dtype:str="int32",name="argsort"):"""Performs sorting along the given axis and returns an array of indices having same shape as an input array that index data in sorted order. Parameters ---------- data : Tensor The input data tensor. axis : int Axis long which to sort the input tensor. descending : bool Whether to sort in descending order, the default is False dtype : str The data type of the output indices. name : str Name hint. Returns ------- out : Tensor The indices of the sorted tensor. """returnwrap_nested(_op.argsort(data._expr,axis,descending,dtype),name=name)
[文档]deftopk(data:Tensor,k:int=1,axis:int=-1,ret_type:str="both",largest:bool=True,dtype:str="int32",name:str="topk",):"""Get the top k elements in an input tensor along the given axis. ret_type specifies the return type, can be one of ("both", "values", "indices"). Parameters ---------- data : Tensor The input data tensor. k : int Number of top elements to select. Return all elements if k < 1. axis : int Axis long which to sort the input tensor. ret_type: str The return type [both, values, indices]. "both": return both top k data and indices. "values": return top k data only. "indices": return top k indices only. largest : bool Whether to return largest or smallest elements. The k smallest elements are returned if largest is False. dtype : str The data type of the indices output. name : str Name hint. Returns ------- out : Tensor or Tuple[Tensor, Tensor] The computed result. """returnwrap_nested(_op.topk(data._expr,k,axis,ret_type,largest,dtype),name=name)
[文档]defmultinomial_from_uniform(prob:Tensor,uniform_sample:Tensor,sample_indices:Optional[Tensor]=None,dtype:str="int64",name:str="multinomial_from_uniform",):"""Returns a tensor where each row contains the index sampled from the multinomial probability distribution located in the corresponding row of tensor prob. Notes ----- For better cpu performance, use 'vm.builtin.multinomial_from_uniform'. For accurate results, ensure probabilities are between 0 and 1 and sum to 1. Parameters ---------- prob : Tensor A 2-D tensor of shape (batch, vocab_size) representing probability distributions. Each row is a distribution across vocabulary for a batch, where: Values range from [0, 1], indicating the probability of each vocabulary item. The sum of values in each row is 1, forming a valid distribution. uniform_sample : Tensor The uniformly sampled 2-D tensor with the shape (n, 1). Values range from 0 to 1, indicating probabilities sampled uniformly. sample_indices : Optional[Tensor] The 2-D tensor with the shape [n, 1], which indicates the specific probability distribution to sample from. The value of sample_indices[i] determines that the ith token should be sampled from the sample_indices[i]th probability distribution. For instance, if there are 3 distinct probability distributions and the requirement is to sample 2, 3, and 4 tokens from each, then sample_indices would be [0, 0, 1, 1, 1, 2, 2, 2, 2]. dtype : str The data type of output tensor. Returns ------- result : Tensor The computed tensor with shape (n, 1). Examples -------- .. code-block:: python prob = [[0.2, 0.3, 0.5], [0.3, 0.4, 0.3]] usample = [[0.4], [0.9]] sample_indices = [[0], [1]] multinomial_from_uniform(prob, usample) -> [[1], [2]] multinomial_from_uniform(prob, usample, sample_indices) -> [[1], [2]] """out_batch=uniform_sample.shape[0]ifsample_indicesisnotNone:assert(sample_indices.shape==uniform_sample.shape),"The shape of sample_indices must match the shape of uniform_sample."else:assert(prob.shape[0]==uniform_sample.shape[0]),"Number of samples must match the number of probability distributions."sample_indices=Tensor.from_const(np.arange(out_batch).reshape(out_batch,1))returnwrap_nested(_op.multinomial_from_uniform(prob._expr,uniform_sample._expr,sample_indices._expr,dtype),name,)
[文档]defsample_top_p_top_k_from_sorted_prob(sorted_prob:Tensor,sorted_index:Tensor,top_p:Tensor,top_k:Tensor,uniform_sample:Tensor,sample_indices:Optional[Tensor]=None,):"""Samples indices from a sorted probability tensor based on top_p and top_k criteria. Notes ----- For accurate results, ensure probabilities are between 0 and 1 and sum to 1. Parameters ---------- sorted_prob : Tensor A 2-D tensor, with shape (batch, vocab_size), contains probabilities sorted in descending order. sorted_index: Tensor The indices tensor with shape (batch, vocab_size), corresponding to the sorted_prob. Potentially from applying argsort on the original probability tensor in descending order. top_p : Tensor The cumulative probability threshold with shape (batch, 1) for nucleus sampling. top_k :Tensor A tensor with shape (batch, 1), representing the number of top probabilities to consider for top-k sampling. uniform_sample : Tensor Uniformly sampled values with shape (n, 1) are used to select the output indices. sample_indices : Optional[Tensor] The 2-D tensor with the shape [n, 1], which indicates the specific probability distribution to sample from. The value of sample_indices[i] determines that the ith token should be sampled from the sample_indices[i]th probability distribution. For instance, if there are 3 distinct probability distributions and the requirement is to sample 2, 3, and 4 tokens from each, then sample_indices would be [0, 0, 1, 1, 1, 2, 2, 2, 2]. Returns ------- result : Tensor The selected indices with shape (n, 1). Examples -------- .. code-block:: python prob = [[0.1 , 0.4, 0.5], [0.3, 0.3, 0.4]] sorted_prob = [[0.5, 0.4, 0.1], [0.4, 0.3, 0.3]] sorted_index = [[2, 1, 0], [2, 0, 1]] top_p = [[0.6],[0.9]] top_k = [[3],[2]] uniform_sample = [[0.5], [0.6]] sample_indices = [[0], [1]] sample_top_p_top_k_from_sorted_prob( sorted_prob, sorted_index,top_p, top_k, uniform_sample, sample_indices) -> [2, 0] """prob_dtype=sorted_prob.dtypeindex_dtype=sorted_index.dtypeprob_batch=sorted_prob.shape[0]out_batch=uniform_sample.shape[0]ifsample_indicesisnotNone:assert(sample_indices.shape==uniform_sample.shape),"The shape of sample_indices must match the shape of uniform_sample."else:assert(sorted_prob.shape[0]==uniform_sample.shape[0]),"Number of samples must match the number of probability distributions."sample_indices=Tensor.from_const(np.arange(out_batch).reshape(out_batch,1).astype(np.int64))print("sample_indices: ",sample_indices)sample_indices_dtype=sample_indices.dtypedef_cumsum_mask(cumsum_sorted,top_p,top_k,i,j):return_tir.all(cumsum_sorted[i,j]<top_p[i,0],j+1<top_k[i,0])@T.prim_func(private=True)def_get_renorm_prob(A:T.handle,B:T.handle,C:T.handle,D:T.handle):batch,vocab_size=T.int64(is_size_var=True),T.int64(is_size_var=True)cumsum_sorted=T.match_buffer(A,(batch,vocab_size),prob_dtype)top_p=T.match_buffer(B,(batch,1),prob_dtype)top_k=T.match_buffer(C,(batch,1),index_dtype)renorm_prob=T.match_buffer(D,(batch,1),prob_dtype)forax0,ax1inT.grid(batch,vocab_size):withT.block("T_get_renorm_prob"):v_ax0,v_ax1=T.axis.remap("SS",[ax0,ax1])ifnot_cumsum_mask(cumsum_sorted,top_p,top_k,v_ax0,0):renorm_prob[v_ax0,0]=cumsum_sorted[v_ax0,0]elif_cumsum_mask(cumsum_sorted,top_p,top_k,v_ax0,v_ax1):ifv_ax1+1==vocab_size:renorm_prob[v_ax0,0]=cumsum_sorted[v_ax0,v_ax1]elifnot_cumsum_mask(cumsum_sorted,top_p,top_k,v_ax0,v_ax1+1):renorm_prob[v_ax0,0]=cumsum_sorted[v_ax0,v_ax1+1]@T.prim_func(private=True)def_get_index_from_sorted(A:T.handle,B:T.handle,C:T.handle,D:T.handle,E:T.handle,F:T.handle):batch,vocab_size=T.int64(is_size_var=True),T.int64(is_size_var=True)out_batch=T.int64(is_size_var=True)cumsum_sorted=T.match_buffer(A,(batch,vocab_size),prob_dtype)indices=T.match_buffer(B,(batch,vocab_size),index_dtype)renorm_prob=T.match_buffer(C,(batch,1),prob_dtype)usample=T.match_buffer(D,(out_batch,1),prob_dtype)sample_indices=T.match_buffer(E,(out_batch,1),sample_indices_dtype)output_index=T.match_buffer(F,(out_batch,1),index_dtype)forax0,ax1inT.grid(out_batch,vocab_size):withT.block("T_get_index_from_sorted"):v_ax0,v_ax1=T.axis.remap("SS",[ax0,ax1])T.writes(output_index[v_ax0,0])if(usample[v_ax0,T.int64(0)]<cumsum_sorted[sample_indices[v_ax0,T.int64(0)],v_ax1]/renorm_prob[sample_indices[v_ax0,T.int64(0)],0]orv_ax1+1==vocab_size):ifv_ax1==0:output_index[v_ax0,0]=indices[sample_indices[v_ax0,T.int64(0)],0]elif(usample[v_ax0,T.int64(0)]>=cumsum_sorted[sample_indices[v_ax0,T.int64(0)],v_ax1-1]/renorm_prob[sample_indices[v_ax0,T.int64(0)],0]):output_index[v_ax0,0]=indices[sample_indices[v_ax0,T.int64(0)],v_ax1]cumsum_sorted=cumsum(sorted_prob,axis=1)renorm_prob=tensor_ir_op(_get_renorm_prob,"get_renorm_prob",args=[cumsum_sorted,top_p,top_k],out=Tensor.placeholder([prob_batch,1],prob_dtype,),)out_index_in_sorted=tensor_ir_op(_get_index_from_sorted,"get_index_from_sorted",args=[cumsum_sorted,sorted_index,renorm_prob,uniform_sample,sample_indices],out=Tensor.placeholder([out_batch,1],index_dtype),)returnout_index_in_sorted
[文档]defrenormalize_top_p_top_k_prob(prob,sorted_prob,top_p,top_k):"""Renormalizes probabilities after filtering with top_p and top_k, ensuring they sum up to 1. Notes ----- For accurate results, ensure probabilities are between 0 and 1 and sum to 1. Parameters ---------- prob : Tensor A 2-D tensor of shape (batch, vocab_size) representing probability distributions. sorted_prob : Tensor Probabilities sorted in descending order. top_p : Tensor The cumulative probability threshold with shape (batch, 1) for nucleus sampling. top_k :Tensor A tensor with shape (batch, 1), representing the number of top probabilities to consider for top-k sampling. Returns ------- result : Tensor The filtered and nomalized tensor with the sampe shape as input prob. """prob_dtype=prob.dtypetop_k_dtype=top_k.dtypebatch=sorted_prob.shape[0]def_cumsum_mask(cumsum_sorted,top_p,top_k,i,j):return_tir.all(cumsum_sorted[i,j]<top_p[i,0],j+1<top_k[i,0])@T.prim_func(private=True)def_get_renorm_cutoff(A:T.handle,B:T.handle,C:T.handle,D:T.handle,E:T.handle):batch,vocab_size=T.int64(),T.int64()sorted_prob=T.match_buffer(A,(batch,vocab_size),prob_dtype)cumsum_sorted=T.match_buffer(B,(batch,vocab_size),prob_dtype)top_p=T.match_buffer(C,(batch,1),prob_dtype)top_k=T.match_buffer(D,(batch,1),top_k_dtype)cutoff=T.match_buffer(E,(batch,1),prob_dtype)forax0,ax1inT.grid(batch,vocab_size):withT.block("T_get_renorm_cutoff"):v_ax0,v_ax1=T.axis.remap("SS",[ax0,ax1])if_cumsum_mask(cumsum_sorted,top_p,top_k,v_ax0,0)==0:cutoff[v_ax0,0]=sorted_prob[v_ax0,0]elif_cumsum_mask(cumsum_sorted,top_p,top_k,v_ax0,v_ax1)==1:ifv_ax1+1==vocab_size:cutoff[v_ax0,0]=sorted_prob[v_ax0,v_ax1]elif_cumsum_mask(cumsum_sorted,top_p,top_k,v_ax0,v_ax1+1)==0:cutoff[v_ax0,0]=sorted_prob[v_ax0,v_ax1+1]cumsum_sorted=cumsum(sorted_prob,axis=1)renorm_cutoff=tensor_ir_op(_get_renorm_cutoff,"get_renorm_cutoff",args=[sorted_prob,cumsum_sorted,top_p,top_k],out=Tensor.placeholder([batch,1],prob_dtype,),)filtered_prob=tensor_expr_op(lambdaprob,renorm_cutoff:te.compute(prob.shape,lambdai,j:_tir.Select(prob[i,j]>=renorm_cutoff[i,0],prob[i,j],0.0),name="filter_with_top_p_top_k",),"filter_with_top_p_top_k",args=[prob,renorm_cutoff],)renorm_prob=filtered_prob/sum(filtered_prob,axis=1,keepdims=True)returnrenorm_prob