Segment CSR¶
- torch_scatter.segment_csr(src: Tensor, indptr: Tensor, out: Tensor | None = None, reduce: str = 'sum') Tensor[source]¶
Reduces all values from the
srctensor intooutwithin the ranges specified in theindptrtensor along the last dimension ofindptr. For each value insrc, its output index is specified by its index insrcfor dimensions outside ofindptr.dim() - 1and by the corresponding range index inindptrfor dimensionindptr.dim() - 1. The applied reduction is defined via thereduceargument.Formally, if
srcandindptrare \(n\)-dimensional and \(m\)-dimensional tensors with size \((x_0, ..., x_{m-1}, x_m, x_{m+1}, ..., x_{n-1})\) and \((x_0, ..., x_{m-2}, y)\), respectively, thenoutmust be an \(n\)-dimensional tensor with size \((x_0, ..., x_{m-2}, y - 1, x_{m}, ..., x_{n-1})\). Moreover, the values ofindptrmust be between \(0\) and \(x_m\) in ascending order. Theindptrtensor supports broadcasting in case its dimensions do not match withsrc.For one-dimensional tensors with
reduce="sum", the operation computes\[\mathrm{out}_i = \sum_{j = \mathrm{indptr}[i]}^{\mathrm{indptr}[i+1]-1}~\mathrm{src}_j.\]Due to the use of index pointers,
segment_csr()is the fastest method to apply for grouped reductions.Note
In contrast to
scatter()andsegment_coo(), this operation is fully-deterministic.- Parameters:
src – The source tensor.
indptr – The index pointers between elements to segment. The number of dimensions of
indexneeds to be less than or equal tosrc.out – The destination tensor.
reduce – The reduce operation (
"sum","mean","min"or"max"). (default:"sum")
- Return type:
Tensor
from torch_scatter import segment_csr src = torch.randn(10, 6, 64) indptr = torch.tensor([0, 2, 5, 6]) indptr = indptr.view(1, -1) # Broadcasting in the first and last dim. out = segment_csr(src, indptr, reduce="sum") print(out.size())
torch.Size([10, 3, 64])