Scatter¶
- torch_scatter.scatter(src: Tensor, index: Tensor, dim: int = -1, out: Tensor | None = None, dim_size: int | None = None, reduce: str = 'sum') Tensor [source]¶
Reduces all values from the
src
tensor intoout
at the indices specified in theindex
tensor along a given axisdim
. For each value insrc
, its output index is specified by its index insrc
for dimensions outside ofdim
and by the corresponding value inindex
for dimensiondim
. The applied reduction is defined via thereduce
argument.Formally, if
src
andindex
are \(n\)-dimensional tensors with size \((x_0, ..., x_{i-1}, x_i, x_{i+1}, ..., x_{n-1})\) anddim
= i, thenout
must be an \(n\)-dimensional tensor with size \((x_0, ..., x_{i-1}, y, x_{i+1}, ..., x_{n-1})\). Moreover, the values ofindex
must be between \(0\) and \(y - 1\), although no specific ordering of indices is required. Theindex
tensor supports broadcasting in case its dimensions do not match withsrc
.For one-dimensional tensors with
reduce="sum"
, the operation computes\[\mathrm{out}_i = \mathrm{out}_i + \sum_j~\mathrm{src}_j\]where \(\sum_j\) is over \(j\) such that \(\mathrm{index}_j = i\).
Note
This operation is implemented via atomic operations on the GPU and is therefore non-deterministic since the order of parallel operations to the same value is undetermined. For floating-point variables, this results in a source of variance in the result.
- Parameters:
src – The source tensor.
index – The indices of elements to scatter.
dim – The axis along which to index. (default:
-1
)out – The destination tensor.
dim_size – If
out
is not given, automatically create output with sizedim_size
at dimensiondim
. Ifdim_size
is not given, a minimal sized output tensor according toindex.max() + 1
is returned.reduce – The reduce operation (
"sum"
,"mul"
,"mean"
,"min"
or"max"
). (default:"sum"
)
- Return type:
Tensor
from torch_scatter import scatter src = torch.randn(10, 6, 64) index = torch.tensor([0, 1, 0, 1, 2, 1]) # Broadcasting in the first and last dim. out = scatter(src, index, dim=1, reduce="sum") print(out.size())
torch.Size([10, 3, 64])