Segment COO¶
- torch_scatter.segment_coo(src: Tensor, index: Tensor, out: Tensor | None = None, dim_size: int | None = None, reduce: str = 'sum') Tensor [source]¶
Reduces all values from the
src
tensor intoout
at the indices specified in theindex
tensor along the last dimension ofindex
. For each value insrc
, its output index is specified by its index insrc
for dimensions outside ofindex.dim() - 1
and by the corresponding value inindex
for dimensionindex.dim() - 1
. The applied reduction is defined via thereduce
argument.Formally, if
src
andindex
are \(n\)-dimensional and \(m\)-dimensional tensors with size \((x_0, ..., x_{m-1}, x_m, x_{m+1}, ..., x_{n-1})\) and \((x_0, ..., x_{m-1}, x_m)\), respectively, thenout
must be an \(n\)-dimensional tensor with size \((x_0, ..., x_{m-1}, y, x_{m+1}, ..., x_{n-1})\). Moreover, the values ofindex
must be between \(0\) and \(y - 1\) in ascending order. Theindex
tensor supports broadcasting in case its dimensions do not match withsrc
.For one-dimensional tensors with
reduce="sum"
, the operation computes\[\mathrm{out}_i = \mathrm{out}_i + \sum_j~\mathrm{src}_j\]where \(\sum_j\) is over \(j\) such that \(\mathrm{index}_j = i\).
In contrast to
scatter()
, this method expects values inindex
to be sorted along dimensionindex.dim() - 1
. Due to the use of sorted indices,segment_coo()
is usually faster than the more generalscatter()
operation.Note
This operation is implemented via atomic operations on the GPU and is therefore non-deterministic since the order of parallel operations to the same value is undetermined. For floating-point variables, this results in a source of variance in the result.
- Parameters:
src – The source tensor.
index – The sorted indices of elements to segment. The number of dimensions of
index
needs to be less than or equal tosrc
.out – The destination tensor.
dim_size – If
out
is not given, automatically create output with sizedim_size
at dimensionindex.dim() - 1
. Ifdim_size
is not given, a minimal sized output tensor according toindex.max() + 1
is returned.reduce – The reduce operation (
"sum"
,"mean"
,"min"
or"max"
). (default:"sum"
)
- Return type:
Tensor
from torch_scatter import segment_coo src = torch.randn(10, 6, 64) index = torch.tensor([0, 0, 1, 1, 1, 2]) index = index.view(1, -1) # Broadcasting in the first and last dim. out = segment_coo(src, index, reduce="sum") print(out.size())
torch.Size([10, 3, 64])