# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=too-many-lines,invalid-name,protected-access,redefined-outer-name
# pylint: disable=redefined-builtin
"""nn.Tensor operators."""
import inspect
import math
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
import numpy as np
from tvm import te
from tvm import tir as _tir
from tvm.script import tir as T
from ... import expr as rx
from ... import op as _op
from ...block_builder import BlockBuilder
from .core import Tensor, get_default_dtype, wrap_nested
IntExpr = Union[int, _tir.PrimExpr]
[文档]
def unsqueeze(x: Tensor, dim: int, name: str = "unsqueeze") -> Tensor:
"""Add a new axis to a tensor
Parameters
----------
x : Tensor
Input tensor to expand.
dim : int
Dimension to expand.
name : str
Name hint for this operator.
Returns
-------
result : Tensor
Expanded result.
"""
return wrap_nested(_op.expand_dims(x._expr, dim), name)
[文档]
def concat(x: List[Tensor], dim: int, name: str = "concat") -> Tensor:
"""Concatenate a list of tensors along an axis.
Parameters
----------
x : List[Tensor]
List of tensors to concatenate.
dim : int
Dimension to concatenate upon.
name : str
Name hint for this operator.
Returns
-------
result : Tensor
Expanded result.
"""
# Convert tensors to expressions.
x = [t._expr for t in x]
return wrap_nested(_op.concat(x, dim), name)
[文档]
def add(a: Tensor, b: Tensor, name: str = "add") -> Tensor:
"""Addition with numpy-style broadcasting.
Parameters
----------
a : Tensor
The first input tensor.
b : Tensor
The second input tensor.
name : str
Name hint.
Returns
-------
result : Tensor
The computed result.
Examples
--------
.. code:: python
c = add(a, b)
"""
return wrap_nested(_op.add(a._expr, b._expr), name)
[文档]
def subtract(a: Tensor, b: Tensor, name: str = "subtract") -> Tensor:
"""Subtraction with numpy-style broadcasting.
Parameters
----------
a : Tensor
The first input tensor.
b : Tensor
The second input tensor.
name : str
Name hint.
Returns
-------
result : Tensor
The computed result.
Examples
--------
.. code:: python
c = subtract(a, b)
"""
return wrap_nested(_op.subtract(a._expr, b._expr), name)
[文档]
def multiply(a: Tensor, b: Tensor, name: str = "mul") -> Tensor:
"""Multiplication with numpy-style broadcasting.
Parameters
----------
a : Tensor
The first input tensor.
b : Tensor
The second input tensor.
name : str
Name hint.
Returns
-------
result : Tensor
The computed result.
Examples
--------
.. code:: python
c = multiply(a, b)
"""
return wrap_nested(_op.multiply(a._expr, b._expr), name)
[文档]
def divide(a: Tensor, b: Tensor, name: str = "divide") -> Tensor:
"""Division with numpy-style broadcasting.
Parameters
----------
a : Tensor
The first input tensor.
b : Tensor
The second input tensor.
name : str
Name hint.
Returns
-------
result : Tensor
The computed result.
Examples
--------
.. code:: python
c = divide(a, b)
"""
return wrap_nested(_op.divide(a._expr, b._expr), name)
[文档]
def chunk(x: Tensor, chunks: int, dim: int = 0, name: str = "chunk") -> Tensor:
"""Split a tensor along dim into the specified number of chunks.
Parameters
----------
x : Tensor
Input tensor to be split.
chunks : int
Number of pieces to slice x into.
dim : int
Which dimension to split x.
name : str
Name hint for this operation.
Returns
-------
result : Tuple[Tensor]
A tuple with chunks elements containing slices of x.
"""
return wrap_nested(_op.split(x._expr, chunks, dim), name)
[文档]
def sum(
x: Tensor,
axis: Optional[Union[int, List[int]]] = None,
keepdims: bool = False,
name: str = "sum",
) -> Tensor:
"""Computes the sum of tensor elements over given axes.
Parameters
----------
x : Tensor
The input data tensor
axis : Optional[Union[int, List[int]]]
Axis or axes along which a sum is performed.
The default, axis=None, will sum all of the elements of the input tensor.
Negative indexing is supported.
keepdims : bool
If this is set to True, the axes which are reduced are left in the result as
dimensions with size one.
With this option, the result will broadcast correctly against the input tensor.
name : str
Name hint for this operation.
Returns
-------
result : Tensor
The computed result.
"""
return wrap_nested(_op.sum(x._expr, axis, keepdims), name)
[文档]
def max(
x: Tensor,
axis: Optional[Union[int, List[int]]] = None,
keepdims: bool = False,
name: str = "max",
) -> Tensor:
"""Computes the max of tensor elements over given axes.
Parameters
----------
x : Tensor
The input data tensor
axis : Optional[Union[int, List[int]]]
Axis or axes along which a max is performed.
The default, axis=None, will max all of the elements of the input tensor.
Negative indexing is supported.
keepdims : bool
If this is set to True, the axes which are reduced are left in the result as
dimensions with size one.
With this option, the result will broadcast correctly against the input tensor.
name : str
Name hint for this operation.
Returns
-------
result : Tensor
The computed result.
"""
return wrap_nested(_op.max(x._expr, axis, keepdims), name)
[文档]
def min(
x: Tensor,
axis: Optional[Union[int, List[int]]] = None,
keepdims: bool = False,
name: str = "min",
) -> Tensor:
"""Computes the min of tensor elements over given axes.
Parameters
----------
x : Tensor
The input data tensor
axis : Optional[Union[int, List[int]]]
Axis or axes along which a min is performed.
The default, axis=None, will min all of the elements of the input tensor.
Negative indexing is supported.
keepdims : bool
If this is set to True, the axes which are reduced are left in the result as
dimensions with size one.
With this option, the result will broadcast correctly against the input tensor.
name : str
Name hint for this operation.
Returns
-------
result : Tensor
The computed result.
"""
return wrap_nested(_op.min(x._expr, axis, keepdims), name)
[文档]
def matmul(a: Tensor, b: Tensor, out_dtype: Optional[str] = None, name: str = "matmul") -> Tensor:
"""General matrix multiplication of two tensors, with broadcasting on batched dimensions.
The semantics and output shape deduction rule is specified as
https://data-apis.org/array-api/latest/API_specification/generated/array_api.matmul.html.
Parameters
----------
a : Tensor
The first input tensor.
b : Tensor
The second input tensor.
out_dtype: Optional[Union[str, DataType]]
The data type of the matmul result.
When it is not specified, the output dtype will be the same as input dtype.
name : str
Name hint.
Returns
-------
result : Tensor
The computed result.
Examples
--------
.. code:: python
c = matmul(a, b)
"""
return wrap_nested(_op.matmul(a._expr, b._expr, out_dtype=out_dtype), name)
[文档]
def conv1d(
x: Tensor,
weight: Tensor,
bias: Optional[Tensor] = None,
stride: Optional[Union[int, Tuple]] = 1,
padding: Optional[Union[int, Tuple, str]] = 0,
dilation: Optional[Union[int, Tuple]] = 1,
groups: Optional[int] = 1,
name: str = "conv1d",
) -> Tensor:
r"""1D convolution.
This operator takes the weight as the 1D convolution kernel
and convolves it with data to produce an output.
In the default case, where the data_layout is `NCW`
and kernel_layout is `OIW`, conv1d takes in
a data Tensor with shape `(batch_size, in_channels, width)`,
and a weight Tensor with shape `(channels, in_channels, kernel_w)`,
where `kernel_w` is the length of the `W` kernel dimension,
to produce an output Tensor with the following rule:
.. math::
\mbox{out}[b, c, x] = \sum_{dx, k}
\mbox{data}[b, k, \mbox{strides} * x + dx] *
\mbox{weight}[c, k, dx]
Padding and dilation are applied to data and weight respectively before the computation.
This operator accepts data layout specification.
Semantically, the operator will convert the layout to the canonical layout
(`NCW` for data and `OIW` for weight), perform the computation,
then convert to the out_layout.
Parameters
----------
x : Tensor
The input data to the operator.
weight : Tensor
The weight expressions.
bias : Optional[Tensor]
Optional bias tensor of shape [O].
strides : Optional[Union[int, Tuple]]
The strides of convolution. It is required to have length 1.
padding : Optional[Union[int, Tuple, str]]
The padding of convolution on both sides of inputs before convolution.
It is required to have length either 1 or 2.
dilation : Optional[Union[int, Tuple]]
Specifies the dilation rate to be used for dilated convolution.
It is required to have length 1.
groups : Optional[int]
Number of groups to split the input into for grouped convolution.
The number of input and output channels should be divisible by the number of groups.
name : str
Name hint.
Returns
-------
result : Tensor
The computed result.
"""
conv_out = _op.nn.conv1d(
data=x._expr,
weight=weight._expr,
strides=stride,
padding=padding,
dilation=dilation,
groups=groups,
)
if bias is not None:
conv_out = _op.add(conv_out, _op.reshape(bias._expr, [1, -1, 1]))
return wrap_nested(conv_out, name)
[文档]
def conv2d(
x: Tensor,
weight: Tensor,
bias: Optional[Tensor] = None,
stride: Optional[Union[int, Tuple]] = 1,
padding: Optional[Union[int, Tuple, str]] = 0,
dilation: Optional[Union[int, Tuple]] = 1,
groups: Optional[int] = 1,
data_layout: Optional[str] = "NCHW",
name: str = "conv2d",
) -> Tensor:
"""Applies a 2D convolution over an input image composed of sevaral input planes
Parameters
----------
x : Tensor
Input tensor of shape [B, N, H, W]
weight : Tensor
Filters of shape [O, N/groups, kH, kW]
bias : Optional[Tensor]
Optional bias tensor of shape [O].
stride : Optional[Union[int, Tuple]]
The stride of the convolving kernel. Can be a single number
or tuple of (sH, sW).
padding : Optional[[Union[int, Tuple]]]
Implicit paddings on both sides of the input.
dilation : Optional[Union[int, Tuple]]
The spacing between kernel elements. Can be a single number of tuple (dH, dW).
groups : Optional[int]
Split input into a number of groups.
data_layout : Optional[str]
Layout of input and output data.
name : str
Name hint.
Returns
-------
result : Tensor
The computed result with shape [B, O, oH, oW].
"""
conv_out = _op.nn.conv2d(
data=x._expr,
weight=weight._expr,
strides=stride,
padding=padding,
dilation=dilation,
data_layout=data_layout,
groups=groups,
)
if bias is not None:
if data_layout == "NCHW":
conv_out = _op.add(conv_out, _op.reshape(bias._expr, [1, -1, 1, 1]))
elif data_layout == "NHWC":
conv_out = _op.add(conv_out, _op.reshape(bias._expr, [1, 1, 1, -1]))
else:
raise NotImplementedError(f"Dont know how to handle layout {data_layout}.")
return wrap_nested(conv_out, name)
[文档]
def conv3d(
x: Tensor,
weight: Tensor,
bias: Optional[Tensor] = None,
stride: Optional[Union[int, Tuple]] = 1,
padding: Optional[Union[int, Tuple, str]] = 0,
dilation: Optional[Union[int, Tuple]] = 1,
groups: Optional[int] = 1,
data_layout: Optional[str] = "NCDHW",
name: str = "conv3d",
) -> Tensor:
"""Applies a 3D convolution over an input image composed of sevaral input planes
Parameters
----------
x : Tensor
Input tensor of shape [B, N, D, H, W]
weight : Tensor
Filters of shape [O, N/groups, kD, kH, kW]
bias : Optional[Tensor]
Optional bias tensor of shape [O].
stride : Optional[Union[int, Tuple]]
The stride of the convolving kernel. Can be a single number
or tuple of (sD, sH, sW).
padding : Optional[[Union[int, Tuple]]]
Implicit paddings on both sides of the input.
dilation : Optional[Union[int, Tuple]]
The spacing between kernel elements. Can be a single number of tuple (dD, dH, dW).
groups : Optional[int]
Split input into a number of groups.
data_layout : Optional[str]
Optional layout of the input and output data.
name : str
Name hint.
Returns
-------
result : Tensor
The computed result with shape [B, O, oD, oH, oW].
"""
conv_out = _op.nn.conv3d(
data=x._expr,
weight=weight._expr,
strides=stride,
padding=padding,
dilation=dilation,
groups=groups,
data_layout=data_layout,
)
if bias is not None:
if data_layout == "NCDHW":
conv_out = _op.add(conv_out, _op.reshape(bias._expr, [1, -1, 1, 1, 1]))
elif data_layout == "NDHWC":
conv_out = _op.add(conv_out, _op.reshape(bias._expr, [1, 1, 1, 1, -1]))
else:
raise NotImplementedError(f"Dont know how to handle layout {data_layout}.")
return wrap_nested(conv_out, name)
[文档]
def conv1d_transpose(
x: Tensor,
weight: Tensor,
bias: Optional[Tensor] = None,
stride: Optional[Union[int, Tuple[int]]] = 1,
padding: Optional[Union[int, Tuple[int, ...]]] = 0,
output_padding: Optional[Union[int, Tuple[int]]] = 0,
dilation: Optional[Union[int, Tuple]] = 1,
groups: Optional[int] = 1,
name: str = "conv1d_transpose",
) -> Tensor:
"""1D transposed convolution operator.
This operator can be seen as the gradient operator of conv1d.
The output shape can be explained in the simple case when `data_layout == "NCW"` and
`kernel_layout == "IOW"`. Suppose `data` has shape `(N, in_channel, in_w)`, `weight` has
shape `(in_channel, out_channel, weight_w)`, we need to assure that `in_channel % groups == 0`.
The shape of the output will be `(N, out_channel * groups, out_w)`, where
- `out_w = ((in_w - 1) * strides[0] + weight_w - 2 * padding[0] + output_padding[0])`
Parameters
----------
data : Tensor
The input data to the operator.
weight : Tensor
The weight tensor.
strides : Union[int, Tuple[int]]
The strides of convolution. It is required to have length 1.
padding : Union[int, Tuple[int, ...]]
The padding of convolution on both sides of inputs before convolution.
It is required to have length either 1 or 2.
output_padding : Union[int, Tuple[int, ...]], optional
Used to disambiguate the output shape.
dilation : Union[int, Tuple[int]]
Specifies the dilation rate to be used for dilated convolution.
It is required to have length either 1.
groups : int
Number of groups to split the input into for grouped convolution.
The number of input and output channels should be divisible by the number of groups.
data_layout : str
Layout of the input.
kernel_layout : str
Layout of the weight.
out_layout : Optional[str]
Layout of the output. If not specified, it is the same as data_layout
out_dtype : Optional[Union[str, DataType]]
Specifies the output data type for mixed precision conv2d.
Returns
-------
result : Tensor
The computed result.
"""
conv_out = _op.nn.conv1d_transpose(
data=x._expr,
weight=weight._expr,
strides=stride,
padding=padding,
output_padding=output_padding,
dilation=dilation,
groups=groups,
)
if bias is not None:
conv_out = _op.add(conv_out, _op.reshape(bias._expr, [1, -1, 1]))
return wrap_nested(conv_out, name)
[文档]
def maximum(x1: Tensor, x2: Tensor, name: str = "maximum"):
"""Element-wise maximum
Parameters
----------
x1 : Tensor
The first input tensor.
x2 : Tensor
The second input tensor.
name : str
Name hint.
Returns
-------
result : Tensor
The computed result.
Examples
--------
.. code:: python
c = maximum(a, b)
"""
return wrap_nested(_op.maximum(x1._expr, x2._expr), name)
[文档]
def minimum(x1: Tensor, x2: Tensor, name: str = "minimum"):
"""Element-wise minimum
Parameters
----------
x1 : Tensor
The first input tensor.
x2 : Tensor
The second input tensor.
name : str
Name hint.
Returns
-------
result : Tensor
The computed result.
Examples
--------
.. code:: python
c = minimum(a, b)
"""
return wrap_nested(_op.minimum(x1._expr, x2._expr), name)
[文档]
def broadcast_to(x: Tensor, shape: Sequence[IntExpr], name: str = "broadcast_to") -> Tensor:
"""Broadcasts a tensor to a specified shape.
Parameters
----------
x : Tensor
The input data to the operator.
shape : Sequence[IntExpr]
The target shape.
name : str
Name hint.
Returns
-------
result : Tensor
The broadcasted tensor.
"""
return wrap_nested(_op.broadcast_to(x._expr, shape), name)
[文档]
def permute_dims(x: Tensor, axes: Optional[List[int]] = None, name: str = None) -> Tensor:
"""Permutes the dimensions of an array.
Parameters
----------
x : Tensor
The input data to the operator.
axes : Optional[List[int]]
The target axes order, reverse order if not specified.
name : str
Name hint.
Returns
-------
result : Tensor
The transposed result.
"""
if name is None:
x_name = getattr(getattr(x, "_expr", None), "name_hint", None)
if x_name is not None and "linear" in x_name:
name = x_name.replace("linear", "matmul")
else:
name = "permute_dims"
return wrap_nested(_op.permute_dims(x._expr, axes=axes), name)
[文档]
def reshape(x: Tensor, shape: Sequence[IntExpr], name="reshape") -> Tensor:
"""Reshape the input array.
``-1`` infers the dimension of the output shape by using the remainder of
the input dimensions keeping the size of the new array same as that of the input array.
At most one dimension of shape can be -1.
.. code-block:: python
x.shape = (2, 3, 4), shape = (6, 1, -1), result.shape = (6, 1, 4)
x.shape = (2, 3, 4), shape = (3, -1, 8), result.shape = (3, 1, 8)
x.shape = (2, 3, 4), shape = (-1,), result.shape = (24,)
Parameters
----------
x : Tensor
The input data to the operator.
shape : Sequence[IntExpr]
The new shape. Should be compatible with the original shape.
name : str
Name hint.
Returns
-------
result : Tensor
The reshaped result.
Note
----
The ``-1`` inference is only performed at compile-time.
That is to say, in any case the dimension length of ``-1`` cannot be inferred in
compile-time, an error will be thrown.
"""
return wrap_nested(_op.reshape(x._expr, shape), name)
[文档]
def repeat(x: Tensor, repeats: int, axis: Optional[int] = None, name="repeat") -> Tensor:
"""Repeats elements of an array.
Parameters
----------
data : Tensor
The input tensor.
repeats : int
The number of repetitions.
axis: Optional[int]
The axis along which to repeat values. The negative numbers are interpreted
counting from the backward. By default, use the flattened input array, and
return a flat output array.
name : str
Name hint.
Returns
-------
ret : Tensor
The computed result.
Examples
--------
.. code-block:: python
np_x = numpy.array([[1, 2], [3, 4]])
x = Tensor.from_const(np_x)
lv1 = repeat(x, repeats=2) # lv1 == [1, 1, 2, 2, 3, 3, 4, 4]
lv2 = repeat(x, repeats=2, axis=1) # lv2 == [[1., 1., 2., 2.],
# [3., 3., 4., 4.]]
"""
return wrap_nested(_op.repeat(x._expr, repeats, axis), name)
[文档]
def squeeze(x: Tensor, axis: int = -1, name: str = "squeeze") -> Tensor:
"""Squeeze axes in the array.
Parameters
----------
x : Tensor
The input data to the operator.
axis : Optional[Union[int, List[int]]
The set of axes to remove.
If axis = None, remove all axis of dimensions 1.
If any specified axis has dimension that does not equal 1, it is an error.
name : str
Name hint.
Returns
-------
result : Tensor
The squeezed result.
"""
return wrap_nested(_op.squeeze(x._expr, axis), name)
[文档]
def take(x: Tensor, indices: Tensor, axis: Optional[int] = None, name="take") -> Tensor:
"""Take elements from a tensor along an axis.
Its semantic is mostly similar to `numpy.take`
(https://numpy.org/doc/stable/reference/generated/numpy.take.html),
which can cover `torch.take` (https://pytorch.org/docs/stable/generated/torch.take.html) and
`onnx.gather` (https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Gather-13).
Parameters
----------
x : Tensor
The source tensor.
indices : Tensor
The indices of the values to extract.
axis : Optional[int]
The axis over which to select values.
If it is none, the input tensor is required to be one-dimensional.
name : str
Name hint.
Returns
-------
ret : Tensor
The taken result.
"""
return wrap_nested(_op.take(x._expr, indices._expr, axis), name)
[文档]
def astype(x: Tensor, dtype: str, name: str = "astype") -> Tensor:
"""Cast input tensor to the given data type.
Parameters
----------
x : Tensor
The input data to the operator.
dtype: str
The target data type
name : str
Name hint.
Returns
-------
result : Tensor
The casted result.
"""
# If trying to cast to same dtype as x, skip casting.
if x.dtype == dtype:
return x
return wrap_nested(_op.astype(x._expr, dtype), name)
[文档]
def relu(x: Tensor, name: str = "relu") -> Tensor:
"""Rectified Linear Unit (ReLU) activation function.
.. math::
\text{ReLU}(x) = \text{max}(x, 0)
Parameters
----------
x : Tensor
The input data.
name : str
Name hint.
Returns
-------
result : Tensor
The computed result.
"""
return wrap_nested(_op.nn.relu(x._expr), name)
[文档]
def relu6(x: Tensor, name: str = "relu6") -> Tensor:
r"""ReLU6 activation function.
.. math::
\text{ReLU6}(x) = \min(\max(x, 0), 6)
Parameters
----------
x : Tensor
The input data.
name : str
Name hint.
Returns
-------
result : Tensor
The computed result.
"""
return wrap_nested(_op.nn.relu6(x._expr), name)
[文档]
def silu(x: Tensor, name: str = "silu") -> Tensor:
r"""Sigmoid Linear Unit function
.. math::
\text{SiLU}(x) = x * \text{sigmoid}(x)
Parameters
----------
data : Tensor
The input data
name : str
Name hint.
Returns
-------
result : Tensor
The computed result.
Note
----
The input tensor is required to have float dtype
"""
return wrap_nested(_op.nn.silu(x._expr), name)
[文档]
def gelu(x: Tensor, approximate: Optional[str] = None, name: str = "gelu") -> Tensor:
r"""Applies the Gaussian Error Linear Units function
.. math::
\text{GeLU}(x) = 0.5 * x * (1 + \text{erf}(x * 0.5**0.5))
where :math:`erf` is the Gauss Error function.
Parameters
----------
x : Tensor
The input data
approximate : Optional[str]
If set to tanh, use an approximation when calculating CDF.
name : str
Name hint.
Returns
-------
result : Tensor
The computed result.
Note
----
The input tensor is required to have float dtype
"""
if approximate == "tanh":
gelu_out = _op.nn.gelu_tanh(x._expr)
else:
gelu_out = _op.nn.gelu(x._expr)
return wrap_nested(gelu_out, name)
[文档]
def sigmoid(x: Tensor, name: str = "sigmoid") -> Tensor:
r"""Computes sigmoid.
.. math:: \text{sigmoid}(x) = \frac{1}{1 + \exp(-x)}
Parameters
----------
data: Tensor
The input data to the operator.
name : str
Name hint.
Returns
-------
result : Tensor
The computed result.
Note
----
The input tensor is required to have float dtype
"""
return wrap_nested(_op.sigmoid(x._expr), name)
[文档]
def softmax(x: Tensor, axis: int = -1, name: str = "softmax") -> Tensor:
r"""Computes softmax.
.. math:: \text{softmax}(x)_i = \frac{\exp(x_i)}{\sum_j \exp(x_j)}
Parameters
----------
data: Tensor
The input data to the operator.
axis: int
The axis to sum over when computing softmax.
If not specified, it is by default the last axis of the input tensor.
Supports negative indexing.
name : str
Name hint.
Returns
-------
result : Tensor
The computed result.
Note
----
The input tensor is required to have float dtype
"""
return wrap_nested(_op.nn.softmax(x._expr, axis), name)
[文档]
def softplus(x: Tensor, beta: float = 1.0, threshold: float = 20.0, name: str = "softplus"):
r"""Softplus activation function.
.. math::
\text{Softplus}(x) = \frac{1}{\beta} \log(1 + e^{\beta x})
Parameters
----------
data : relax.Expr
The input data.
beta : float, optional
Controls the smoothness of the transition. Default is 1.0.
threshold : float, optional
The value beyond which the function is approximated as linear
to avoid numerical instability. Default is 20.0.
Returns
-------
result : relax.Expr
The computed result.
"""
return wrap_nested(_op.nn.softplus(x._expr, beta=beta, threshold=threshold), name)
[文档]
def prelu(x: Tensor, alpha: Tensor, name: str = "prelu"):
r"""Parametric ReLU activation function.
.. math::
\text{PReLU}(x) = \begin{cases}
x & \text{if } x \geq 0 \\
\alpha \cdot x & \text{if } x < 0
\end{cases}
Parameters
----------
x : Tensor
The input data.
alpha : Tensor
Slope coefficient for the negative part of the input.
name : str, optional
Optional name for the operation. Default is "prelu".
Returns
-------
result : Tensor
The computed result.
"""
return wrap_nested(_op.nn.prelu(x._expr, alpha._expr), name)
[文档]
def tanh(x: Tensor, name: str = "tanh") -> Tensor:
r"""Applies the hyperbolic tangent function.
.. math::
\text{Tanh}(x) = \frac{e^x - e^{-x}}{e^x + e^{-x}}
Parameters
----------
x : Tensor
The input data to the operator.
name : str
Name hint.
Returns
-------
result : Tensor
The computed result.
Note
----
The input tensor is required to have float dtype
"""
return wrap_nested(_op.tanh(x._expr), name)
[文档]
def exp(x: Tensor, name: str = "exp") -> Tensor:
r"""Applies the exponential function.
.. math::
\text{Exp}(x) = e^x
Parameters
----------
x : Tensor
The input data to the operator.
name : str
Name hint.
Returns
-------
result : Tensor
The computed result.
Note
----
The input tensor is required to have float dtype
"""
return wrap_nested(_op.exp(x._expr), name)
[文档]
def log(x: Tensor, name: str = "log") -> Tensor:
r"""Applies the natural logarithm function.
.. math::
\text{Log}(x) = \log(x)
Parameters
----------
x : Tensor
The input data to the operator.
name : str
Name hint.
Returns
-------
result : Tensor
The computed result.
Note
----
The input tensor is required to have float dtype
"""
return wrap_nested(_op.log(x._expr), name)
[文档]
def floor(x: Tensor, name: str = "floor") -> Tensor:
r"""Computes the floor of the input tensor.
.. math::
\text{Floor}(x) = \floor(x)
Parameters
----------
x : Tensor
The input data to the operator.
name : str
Name hint.
Returns
-------
result : Tensor
The computed result.
Note
----
The input tensor is required to have float dtype
"""
return wrap_nested(_op.floor(x._expr), name)
[文档]
def arange(
start: int,
end: Optional[int] = None,
step: int = 1,
dtype: Optional[str] = "float32",
name: str = "arange",
) -> Tensor:
r"""Construct a tensor with evenly spaced elements.
Parameters
----------
start : int
The start of the interval.
end : Optional[int]
The end of the interval. If not given, it will be set to start,
and start will be set to 0.
step : int
The step size.
dtype : Optional[str]
The data type of the created tensor.
name : str
Name hint.
Returns
-------
result : Tensor
The computed result.
"""
return wrap_nested(_op.arange(start, end, step, dtype), name)
[文档]
def permute(x: Tensor, axes: Optional[List[int]], name: str = "permute") -> Tensor:
"""Permutes the dimensions of the input tensor.
Parameters
----------
x : Tensor
The input data to the operator.
axes : Optional[List[int]]
The target axes order.
name : str
Name hint.
Returns
-------
result : Tensor
The transposed result.
"""
return wrap_nested(_op.permute_dims(x._expr, axes=axes), name)
[文档]
def negative(x: Tensor, name: str = "neg") -> Tensor:
"""Numerical negative of the input tensor.
Parameters
----------
x : Tensor
The input data to the operator.
name : str
Name hint.
Returns
-------
result : Tensor
The computed result.
"""
return wrap_nested(_op.negative(x._expr), name)
[文档]
def layer_norm(
x: Tensor,
normalized_shape: Union[int, List[int]],
weight: Optional[Tensor] = None,
bias: Optional[Tensor] = None,
eps: float = 1e-5,
name: str = "layer_norm",
) -> Tensor:
r"""
Layer normalization (Lei Ba and et al., 2016).
Applies layer normalization to the n-dimensional input array.
This operator takes an n-dimensional input array and normalizes
the input using the given axis:
.. math::
out = \frac{data - mean(data, axis)}{\sqrt{var(data, axis)+\epsilon}}
* gamma + beta
Unlike batch normalization, the mean and var are computed along the channel dimension.
Assume the input has size k on axis 1, then both gamma and beta have shape (k,).
.. note::
This operator can be optimized away for inference.
Parameters
----------
x : Tensor
Input to which layer_norm will be applied.
normalized_shape: Union[int, List[int]]
The shape of axes to normalize. If a single integer
is used, it is treated as a singleton list and this
module will normalize over the last dimension.
weight: Tensor
The gamma scale factor.
bias: Tensor
The beta offset factor.
eps: float
Small float added to variance to avoid dividing by zero.
name : str
Name hint.
Returns
-------
result : Tensor
The computed result.
"""
if isinstance(normalized_shape, int):
normalized_shape = [normalized_shape]
dim_num = len(normalized_shape)
axes = list(range(-dim_num, 0))
dtype = x._expr.struct_info.dtype
if weight is not None:
weight = weight._expr
else:
weight = rx.const(np.ones(normalized_shape), dtype=dtype)
if bias is not None:
bias = bias._expr
else:
bias = rx.const(np.zeros(normalized_shape), dtype=dtype)
return wrap_nested(
_op.nn.layer_norm(
x._expr,
gamma=weight,
beta=bias,
axes=axes,
epsilon=eps,
),
name=name,
)
[文档]
def rms_norm(
x: Tensor,
weight: Tensor,
axes: Union[int, List[int]],
epsilon: float = 1e-5,
name: str = "rms_norm",
) -> Tensor:
r"""
Root mean square normalization (Biao Zhang and et al., 2019).
Applies root mean square normalization to the n-dimensional input array.
This operator takes an n-dimensional input array and normalizes
the input using the given axis:
.. math::
out = \frac{data}{\sqrt{mean(data, axis)+\epsilon}} * weight
Parameters
----------
data : Tensor
Input to which rms_norm will be applied.
weight : Tensor
The scale factor.
axes : Union[int, List[int]]
The axes that along which the normalization is applied.
epsilon : float
Small float added to square mean to avoid dividing by zero.
name : str
Name hint.
Returns
-------
result : Tensor
The computed result.
"""
return wrap_nested(_op.nn.rms_norm(x._expr, weight._expr, axes, epsilon), name)
[文档]
def group_norm(
x: Tensor,
num_groups: int,
weight: Optional[Tensor],
bias: Optional[Tensor],
eps: float = 1e-5,
channel_axis: int = 1,
axes: Optional[List[int]] = None,
name: str = "group_norm",
) -> Tensor:
r"""
Applies Group Normalization over a mini-batch of inputs as described in
the paper `Group Normalization <https://arxiv.org/abs/1803.08494>`__
.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
Parameters
----------
x : Tensor
Input to which rms_norm will be applied.
num_groups : int
Number of groups to separate the channels into.
weight : Tensor
The gamma scale factor.
bias : Tensor
The beta offset factor.
epsilon : float
Small float added to square mean to avoid dividing by zero.
channel_axis: int
The channel axis of the data.
axes : Optional[int]
Which axes to compute the groupnorm over. If None, assumes first
two channels should be ignored.
name : str
Name hint.
Returns
-------
result : Tensor
The computed result.
"""
if weight is not None:
weight = weight._expr
if bias is not None:
bias = bias._expr
dim = len(x._expr.struct_info.shape)
if axes is None:
axes = list(range(2, dim))
return wrap_nested(
_op.nn.group_norm(
x._expr, weight, bias, num_groups, channel_axis=channel_axis, axes=axes, epsilon=eps
),
name,
)
[文档]
def triu(x: Tensor, diagonal: int = 0, name: str = "triu") -> Tensor:
"""Return the upper triangular part of a matrix or a batch of matrices.
Parameters
----------
x : Tensor
The tensor that triu will be applied to.
It is required to have at least two dimensions.
k : int
The index indicating the diagonal below which to zero elements.
If k = 0, the diagonal is the main diagonal.
If k < 0, the diagonal is below the main diagonal.
If k > 0, the diagonal is above the main diagonal.
name : str
Name hint.
Returns
-------
ret : Tensor
The result tensor.
"""
return wrap_nested(_op.triu(x._expr, diagonal), name)
[文档]
def full(
shape: Sequence[IntExpr],
fill_value: Tensor,
dtype: str = "float32",
name: str = "full",
) -> Tensor:
"""Fill array with scalar value.
Parameters
----------
shape : Sequence[IntExpr]
The shape of the created tensor.
fill_value : Tensor
The value to fill. Must be a scalar tensor.
dtype : str
The data type of the created tensor.
If dtype is not given, it will by default use the dtype of fill_value.
name : str
Name hint.
Returns
-------
result : Tensor
The result tensor.
"""
if isinstance(fill_value, (_tir.FloatImm, _tir.IntImm)):
fill_value = rx.const(fill_value.value, dtype=dtype)
elif isinstance(fill_value, (int, float)):
fill_value = rx.const(fill_value, dtype=dtype)
else:
fill_value = fill_value._expr
return wrap_nested(_op.full(shape, fill_value, dtype), name)
[文档]
def zeros(
shape: Sequence[IntExpr],
dtype: str = "float32",
name: str = "zeros",
) -> Tensor:
"""Construct a tensor of all zeros, with the input shape and dtype.
Parameters
----------
shape : Sequence[IntExpr]
The shape of the created tensor.
dtype : str
The data type of the created tensor.
name : str
Name hint.
Returns
-------
result : Tensor
The result tensor.
"""
return wrap_nested(_op.zeros(shape, dtype), name)
[文档]
def ones(
shape: Sequence[IntExpr],
dtype: str = "float32",
name: str = "ones",
) -> Tensor:
"""Construct a tensor of all zeros, with the input shape and dtype.
Parameters
----------
shape : Sequence[IntExpr]
The shape of the created tensor.
dtype : str
The data type of the created tensor.
name : str
Name hint.
Returns
-------
result : Tensor
The result tensor.
"""
return wrap_nested(_op.ones(shape, dtype), name)
[文档]
def empty(
shape: Sequence[IntExpr],
dtype: str = "float32",
name: str = "empty",
) -> Tensor:
"""Construct an uninitialized tensor, with the input shape and dtype.
Parameters
----------
shape : Sequence[IntExpr]
The shape of the created tensor.
dtype : str
The data type of the created tensor.
name : str
Name hint.
Returns
-------
result : Tensor
The result tensor.
"""
return wrap_nested( # type: ignore
_op.builtin.alloc_tensor(
rx.ShapeExpr(shape), # type: ignore
dtype,
runtime_device_index=0,
),
name,
)
[文档]
def split(
ary: Tensor,
indices_or_sections: Union[int, Sequence[int]],
axis: int = 0,
name: str = "split",
) -> Tuple[Tensor, ...]:
"""Split an array into multiple sub-arrays.
Parameters
----------
ary : Tensor
Input tensor to be split.
indices_or_sections : Union[int, Sequence[int]]
Indices or sections to split into.
axis : int = 0
The axis along which to split, default is 0.
name : str
Name hint.
Returns
-------
result : Tuple[Tensor, ...]
A list of sub-arrays as the outcome of splitting.
"""
return wrap_nested(_op.split(ary._expr, indices_or_sections, axis), name)
[文档]
def pad(
x: Tensor,
pad: List[int],
mode: str = "constant",
value: float = 0.0,
name: str = "pad",
) -> Tensor:
"""
Apply spatial padding to the input tensor.
Parameters
----------
x : Tensor
Input tensor to be padded.
pad : List[int]
List in the format of [before_0, after_0, before_1, after_1, ...]
indicating how much to pad each axis of x.
mod : str
Padding mode to use, constant implies padded elements will use
value argument.
value : float
What to pad with in constant mode.
name : str
Name hint for this operator.
Returns
-------
result : Tensor
Padded output tensor.
"""
return wrap_nested(_op.nn.pad(x._expr, pad_width=pad, pad_mode=mode, pad_value=value), name)
[文档]
def square(x: Tensor, name: str = "square") -> Tensor:
"""Computes the element-wise square of the input tensor.
Parameters
----------
x : Tensor
The input tensor.
name : str
Name hint.
Returns
-------
result : Tensor
The computed result.
"""
return wrap_nested(_op.square(x._expr), name)
[文档]
def sqrt(x: Tensor, name: str = "sqrt") -> Tensor:
"""Computes the element-wise sqrt of the input tensor.
Parameters
----------
x : Tensor
The input tensor.
name : str
Name hint.
Returns
-------
result : Tensor
The computed result.
Note
----
The input tensor is required to have float dtype
"""
return wrap_nested(_op.sqrt(x._expr), name)
[文档]
def get_timestep_embedding(
x: Tensor,
embedding_dim: int,
flip_sin_to_cos: bool = False,
downscale_freq_shift: float = 1,
scale: float = 1,
max_period: int = 10000,
name: str = "get_timestep_embedding",
) -> Tensor:
"""
Timestep calculation as described in Denoising Diffusion Probabilistic Models.
Parameters
----------
x : Tensor
A 1-D Tensor of N indices.
embedding_dim : int
The dimension of the output.
flip_sin_to_cos : bool
If True, change the order of sine and cosine embeddings.
downscale_freq_shift : float
Adjusts the frequency of the sinusoidal sampling.
scale : float
Weight adjustment for embedding magnitude.
max_period : int
Controls the minimum frequency of the embeddings.
name : str
The name to label this operator with.
Returns
-------
result : Tensor
[N x dim] Tensor of positional embeddings.
"""
dtype = get_default_dtype()
# Arithmetic should be done in float for precision.
timesteps = _op.astype(x._expr, "float32")
half_dim = embedding_dim // 2
exponent = rx.const(-math.log(max_period), "float32") * _op.arange(
start=0, end=half_dim, dtype="float32"
)
exponent = exponent / (rx.const(half_dim - downscale_freq_shift, "float32"))
emb = _op.exp(exponent)
emb = _op.expand_dims(timesteps, 1) * _op.expand_dims(emb, 0)
# Scale embeddings
if scale != 1:
emb = rx.const(scale, "float32") * emb
# Concat sine and cosine embeddings.
if flip_sin_to_cos:
emb = _op.concat([_op.cos(emb), _op.sin(emb)], axis=-1)
else:
emb = _op.concat([_op.sin(emb), _op.cos(emb)], axis=-1)
# Zero pad
if embedding_dim % 2 == 1:
emb = _op.nn.pad(emb, (0, 1, 0, 0))
# Cast to proper output type
emb = _op.astype(emb, dtype)
return wrap_nested(emb, name)
[文档]
def scaled_dot_product_attention(
query: Tensor,
key: Tensor,
value: Tensor,
attn_mask: Optional[Tensor] = None,
is_causal: Optional[bool] = False,
scale: Optional[float] = None,
name: str = "scaled_dot_product_attention",
):
"""
Computes a scaled dot product attention on provided attention
query, key, and values. Compliant with the functional torch implementation.
Parameters
----------
query : Tensor
Tensor representing current attention lookup of shape
[batch, seq_len, num_heads, head_size].
key : Tensor
Tensor representing cross attention mapping of shape
[batch, seq_len_kv, num_heads_kv, head_size].
value : Tensor
Tensor representing embedded attention values of shape
[batch, seq_len_kv, num_heads_kv, head_size_value].
attn_mask : Optional[Tensor]
Optional mask for attention, not yet supported.
is_causal : Optional[bool]
If set, uses a causal attention mask.
scale : Optional[float]
Optional extra scaling argument applied to attention.
name : str
Name hint for this function.
"""
assert attn_mask is None, "attn_mask not yet supported."
causal_mask = "TopLeft" if is_causal else None
attn = _op.nn.attention(
query._expr, key._expr, value._expr, causal_mask=causal_mask, scale=scale
)
return wrap_nested(attn, name)
[文档]
def interpolate(
x: Tensor,
size: Optional[Union[int, Tuple[int]]] = None,
scale_factor: Optional[Union[float, Tuple[float]]] = None,
mode: str = "nearest",
align_corners: Optional[bool] = None,
recompute_scale_factor: Optional[bool] = None,
antialias: Optional[bool] = None,
data_layout: Optional[str] = "NCHW",
name: str = "interpolate",
):
"""Resize a tensor using the specified mode.
Parameters
----------
x : Tensor
Input tensor to be resized.
size : Optional[Union[int, Tuple[int]]]
Requested output size, only one of size and scale_factor may
be specified.
scale_factor : Optional[Union[float, Tuple[float]]]
Multiplier for spatial size.
mode : str
Algorithm used for sampling.
align_corners : Optional[bool]
How to map pixels before and after sampling.
recompute_scale_factor : Optional[bool]
Recompute the scale_factor for use in interpolation.
antialias : Optional[bool]
Apply antialiasing to output.
data_layout : Optional[str]
Layout of the input and output data.
name : str
Name hint for this operation.
Returns
-------
result : Tensor
Output tensor with requested shape.
"""
assert recompute_scale_factor is None, "recompute_scale_factor is not supported."
assert antialias is None, "antialias is not supported."
if size is None:
size = []
for i, dim in enumerate(data_layout):
# Only upscale spatial dimensions.
if dim not in ["N", "C"]:
if isinstance(scale_factor, (list, tuple)):
size.append(int(x.shape[i] * scale_factor[len(size)]))
else:
size.append(int(x.shape[i] * scale_factor))
if mode.startswith("nearest"):
mode = "nearest_neighbor"
elif mode[0:2] == "bi":
mode = mode[2:]
if mode == "nearest_neighbor":
coord_trans = "asymmetric"
elif align_corners:
coord_trans = "align_corners"
else:
coord_trans = "half_pixel"
return wrap_nested(
_op.image.resize2d(
x._expr,
size,
layout=data_layout,
method=mode,
coordinate_transformation_mode=coord_trans,
),
name,
)
[文档]
def ccl_allreduce(x: Tensor, op_type: str = "sum", in_group: bool = True, name="ccl_allreduce"):
"""CCL Allreduce operator
Parameters
----------
x : relax.Expr
The input tensor.
op_type : str
The type of reduction operation to be applied to the input data.
Now "sum", "prod", "min", "max" and "avg" are supported.
in_group : bool
Whether the reduction operation performs globally or in group as default.
name : str
Name hint for this operation.
Returns
-------
result : Tensor
The result tensor of allreduce.
"""
return wrap_nested(_op.ccl.allreduce(x._expr, op_type, in_group), name)
[文档]
def ccl_allgather(x: Tensor, num_workers: int, name="ccl_allgather"):
"""CCL Allgather operator
Parameters
----------
x : relax.Expr
The input tensor.
num_workers : int
Number of workers.
name : str
Name hint for this operation.
Returns
-------
result : Tensor
The result tensor of allgather.
"""
return wrap_nested(_op.ccl.allgather(x._expr, num_workers), name)
[文档]
def ccl_broadcast_from_worker0(x: Tensor, name="broadcast_from_worker"):
"""Broadcast data from worker-0 to all other workers.
Parameters
----------
x : Tensor
The tensor to be broadcast.
name : str
Name hint for this operation.
Returns
-------
result : Tensor
The same tensor, which has been broadcast to all other workers.
"""
return wrap_nested(_op.ccl.broadcast_from_worker0(x._expr), name)
[文档]
def tensor_expr_op(
tensor_expr_func: Callable,
name_hint: str,
args: List[Union[Tensor, _tir.Var, int]],
*,
attrs: Optional[Dict[str, Any]] = None,
):
"""Build the given tensor_expr_func with te.
Parameters
----------
tensor_expr_func : Callable
A function that returns a te tensor or a list of tensors.
name_hint : str
Name hint.
args: List[Union[Tensor, _tir.Var]]
Arguments passed to the function.
attrs: Optional[Dict[str, Any]]
A dict of attributes to apply to the function.
Returns
-------
result : Tensor
The result tensor.
"""
def _convert(arg):
if isinstance(arg, Tensor):
return arg._expr # pylint: disable=protected-access
return arg
return wrap_nested(
BlockBuilder.current().emit_te(
tensor_expr_func,
*[_convert(arg) for arg in args],
primfunc_name_hint=name_hint,
primfunc_attrs=attrs,
),
name=name_hint,
)
OutType = TypeVar("OutType", bound=Union[Tensor, Sequence[Tensor]])
[文档]
def tensor_ir_op(
func: _tir.PrimFunc,
name_hint: str,
args: Union[Tensor, Sequence[Union[Tensor, rx.ShapeExpr, _tir.PrimExpr]]],
out: OutType,
) -> OutType:
"""Create a `call_tir` binding with given PrimFunc
Parameters
----------
func : _tir.PrimFunc
The PrimFunc to call.
name_hint : str
Name hint.
args : Union[Tensor, Sequence[Union[Tensor, rx.ShapeExpr, _tir.PrimExpr]]]
The arguments to pass to the PrimFunc.
out : Union[Tensor, List[Tensor]]
The output tensors.
Returns
-------
result : Tensor
The result tensor
"""
from tvm import relax as rx # pylint: disable=import-outside-toplevel
call_tir_args, tir_vars = [], []
if not isinstance(args, (tuple, list)):
args = [args]
for arg in args:
if isinstance(arg, Tensor):
call_tir_args.append(arg._expr)
elif isinstance(arg, (rx.ShapeExpr, _tir.PrimExpr)):
tir_vars.append(arg)
else:
raise TypeError(
"Unsupported type: tensor_ir_op args expect Tensor or ShapeExpr or PrimExpr,"
f"but got {type(arg)}"
)
if isinstance(out, Tensor):
out_sinfo = [out._expr.struct_info]
else:
out_sinfo = [x._expr.struct_info for x in out]
bb = BlockBuilder.current()
global_var = bb.add_func(func, name_hint)
if len(tir_vars) == 0:
tir_vars = None
return wrap_nested(
bb.emit(rx.call_tir(global_var, call_tir_args, out_sinfo, tir_vars=tir_vars)),
name=name_hint,
)
[文档]
def tensor_ir_inplace_op(
func: _tir.PrimFunc,
name_hint: str,
args: Union[Tensor, Sequence[Union[Tensor, rx.ShapeExpr, _tir.PrimExpr]]],
inplace_indices: Union[int, List[int]],
out: OutType,
) -> OutType:
"""Create a `call_tir_inplace` binding with given PrimFunc
Parameters
----------
func : _tir.PrimFunc
The PrimFunc to call.
name_hint : str
Name hint.
args : Union[Tensor, Sequence[Union[Tensor, rx.ShapeExpr, _tir.PrimExpr]]]
The arguments to pass to the PrimFunc.
inplace_indices : Union[int, List[int]]
Specify which arguments should be used for in-place computations.
If `inplace_indices` is a single integer, it will be made into a singleton list.
Suppose `inplace_indices[i] = j`, where `j >= 0`. Then the `i`th output
will be an alias of `args[j]`.
If `inplace_indices[i] = -1`, then the `i`th output will be a freshly allocated tensor.
At least one member of `inplace_indices` must not be -1.
out : Union[Tensor, List[Tensor]]
The output tensors.
Returns
-------
result : Tensor
The result tensor
"""
from tvm import relax as rx # pylint: disable=import-outside-toplevel
call_tir_args, tir_vars = [], []
if not isinstance(args, (tuple, list)):
args = [args]
for arg in args:
if isinstance(arg, Tensor):
call_tir_args.append(arg._expr)
elif isinstance(arg, (rx.ShapeExpr, _tir.PrimExpr)):
tir_vars.append(arg)
else:
raise TypeError(
"Unsupported type: tensor_ir_inplace_op args expect Tensor or ShapeExpr or"
f" PrimExpr, but got {type(arg)}"
)
if isinstance(out, Tensor):
out_sinfo = [out._expr.struct_info]
else:
out_sinfo = [x._expr.struct_info for x in out]
bb = BlockBuilder.current()
global_var = bb.add_func(func, name_hint)
return wrap_nested(
bb.emit(
rx.call_tir_inplace(global_var, call_tir_args, inplace_indices, out_sinfo, tir_vars)
),
name=name_hint,
)
[文档]
def extern(
name: str,
args: Sequence[Union[Tensor, _tir.PrimExpr, int, float, str]],
out: OutType,
) -> OutType:
"""Invoke an extern function during runtime. The extern function must be registered with the "
TVM runtime using `reflection::GlobalDef().def` (C++), or `tvm.register_global_func` (Python).
Parameters
----------
name : str
The name of the extern function to call.
args : Sequence[Union[Tensor, _tir.PrimExpr, int, float, str]]
The arguments to pass to the extern function.
out : Union[Tensor, List[Tensor]]
The output tensors, only
Returns
-------
result : Tensor
The result
"""
from tvm import relax as rx # pylint: disable=import-outside-toplevel
def _convert(arg, name: str):
if isinstance(arg, Tensor):
return arg._expr # pylint: disable=protected-access
if isinstance(arg, int):
return rx.PrimValue(_tir.IntImm("int64", arg))
if isinstance(arg, float):
return rx.PrimValue(_tir.FloatImm("float64", arg))
if isinstance(arg, str):
return rx.StringImm(arg)
if isinstance(arg, _tir.PrimExpr):
return rx.PrimValue(arg)
if isinstance(arg, (tuple, list)):
return rx.Tuple([_convert(e, f"{name}_{i}") for i, e in enumerate(arg)])
raise TypeError(f"Unsupported input type: {type(arg)}")
rx_inputs = _convert(args, "input")
rx_outputs_sinfo = _convert(out, "dummy").struct_info
return wrap_nested(
_op.call_dps_packed(
name,
args=rx_inputs,
out_sinfo=rx_outputs_sinfo,
),
name,
) # type: ignore
[文档]
def debug_func(
name: str,
*args: Union[Tensor, _tir.PrimExpr, int, float, str],
_line_info: Optional[str] = None,
):
"""Call a debug function during runtime. The debug function must be registered with the
following type signature:
.. code-block:: python
@tvm.register_global_func(name_of_debug_func)
def debug_func(lineno: str, arg_0, arg_1, ...) -> None:
...
Parameters
----------
name : str
The name of the debug function to call.
*args : Union[Tensor, _tir.PrimExpr, int, float, str]
The arguments to pass to the debug function.
"""
# pylint: disable=import-outside-toplevel
from tvm import relax as rx
from .exporter import Exporter
from .modules import IOEffect
# pylint: enable=import-outside-toplevel
if Exporter.current().io_effect is None:
raise RuntimeError("Debugging is only supported when debug mode is on.")
io: IOEffect = Exporter.current().io_effect # type: ignore
if _line_info is None:
filename, line_number = inspect.getframeinfo(inspect.currentframe().f_back)[:2]
_line_info = f"{filename}:{line_number}"
converted_args = []
for arg in args:
if isinstance(arg, Tensor):
converted_args.append(arg._expr) # pylint: disable=protected-access
elif isinstance(arg, int):
converted_args.append(rx.PrimValue(_tir.IntImm("int64", arg)))
elif isinstance(arg, float):
converted_args.append(rx.PrimValue(_tir.FloatImm("float32", arg)))
elif isinstance(arg, _tir.PrimExpr):
converted_args.append(rx.PrimValue(arg))
elif isinstance(arg, str):
converted_args.append(rx.StringImm(arg))
else:
raise TypeError(f"Unsupported type {type(arg)}")
io.effect = BlockBuilder.current().emit(
rx.call_pure_packed(
"vm.builtin.invoke_debug_func",
io.effect,
rx.StringImm(name),
rx.StringImm(_line_info),
*converted_args,
sinfo_args=[rx.ObjectStructInfo()],
),
name_hint=io.effect.name_hint,
)
[文档]
def print_(tensor: Tensor):
"""Debug printing a Tensor during runtime."""
filename, line_number = inspect.getframeinfo(inspect.currentframe().f_back)[:2]
line_info = f"{filename}:{line_number}"
debug_func("vm.builtin.debug_print", tensor, _line_info=line_info)
[文档]
def less(a: Tensor, b: Tensor, name: str = "less") -> Tensor:
"""Broadcasted element-wise comparison for (lhs < rhs).
Parameters
----------
a : Tensor
The first input tensor.
b : Tensor
The second input tensor.
name : str
Name hint.
Returns
-------
result : Tensor
The computed result.
"""
return wrap_nested(_op.less(a._expr, b._expr), name)
[文档]
def less_equal(a: Tensor, b: Tensor, name: str = "less_equal") -> Tensor:
"""Broadcasted element-wise comparison for (lhs <= rhs).
Parameters
----------
a : Tensor
The first input tensor.
b : Tensor
The second input tensor.
name : str
Name hint.
Returns
-------
result : Tensor
The computed result.
"""
return wrap_nested(_op.less_equal(a._expr, b._expr), name)
[文档]
def greater(a: Tensor, b: Tensor, name: str = "greater") -> Tensor:
"""Broadcasted element-wise comparison for (lhs > rhs).
Parameters
----------
a : Tensor
The first input tensor.
b : Tensor
The second input tensor.
name : str
Name hint.
Returns
-------
result : Tensor
The computed result.
"""
return wrap_nested(_op.greater(a._expr, b._expr), name)
[文档]
def greater_equal(a: Tensor, b: Tensor, name: str = "greater_equal") -> Tensor:
"""Broadcasted element-wise comparison for (lhs >= rhs).
Parameters
----------
a : Tensor
The first input tensor.
b : Tensor
The second input tensor.
name : str
Name hint.
Returns
-------
result : Tensor
The computed result.
"""
return wrap_nested(_op.greater_equal(a._expr, b._expr), name)
[文档]
def equal(a: Tensor, b: Tensor, name: str = "equal") -> Tensor:
"""Broadcasted element-wise comparison for (lhs == rhs).
Parameters
----------
a : Tensor
The first input tensor.
b : Tensor
The second input tensor.
name : str
Name hint.
Returns
-------
result : Tensor
The computed result.
"""
return wrap_nested(_op.equal(a._expr, b._expr), name)
[文档]
def not_equal(a: Tensor, b: Tensor, name: str = "not_equal") -> Tensor:
"""Broadcasted element-wise comparison for (lhs != rhs).
Parameters
----------
a : Tensor
The first input tensor.
b : Tensor
The second input tensor.
name : str
Name hint.
Returns
-------
result : Tensor
The computed result.
"""
return wrap_nested(_op.not_equal(a._expr, b._expr), name)
[文档]
def where(condition: Tensor, x1: Tensor, x2: Tensor, name: str = "where") -> Tensor:
"""Selecting elements from either the input tensors depending on the value of the
condition.
For a given position, return the corresponding value in `x1` if `condition` is True,
and return the corresponding value in `x2` otherwise.
Parameters
----------
condition : Tensor
When True, yield `x1`; otherwise, yield `x2`.
Must be broadcasting compatible with `x1` and `x2`.
Must have boolean dtype.
x1 : Tensor
The first input tensor.
Must be broadcasting compatible with `condition` and `x2`.
x2 : Tensor
The second input tensor.
Must be broadcasting compatible with `condition` and `x1`.
name : str
Name hint.
Returns
-------
result : Tensor
The result tensor.
"""
# Cast condition to boolean.
condition = astype(condition, "bool")
return wrap_nested(_op.where(condition._expr, x1._expr, x2._expr), name)
[文档]
def cumsum(
data: Tensor,
axis: Optional[int] = None,
dtype: Optional[str] = None,
exclusive: Optional[bool] = None,
name: str = "cumsum",
) -> Tensor:
"""Numpy style cumsum op. Return the cumulative inclusive sum of the elements along
a given axis.
Parameters
----------
data : Tensor
The input data to the operator.
axis : Optional[int]
Axis along which the cumulative sum is computed. The default (None) is to compute
the cumsum over the flattened array.
dtype : Optional[str]
Type of the returned array and of the accumulator in which the elements are summed.
If dtype is not specified, it defaults to the dtype of data.
exclusive : Optional[bool]
If true will return exclusive sum in which the first element is not
included.
name : str
Name hint.
Returns
-------
result : Tensor
The result has the same size as data, and the same shape as data if axis is not None.
If axis is None, the result is a 1-d array.
Examples
--------
.. code-block:: python
a = [[1, 2, 3], [4, 5, 6]]
cumsum(a) # if axis is not provided, cumsum is done over the flattened input.
-> [ 1, 3, 6, 10, 15, 21]
cumsum(a, dtype="float32")
-> [ 1., 3., 6., 10., 15., 21.]
cumsum(a, axis=0) # sum over rows for each of the 3 columns
-> [[1, 2, 3],
[5, 7, 9]]
cumsum(a, axis=1)
-> [[ 1, 3, 6],
[ 4, 9, 15]]
a = [1, 0, 1, 0, 1, 1, 0] # a is a boolean array
cumsum(a, dtype=int32) # dtype should be provided to get the expected results
-> [1, 1, 2, 2, 3, 4, 4]
"""
return wrap_nested(_op.cumsum(data._expr, axis, dtype, exclusive), name)
[文档]
def sort(x: Tensor, axis: int = -1, descending: bool = False, name="sort"):
"""Performs sorting along the given axis and returns an array
in sorted order.
Parameters
----------
x : Tensor
The input tensor.
axis : int
Axis along which to sort the input tensor.
By default the last axis of the input is used.
descending : bool
Whether to sort in descending order, the default is False
name : str
Name hint.
Returns
-------
out : Tensor
The sorted tensor.
"""
return wrap_nested(_op.sort(x._expr, axis, descending), name=name)
[文档]
def argsort(
data: Tensor, axis: int = -1, descending: bool = False, dtype: str = "int32", name="argsort"
):
"""Performs sorting along the given axis and returns an array of indices
having same shape as an input array that index data in sorted order.
Parameters
----------
data : Tensor
The input data tensor.
axis : int
Axis long which to sort the input tensor.
descending : bool
Whether to sort in descending order, the default is False
dtype : str
The data type of the output indices.
name : str
Name hint.
Returns
-------
out : Tensor
The indices of the sorted tensor.
"""
return wrap_nested(_op.argsort(data._expr, axis, descending, dtype), name=name)
[文档]
def topk(
data: Tensor,
k: int = 1,
axis: int = -1,
ret_type: str = "both",
largest: bool = True,
dtype: str = "int32",
name: str = "topk",
):
"""Get the top k elements in an input tensor along the given axis.
ret_type specifies the return type, can be one of ("both", "values", "indices").
Parameters
----------
data : Tensor
The input data tensor.
k : int
Number of top elements to select. Return all elements if k < 1.
axis : int
Axis long which to sort the input tensor.
ret_type: str
The return type [both, values, indices].
"both": return both top k data and indices.
"values": return top k data only.
"indices": return top k indices only.
largest : bool
Whether to return largest or smallest elements.
The k smallest elements are returned if largest is False.
dtype : str
The data type of the indices output.
name : str
Name hint.
Returns
-------
out : Tensor or Tuple[Tensor, Tensor]
The computed result.
"""
return wrap_nested(_op.topk(data._expr, k, axis, ret_type, largest, dtype), name=name)
[文档]
def sample_top_p_top_k_from_sorted_prob(
sorted_prob: Tensor,
sorted_index: Tensor,
top_p: Tensor,
top_k: Tensor,
uniform_sample: Tensor,
sample_indices: Optional[Tensor] = None,
):
"""Samples indices from a sorted probability tensor based on top_p and top_k criteria.
Notes
-----
For accurate results, ensure probabilities are between 0 and 1 and sum to 1.
Parameters
----------
sorted_prob : Tensor
A 2-D tensor, with shape (batch, vocab_size), contains probabilities
sorted in descending order.
sorted_index: Tensor
The indices tensor with shape (batch, vocab_size), corresponding to the
sorted_prob. Potentially from applying argsort on the original probability
tensor in descending order.
top_p : Tensor
The cumulative probability threshold with shape (batch, 1) for nucleus sampling.
top_k :Tensor
A tensor with shape (batch, 1), representing the number of top probabilities
to consider for top-k sampling.
uniform_sample : Tensor
Uniformly sampled values with shape (n, 1) are used to select the output indices.
sample_indices : Optional[Tensor]
The 2-D tensor with the shape [n, 1], which indicates the specific
probability distribution to sample from. The value of sample_indices[i]
determines that the ith token should be sampled from the sample_indices[i]th
probability distribution. For instance, if there are 3 distinct probability
distributions and the requirement is to sample 2, 3, and 4 tokens from each,
then sample_indices would be [0, 0, 1, 1, 1, 2, 2, 2, 2].
Returns
-------
result : Tensor
The selected indices with shape (n, 1).
Examples
--------
.. code-block:: python
prob = [[0.1 , 0.4, 0.5],
[0.3, 0.3, 0.4]]
sorted_prob = [[0.5, 0.4, 0.1],
[0.4, 0.3, 0.3]]
sorted_index = [[2, 1, 0],
[2, 0, 1]]
top_p = [[0.6],[0.9]]
top_k = [[3],[2]]
uniform_sample = [[0.5], [0.6]]
sample_indices = [[0], [1]]
sample_top_p_top_k_from_sorted_prob(
sorted_prob, sorted_index,top_p, top_k, uniform_sample, sample_indices)
-> [2, 0]
"""
prob_dtype = sorted_prob.dtype
index_dtype = sorted_index.dtype
prob_batch = sorted_prob.shape[0]
out_batch = uniform_sample.shape[0]
if sample_indices is not None:
assert (
sample_indices.shape == uniform_sample.shape
), "The shape of sample_indices must match the shape of uniform_sample."
else:
assert (
sorted_prob.shape[0] == uniform_sample.shape[0]
), "Number of samples must match the number of probability distributions."
sample_indices = Tensor.from_const(
np.arange(out_batch).reshape(out_batch, 1).astype(np.int64)
)
print("sample_indices: ", sample_indices)
sample_indices_dtype = sample_indices.dtype
def _cumsum_mask(cumsum_sorted, top_p, top_k, i, j):
return _tir.all(cumsum_sorted[i, j] < top_p[i, 0], j + 1 < top_k[i, 0])
@T.prim_func(private=True)
def _get_renorm_prob(A: T.handle, B: T.handle, C: T.handle, D: T.handle):
batch, vocab_size = T.int64(is_size_var=True), T.int64(is_size_var=True)
cumsum_sorted = T.match_buffer(A, (batch, vocab_size), prob_dtype)
top_p = T.match_buffer(B, (batch, 1), prob_dtype)
top_k = T.match_buffer(C, (batch, 1), index_dtype)
renorm_prob = T.match_buffer(D, (batch, 1), prob_dtype)
for ax0, ax1 in T.grid(batch, vocab_size):
with T.block("T_get_renorm_prob"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
if not _cumsum_mask(cumsum_sorted, top_p, top_k, v_ax0, 0):
renorm_prob[v_ax0, 0] = cumsum_sorted[v_ax0, 0]
elif _cumsum_mask(cumsum_sorted, top_p, top_k, v_ax0, v_ax1):
if v_ax1 + 1 == vocab_size:
renorm_prob[v_ax0, 0] = cumsum_sorted[v_ax0, v_ax1]
elif not _cumsum_mask(cumsum_sorted, top_p, top_k, v_ax0, v_ax1 + 1):
renorm_prob[v_ax0, 0] = cumsum_sorted[v_ax0, v_ax1 + 1]
@T.prim_func(private=True)
def _get_index_from_sorted(
A: T.handle, B: T.handle, C: T.handle, D: T.handle, E: T.handle, F: T.handle
):
batch, vocab_size = T.int64(is_size_var=True), T.int64(is_size_var=True)
out_batch = T.int64(is_size_var=True)
cumsum_sorted = T.match_buffer(A, (batch, vocab_size), prob_dtype)
indices = T.match_buffer(B, (batch, vocab_size), index_dtype)
renorm_prob = T.match_buffer(C, (batch, 1), prob_dtype)
usample = T.match_buffer(D, (out_batch, 1), prob_dtype)
sample_indices = T.match_buffer(E, (out_batch, 1), sample_indices_dtype)
output_index = T.match_buffer(F, (out_batch, 1), index_dtype)
for ax0, ax1 in T.grid(out_batch, vocab_size):
with T.block("T_get_index_from_sorted"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.writes(output_index[v_ax0, 0])
if (
usample[v_ax0, T.int64(0)]
< cumsum_sorted[sample_indices[v_ax0, T.int64(0)], v_ax1]
/ renorm_prob[sample_indices[v_ax0, T.int64(0)], 0]
or v_ax1 + 1 == vocab_size
):
if v_ax1 == 0:
output_index[v_ax0, 0] = indices[sample_indices[v_ax0, T.int64(0)], 0]
elif (
usample[v_ax0, T.int64(0)]
>= cumsum_sorted[sample_indices[v_ax0, T.int64(0)], v_ax1 - 1]
/ renorm_prob[sample_indices[v_ax0, T.int64(0)], 0]
):
output_index[v_ax0, 0] = indices[sample_indices[v_ax0, T.int64(0)], v_ax1]
cumsum_sorted = cumsum(sorted_prob, axis=1)
renorm_prob = tensor_ir_op(
_get_renorm_prob,
"get_renorm_prob",
args=[cumsum_sorted, top_p, top_k],
out=Tensor.placeholder(
[prob_batch, 1],
prob_dtype,
),
)
out_index_in_sorted = tensor_ir_op(
_get_index_from_sorted,
"get_index_from_sorted",
args=[cumsum_sorted, sorted_index, renorm_prob, uniform_sample, sample_indices],
out=Tensor.placeholder([out_batch, 1], index_dtype),
)
return out_index_in_sorted
[文档]
def renormalize_top_p_top_k_prob(prob, sorted_prob, top_p, top_k):
"""Renormalizes probabilities after filtering with top_p and top_k, ensuring
they sum up to 1.
Notes
-----
For accurate results, ensure probabilities are between 0 and 1 and sum to 1.
Parameters
----------
prob : Tensor
A 2-D tensor of shape (batch, vocab_size) representing probability distributions.
sorted_prob : Tensor
Probabilities sorted in descending order.
top_p : Tensor
The cumulative probability threshold with shape (batch, 1) for nucleus sampling.
top_k :Tensor
A tensor with shape (batch, 1), representing the number of top probabilities
to consider for top-k sampling.
Returns
-------
result : Tensor
The filtered and nomalized tensor with the sampe shape as input prob.
"""
prob_dtype = prob.dtype
top_k_dtype = top_k.dtype
batch = sorted_prob.shape[0]
def _cumsum_mask(cumsum_sorted, top_p, top_k, i, j):
return _tir.all(cumsum_sorted[i, j] < top_p[i, 0], j + 1 < top_k[i, 0])
@T.prim_func(private=True)
def _get_renorm_cutoff(A: T.handle, B: T.handle, C: T.handle, D: T.handle, E: T.handle):
batch, vocab_size = T.int64(), T.int64()
sorted_prob = T.match_buffer(A, (batch, vocab_size), prob_dtype)
cumsum_sorted = T.match_buffer(B, (batch, vocab_size), prob_dtype)
top_p = T.match_buffer(C, (batch, 1), prob_dtype)
top_k = T.match_buffer(D, (batch, 1), top_k_dtype)
cutoff = T.match_buffer(E, (batch, 1), prob_dtype)
for ax0, ax1 in T.grid(batch, vocab_size):
with T.block("T_get_renorm_cutoff"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
if _cumsum_mask(cumsum_sorted, top_p, top_k, v_ax0, 0) == 0:
cutoff[v_ax0, 0] = sorted_prob[v_ax0, 0]
elif _cumsum_mask(cumsum_sorted, top_p, top_k, v_ax0, v_ax1) == 1:
if v_ax1 + 1 == vocab_size:
cutoff[v_ax0, 0] = sorted_prob[v_ax0, v_ax1]
elif _cumsum_mask(cumsum_sorted, top_p, top_k, v_ax0, v_ax1 + 1) == 0:
cutoff[v_ax0, 0] = sorted_prob[v_ax0, v_ax1 + 1]
cumsum_sorted = cumsum(sorted_prob, axis=1)
renorm_cutoff = tensor_ir_op(
_get_renorm_cutoff,
"get_renorm_cutoff",
args=[sorted_prob, cumsum_sorted, top_p, top_k],
out=Tensor.placeholder(
[batch, 1],
prob_dtype,
),
)
filtered_prob = tensor_expr_op(
lambda prob, renorm_cutoff: te.compute(
prob.shape,
lambda i, j: _tir.Select(prob[i, j] >= renorm_cutoff[i, 0], prob[i, j], 0.0),
name="filter_with_top_p_top_k",
),
"filter_with_top_p_top_k",
args=[prob, renorm_cutoff],
)
renorm_prob = filtered_prob / sum(filtered_prob, axis=1, keepdims=True)
return renorm_prob