Scatter

torch_scatter.scatter(src: torch.Tensor, index: torch.Tensor, dim: int = - 1, out: Optional[torch.Tensor] = None, dim_size: Optional[int] = None, reduce: str = 'sum') torch.Tensor[source]

https://raw.githubusercontent.com/rusty1s/pytorch_scatter/master/docs/source/_figures/add.svg?sanitize=true

Reduces all values from the src tensor into out at the indices specified in the index tensor along a given axis dim. For each value in src, its output index is specified by its index in src for dimensions outside of dim and by the corresponding value in index for dimension dim. The applied reduction is defined via the reduce argument.

Formally, if src and index are \(n\)-dimensional tensors with size \((x_0, ..., x_{i-1}, x_i, x_{i+1}, ..., x_{n-1})\) and dim = i, then out must be an \(n\)-dimensional tensor with size \((x_0, ..., x_{i-1}, y, x_{i+1}, ..., x_{n-1})\). Moreover, the values of index must be between \(0\) and \(y - 1\), although no specific ordering of indices is required. The index tensor supports broadcasting in case its dimensions do not match with src.

For one-dimensional tensors with reduce="sum", the operation computes

\[\mathrm{out}_i = \mathrm{out}_i + \sum_j~\mathrm{src}_j\]

where \(\sum_j\) is over \(j\) such that \(\mathrm{index}_j = i\).

Note

This operation is implemented via atomic operations on the GPU and is therefore non-deterministic since the order of parallel operations to the same value is undetermined. For floating-point variables, this results in a source of variance in the result.

Parameters
  • src – The source tensor.

  • index – The indices of elements to scatter.

  • dim – The axis along which to index. (default: -1)

  • out – The destination tensor.

  • dim_size – If out is not given, automatically create output with size dim_size at dimension dim. If dim_size is not given, a minimal sized output tensor according to index.max() + 1 is returned.

  • reduce – The reduce operation ("sum", "mul", "mean", "min" or "max"). (default: "sum")

Return type

Tensor

from torch_scatter import scatter

src = torch.randn(10, 6, 64)
index = torch.tensor([0, 1, 0, 1, 2, 1])

# Broadcasting in the first and last dim.
out = scatter(src, index, dim=1, reduce="sum")

print(out.size())
torch.Size([10, 3, 64])