Scatter¶
- torch_scatter.scatter(src: Tensor, index: Tensor, dim: int = -1, out: Tensor | None = None, dim_size: int | None = None, reduce: str = 'sum') Tensor[source]¶
Reduces all values from the
srctensor intooutat the indices specified in theindextensor along a given axisdim. For each value insrc, its output index is specified by its index insrcfor dimensions outside ofdimand by the corresponding value inindexfor dimensiondim. The applied reduction is defined via thereduceargument.Formally, if
srcandindexare \(n\)-dimensional tensors with size \((x_0, ..., x_{i-1}, x_i, x_{i+1}, ..., x_{n-1})\) anddim= i, thenoutmust be an \(n\)-dimensional tensor with size \((x_0, ..., x_{i-1}, y, x_{i+1}, ..., x_{n-1})\). Moreover, the values ofindexmust be between \(0\) and \(y - 1\), although no specific ordering of indices is required. Theindextensor supports broadcasting in case its dimensions do not match withsrc.For one-dimensional tensors with
reduce="sum", the operation computes\[\mathrm{out}_i = \mathrm{out}_i + \sum_j~\mathrm{src}_j\]where \(\sum_j\) is over \(j\) such that \(\mathrm{index}_j = i\).
Note
This operation is implemented via atomic operations on the GPU and is therefore non-deterministic since the order of parallel operations to the same value is undetermined. For floating-point variables, this results in a source of variance in the result.
- Parameters:
src – The source tensor.
index – The indices of elements to scatter.
dim – The axis along which to index. (default:
-1)out – The destination tensor.
dim_size – If
outis not given, automatically create output with sizedim_sizeat dimensiondim. Ifdim_sizeis not given, a minimal sized output tensor according toindex.max() + 1is returned.reduce – The reduce operation (
"sum","mul","mean","min"or"max"). (default:"sum")
- Return type:
Tensor
from torch_scatter import scatter src = torch.randn(10, 6, 64) index = torch.tensor([0, 1, 0, 1, 2, 1]) # Broadcasting in the first and last dim. out = scatter(src, index, dim=1, reduce="sum") print(out.size())
torch.Size([10, 3, 64])