Segment COO¶

torch_scatter.
segment_coo
(src: torch.Tensor, index: torch.Tensor, out: Optional[torch.Tensor] = None, dim_size: Optional[int] = None, reduce: str = 'sum') → torch.Tensor[source]¶ Reduces all values from the
src
tensor intoout
at the indices specified in theindex
tensor along the last dimension ofindex
. For each value insrc
, its output index is specified by its index insrc
for dimensions outside ofindex.dim()  1
and by the corresponding value inindex
for dimensionindex.dim()  1
. The applied reduction is defined via thereduce
argument.Formally, if
src
andindex
are \(n\)dimensional and \(m\)dimensional tensors with size \((x_0, ..., x_{m1}, x_m, x_{m+1}, ..., x_{n1})\) and \((x_0, ..., x_{m1}, x_m)\), respectively, thenout
must be an \(n\)dimensional tensor with size \((x_0, ..., x_{m1}, y, x_{m+1}, ..., x_{n1})\). Moreover, the values ofindex
must be between \(0\) and \(y  1\) in ascending order. Theindex
tensor supports broadcasting in case its dimensions do not match withsrc
.For onedimensional tensors with
reduce="sum"
, the operation computes\[\mathrm{out}_i = \mathrm{out}_i + \sum_j~\mathrm{src}_j\]where \(\sum_j\) is over \(j\) such that \(\mathrm{index}_j = i\).
In contrast to
scatter()
, this method expects values inindex
to be sorted along dimensionindex.dim()  1
. Due to the use of sorted indices,segment_coo()
is usually faster than the more generalscatter()
operation.Note
This operation is implemented via atomic operations on the GPU and is therefore nondeterministic since the order of parallel operations to the same value is undetermined. For floatingpoint variables, this results in a source of variance in the result.
 Parameters
src – The source tensor.
index – The sorted indices of elements to segment. The number of dimensions of
index
needs to be less than or equal tosrc
.out – The destination tensor.
dim_size – If
out
is not given, automatically create output with sizedim_size
at dimensionindex.dim()  1
. Ifdim_size
is not given, a minimal sized output tensor according toindex.max() + 1
is returned.reduce – The reduce operation (
"sum"
,"mean"
,"min"
or"max"
). (default:"sum"
)
 Return type
Tensor
from torch_scatter import segment_coo src = torch.randn(10, 6, 64) index = torch.tensor([0, 0, 1, 1, 1, 2]) index = index.view(1, 1) # Broadcasting in the first and last dim. out = segment_coo(src, index, reduce="sum") print(out.size())
torch.Size([10, 3, 64])