torch2onnx:Scatter
#
参考:Scatter
警告
从版本 11 开始,Scatter
算子已被弃用,请使用 ScatterElements
,它提供了相同的功能。
Scatter
接受三个输入 data
、updates
和 indices
,它们的秩 r>=1
,以及可选的属性轴,用于标识 data
的轴(默认为最外层轴,即轴 0
)。算子的输出是通过创建输入 data
的副本,然后根据由 indices
指定的特定索引位置更新其值来生成的。它的输出形状与数据的形状相同。
对于 updates
中的每个条目,通过将 indices
中相应的条目与该条目本身的索引组合来获取 data
中的目标索引:维度 =axis
的索引值从 indices
中相应条目的值获得,维度 !=axis
的索引值从该条目本身的索引获得。
在二维张量的情况下,对应于 [i][j]
条目的更新如下进行:
output[indices[i][j]][j] = updates[i][j] if axis = 0,
output[i][indices[i][j]] = updates[i][j] if axis = 1,
该算子是 GatherElements
的逆运输。它类似于 Torch 的 Scatter 算子。
示例1:
data = [
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, 0.0, 0.0],
]
indices = [
[1, 0, 2],
[0, 2, 1],
]
updates = [
[1.0, 1.1, 1.2],
[2.0, 2.1, 2.2],
]
output = [
[2.0, 1.1, 0.0]
[1.0, 0.0, 2.2]
[0.0, 2.1, 1.2]
]
示例2:
data = [[1.0, 2.0, 3.0, 4.0, 5.0]]
indices = [[1, 3]]
updates = [[1.1, 2.1]]
axis = 1
output = [[1.0, 1.1, 3.0, 2.1, 5.0]]
可选属性
axis
:在哪个轴上进行扩散。负值意味着从后面开始计算维度。可接受的范围是 \([-r, r-1]\),其中r = rank(data)
。
import torch
torch.Tensor.scatter_?
Docstring:
scatter_(dim, index, src, reduce=None) -> Tensor
Writes all values from the tensor :attr:`src` into :attr:`self` at the indices
specified in the :attr:`index` tensor. For each value in :attr:`src`, its output
index is specified by its index in :attr:`src` for ``dimension != dim`` and by
the corresponding value in :attr:`index` for ``dimension = dim``.
For a 3-D tensor, :attr:`self` is updated as::
self[index[i][j][k]][j][k] = src[i][j][k] # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k] # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k] # if dim == 2
This is the reverse operation of the manner described in :meth:`~Tensor.gather`.
:attr:`self`, :attr:`index` and :attr:`src` (if it is a Tensor) should all have
the same number of dimensions. It is also required that
``index.size(d) <= src.size(d)`` for all dimensions ``d``, and that
``index.size(d) <= self.size(d)`` for all dimensions ``d != dim``.
Note that ``index`` and ``src`` do not broadcast.
Moreover, as for :meth:`~Tensor.gather`, the values of :attr:`index` must be
between ``0`` and ``self.size(dim) - 1`` inclusive.
.. warning::
When indices are not unique, the behavior is non-deterministic (one of the
values from ``src`` will be picked arbitrarily) and the gradient will be
incorrect (it will be propagated to all locations in the source that
correspond to the same index)!
.. note::
The backward pass is implemented only for ``src.shape == index.shape``.
Additionally accepts an optional :attr:`reduce` argument that allows
specification of an optional reduction operation, which is applied to all
values in the tensor :attr:`src` into :attr:`self` at the indices
specified in the :attr:`index`. For each value in :attr:`src`, the reduction
operation is applied to an index in :attr:`self` which is specified by
its index in :attr:`src` for ``dimension != dim`` and by the corresponding
value in :attr:`index` for ``dimension = dim``.
Given a 3-D tensor and reduction using the multiplication operation, :attr:`self`
is updated as::
self[index[i][j][k]][j][k] *= src[i][j][k] # if dim == 0
self[i][index[i][j][k]][k] *= src[i][j][k] # if dim == 1
self[i][j][index[i][j][k]] *= src[i][j][k] # if dim == 2
Reducing with the addition operation is the same as using
:meth:`~torch.Tensor.scatter_add_`.
.. warning::
The reduce argument with Tensor ``src`` is deprecated and will be removed in
a future PyTorch release. Please use :meth:`~torch.Tensor.scatter_reduce_`
instead for more reduction options.
Args:
dim (int): the axis along which to index
index (LongTensor): the indices of elements to scatter, can be either empty
or of the same dimensionality as ``src``. When empty, the operation
returns ``self`` unchanged.
src (Tensor or float): the source element(s) to scatter.
reduce (str, optional): reduction operation to apply, can be either
``'add'`` or ``'multiply'``.
Example::
>>> src = torch.arange(1, 11).reshape((2, 5))
>>> src
tensor([[ 1, 2, 3, 4, 5],
[ 6, 7, 8, 9, 10]])
>>> index = torch.tensor([[0, 1, 2, 0]])
>>> torch.zeros(3, 5, dtype=src.dtype).scatter_(0, index, src)
tensor([[1, 0, 0, 4, 0],
[0, 2, 0, 0, 0],
[0, 0, 3, 0, 0]])
>>> index = torch.tensor([[0, 1, 2], [0, 1, 4]])
>>> torch.zeros(3, 5, dtype=src.dtype).scatter_(1, index, src)
tensor([[1, 2, 3, 0, 0],
[6, 7, 0, 0, 8],
[0, 0, 0, 0, 0]])
>>> torch.full((2, 4), 2.).scatter_(1, torch.tensor([[2], [3]]),
... 1.23, reduce='multiply')
tensor([[2.0000, 2.0000, 2.4600, 2.0000],
[2.0000, 2.0000, 2.0000, 2.4600]])
>>> torch.full((2, 4), 2.).scatter_(1, torch.tensor([[2], [3]]),
... 1.23, reduce='add')
tensor([[2.0000, 2.0000, 3.2300, 2.0000],
[2.0000, 2.0000, 2.0000, 3.2300]])
Type: method_descriptor