torch2onnx:Scatter

torch2onnx:Scatter#

参考:Scatter

警告

从版本 11 开始,Scatter 算子已被弃用,请使用 ScatterElements,它提供了相同的功能。

Scatter 接受三个输入 dataupdatesindices,它们的秩 r>=1,以及可选的属性轴,用于标识 data 的轴(默认为最外层轴,即轴 0)。算子的输出是通过创建输入 data 的副本,然后根据由 indices 指定的特定索引位置更新其值来生成的。它的输出形状与数据的形状相同。

对于 updates 中的每个条目,通过将 indices 中相应的条目与该条目本身的索引组合来获取 data 中的目标索引:维度 =axis 的索引值从 indices 中相应条目的值获得,维度 !=axis 的索引值从该条目本身的索引获得。

在二维张量的情况下,对应于 [i][j] 条目的更新如下进行:

output[indices[i][j]][j] = updates[i][j] if axis = 0,
output[i][indices[i][j]] = updates[i][j] if axis = 1,

该算子是 GatherElements 的逆运输。它类似于 Torch 的 Scatter 算子。

示例1:

data = [
    [0.0, 0.0, 0.0],
    [0.0, 0.0, 0.0],
    [0.0, 0.0, 0.0],
]
indices = [
    [1, 0, 2],
    [0, 2, 1],
]
updates = [
    [1.0, 1.1, 1.2],
    [2.0, 2.1, 2.2],
]
output = [
    [2.0, 1.1, 0.0]
    [1.0, 0.0, 2.2]
    [0.0, 2.1, 1.2]
]

示例2:

data = [[1.0, 2.0, 3.0, 4.0, 5.0]]
indices = [[1, 3]]
updates = [[1.1, 2.1]]
axis = 1
output = [[1.0, 1.1, 3.0, 2.1, 5.0]]
  • 可选属性 axis:在哪个轴上进行扩散。负值意味着从后面开始计算维度。可接受的范围是 \([-r, r-1]\),其中 r = rank(data)

import torch

torch.Tensor.scatter_?
Docstring:
scatter_(dim, index, src, reduce=None) -> Tensor

Writes all values from the tensor :attr:`src` into :attr:`self` at the indices
specified in the :attr:`index` tensor. For each value in :attr:`src`, its output
index is specified by its index in :attr:`src` for ``dimension != dim`` and by
the corresponding value in :attr:`index` for ``dimension = dim``.

For a 3-D tensor, :attr:`self` is updated as::

    self[index[i][j][k]][j][k] = src[i][j][k]  # if dim == 0
    self[i][index[i][j][k]][k] = src[i][j][k]  # if dim == 1
    self[i][j][index[i][j][k]] = src[i][j][k]  # if dim == 2

This is the reverse operation of the manner described in :meth:`~Tensor.gather`.

:attr:`self`, :attr:`index` and :attr:`src` (if it is a Tensor) should all have
the same number of dimensions. It is also required that
``index.size(d) <= src.size(d)`` for all dimensions ``d``, and that
``index.size(d) <= self.size(d)`` for all dimensions ``d != dim``.
Note that ``index`` and ``src`` do not broadcast.

Moreover, as for :meth:`~Tensor.gather`, the values of :attr:`index` must be
between ``0`` and ``self.size(dim) - 1`` inclusive.

.. warning::

    When indices are not unique, the behavior is non-deterministic (one of the
    values from ``src`` will be picked arbitrarily) and the gradient will be
    incorrect (it will be propagated to all locations in the source that
    correspond to the same index)!

.. note::

    The backward pass is implemented only for ``src.shape == index.shape``.

Additionally accepts an optional :attr:`reduce` argument that allows
specification of an optional reduction operation, which is applied to all
values in the tensor :attr:`src` into :attr:`self` at the indices
specified in the :attr:`index`. For each value in :attr:`src`, the reduction
operation is applied to an index in :attr:`self` which is specified by
its index in :attr:`src` for ``dimension != dim`` and by the corresponding
value in :attr:`index` for ``dimension = dim``.

Given a 3-D tensor and reduction using the multiplication operation, :attr:`self`
is updated as::

    self[index[i][j][k]][j][k] *= src[i][j][k]  # if dim == 0
    self[i][index[i][j][k]][k] *= src[i][j][k]  # if dim == 1
    self[i][j][index[i][j][k]] *= src[i][j][k]  # if dim == 2

Reducing with the addition operation is the same as using
:meth:`~torch.Tensor.scatter_add_`.

.. warning::
    The reduce argument with Tensor ``src`` is deprecated and will be removed in
    a future PyTorch release. Please use :meth:`~torch.Tensor.scatter_reduce_`
    instead for more reduction options.

Args:
    dim (int): the axis along which to index
    index (LongTensor): the indices of elements to scatter, can be either empty
        or of the same dimensionality as ``src``. When empty, the operation
        returns ``self`` unchanged.
    src (Tensor or float): the source element(s) to scatter.
    reduce (str, optional): reduction operation to apply, can be either
        ``'add'`` or ``'multiply'``.

Example::

    >>> src = torch.arange(1, 11).reshape((2, 5))
    >>> src
    tensor([[ 1,  2,  3,  4,  5],
            [ 6,  7,  8,  9, 10]])
    >>> index = torch.tensor([[0, 1, 2, 0]])
    >>> torch.zeros(3, 5, dtype=src.dtype).scatter_(0, index, src)
    tensor([[1, 0, 0, 4, 0],
            [0, 2, 0, 0, 0],
            [0, 0, 3, 0, 0]])
    >>> index = torch.tensor([[0, 1, 2], [0, 1, 4]])
    >>> torch.zeros(3, 5, dtype=src.dtype).scatter_(1, index, src)
    tensor([[1, 2, 3, 0, 0],
            [6, 7, 0, 0, 8],
            [0, 0, 0, 0, 0]])

    >>> torch.full((2, 4), 2.).scatter_(1, torch.tensor([[2], [3]]),
    ...            1.23, reduce='multiply')
    tensor([[2.0000, 2.0000, 2.4600, 2.0000],
            [2.0000, 2.0000, 2.0000, 2.4600]])
    >>> torch.full((2, 4), 2.).scatter_(1, torch.tensor([[2], [3]]),
    ...            1.23, reduce='add')
    tensor([[2.0000, 2.0000, 3.2300, 2.0000],
            [2.0000, 2.0000, 2.0000, 3.2300]])
Type:      method_descriptor