Segment COO

torch_scatter.segment_coo(src: torch.Tensor, index: torch.Tensor, out: Optional[torch.Tensor] = None, dim_size: Optional[int] = None, reduce: str = 'sum') → torch.Tensor[source]

Reduces all values from the src tensor into out at the indices specified in the index tensor along the last dimension of index. For each value in src, its output index is specified by its index in src for dimensions outside of index.dim() - 1 and by the corresponding value in index for dimension index.dim() - 1. The applied reduction is defined via the reduce argument.

Formally, if src and index are \(n\)-dimensional and \(m\)-dimensional tensors with size \((x_0, ..., x_{m-1}, x_m, x_{m+1}, ..., x_{n-1})\) and \((x_0, ..., x_{m-1}, x_m)\), respectively, then out must be an \(n\)-dimensional tensor with size \((x_0, ..., x_{m-1}, y, x_{m+1}, ..., x_{n-1})\). Moreover, the values of index must be between \(0\) and \(y - 1\) in ascending order. The index tensor supports broadcasting in case its dimensions do not match with src.

For one-dimensional tensors with reduce="sum", the operation computes

\[\mathrm{out}_i = \mathrm{out}_i + \sum_j~\mathrm{src}_j\]

where \(\sum_j\) is over \(j\) such that \(\mathrm{index}_j = i\).

In contrast to scatter(), this method expects values in index to be sorted along dimension index.dim() - 1. Due to the use of sorted indices, segment_coo() is usually faster than the more general scatter() operation.


This operation is implemented via atomic operations on the GPU and is therefore non-deterministic since the order of parallel operations to the same value is undetermined. For floating-point variables, this results in a source of variance in the result.

  • src – The source tensor.

  • index – The sorted indices of elements to segment. The number of dimensions of index needs to be less than or equal to src.

  • out – The destination tensor.

  • dim_size – If out is not given, automatically create output with size dim_size at dimension index.dim() - 1. If dim_size is not given, a minimal sized output tensor according to index.max() + 1 is returned.

  • reduce – The reduce operation ("sum", "mean", "min" or "max"). (default: "sum")

Return type


from torch_scatter import segment_coo

src = torch.randn(10, 6, 64)
index = torch.tensor([0, 0, 1, 1, 1, 2])
index = index.view(1, -1)  # Broadcasting in the first and last dim.

out = segment_coo(src, index, reduce="sum")

torch.Size([10, 3, 64])