Segment CSR

torch_scatter.segment_csr(src: torch.Tensor, indptr: torch.Tensor, out: Optional[torch.Tensor] = None, reduce: str = 'sum') → torch.Tensor[source]

Reduces all values from the src tensor into out within the ranges specified in the indptr tensor along the last dimension of indptr. For each value in src, its output index is specified by its index in src for dimensions outside of indptr.dim() - 1 and by the corresponding range index in indptr for dimension indptr.dim() - 1. The applied reduction is defined via the reduce argument.

Formally, if src and indptr are \(n\)-dimensional and \(m\)-dimensional tensors with size \((x_0, ..., x_{m-1}, x_m, x_{m+1}, ..., x_{n-1})\) and \((x_0, ..., x_{m-1}, y)\), respectively, then out must be an \(n\)-dimensional tensor with size \((x_0, ..., x_{m-1}, y - 1, x_{m+1}, ..., x_{n-1})\). Moreover, the values of indptr must be between \(0\) and \(x_m\) in ascending order. The indptr tensor supports broadcasting in case its dimensions do not match with src.

For one-dimensional tensors with reduce="sum", the operation computes

\[\mathrm{out}_i = \sum_{j = \mathrm{indptr}[i]}^{\mathrm{indptr}[i+i]}~\mathrm{src}_j.\]

Due to the use of index pointers, segment_csr() is the fastest method to apply for grouped reductions.

Note

In contrast to scatter() and segment_coo(), this operation is fully-deterministic.

Parameters
  • src – The source tensor.

  • indptr – The index pointers between elements to segment. The number of dimensions of index needs to be less than or equal to src.

  • out – The destination tensor.

  • reduce – The reduce operation ("sum", "mean", "min" or "max"). (default: "sum")

Return type

Tensor

from torch_scatter import segment_csr

src = torch.randn(10, 6, 64)
indptr = torch.tensor([0, 2, 5, 6])
indptr = indptr.view(1, -1)  # Broadcasting in the first and last dim.

out = segment_csr(src, indptr, reduce="sum")

print(out.size())
torch.Size([10, 3, 64])