Scatter¶
-
torch_scatter.
scatter
(src, index, dim=-1, out=None, dim_size=None, reduce='sum')[source]¶ Reduces all values from the
src
tensor intoout
at the indices specified in theindex
tensor along a given axisdim
. For each value insrc
, its output index is specified by its index insrc
for dimensions outside ofdim
and by the corresponding value inindex
for dimensiondim
. The applied reduction is defined via thereduce
argument.Formally, if
src
andindex
are \(n\)-dimensional tensors with size \((x_0, ..., x_{i-1}, x_i, x_{i+1}, ..., x_{n-1})\) anddim
= i, thenout
must be an \(n\)-dimensional tensor with size \((x_0, ..., x_{i-1}, y, x_{i+1}, ..., x_{n-1})\). Moreover, the values ofindex
must be between \(0\) and \(y - 1\) in ascending order. Theindex
tensor supports broadcasting in case its dimensions do not match withsrc
.For one-dimensional tensors with
reduce="sum"
, the operation computes\[\mathrm{out}_i = \mathrm{out}_i + \sum_j~\mathrm{src}_j\]where \(\sum_j\) is over \(j\) such that \(\mathrm{index}_j = i\).
Note
This operation is implemented via atomic operations on the GPU and is therefore non-deterministic since the order of parallel operations to the same value is undetermined. For floating-point variables, this results in a source of variance in the result.
Parameters: - src (
Tensor
) – The source tensor. - index (
Tensor
) – The indices of elements to scatter. - dim (
int
) – The axis along which to index. (default:-1
) - out (
Optional
[Tensor
]) – The destination tensor. - dim_size (
Optional
[int
]) – Ifout
is not given, automatically create output with sizedim_size
at dimensiondim
. Ifdim_size
is not given, a minimal sized output tensor according toindex.max() + 1
is returned. - reduce (
str
) – The reduce operation ("sum"
,"mean"
,"min"
or"max"
). (default:"sum"
)
Return type: Tensor
from torch_scatter import scatter src = torch.randn(10, 6, 64) index = torch.tensor([0, 1, 0, 1, 2, 1]) # Broadcasting in the first and last dim. out = scatter(src, index, dim=1, reduce="sum") print(out.size())
- src (