Segment CSR¶

torch_scatter.
segment_csr
(src: torch.Tensor, indptr: torch.Tensor, out: Optional[torch.Tensor] = None, reduce: str = 'sum') → torch.Tensor[source]¶ Reduces all values from the
src
tensor intoout
within the ranges specified in theindptr
tensor along the last dimension ofindptr
. For each value insrc
, its output index is specified by its index insrc
for dimensions outside ofindptr.dim()  1
and by the corresponding range index inindptr
for dimensionindptr.dim()  1
. The applied reduction is defined via thereduce
argument.Formally, if
src
andindptr
are \(n\)dimensional and \(m\)dimensional tensors with size \((x_0, ..., x_{m1}, x_m, x_{m+1}, ..., x_{n1})\) and \((x_0, ..., x_{m1}, y)\), respectively, thenout
must be an \(n\)dimensional tensor with size \((x_0, ..., x_{m1}, y  1, x_{m+1}, ..., x_{n1})\). Moreover, the values ofindptr
must be between \(0\) and \(x_m\) in ascending order. Theindptr
tensor supports broadcasting in case its dimensions do not match withsrc
.For onedimensional tensors with
reduce="sum"
, the operation computes\[\mathrm{out}_i = \sum_{j = \mathrm{indptr}[i]}^{\mathrm{indptr}[i+i]}~\mathrm{src}_j.\]Due to the use of index pointers,
segment_csr()
is the fastest method to apply for grouped reductions.Note
In contrast to
scatter()
andsegment_coo()
, this operation is fullydeterministic. Parameters
src – The source tensor.
indptr – The index pointers between elements to segment. The number of dimensions of
index
needs to be less than or equal tosrc
.out – The destination tensor.
reduce – The reduce operation (
"sum"
,"mean"
,"min"
or"max"
). (default:"sum"
)
 Return type
Tensor
from torch_scatter import segment_csr src = torch.randn(10, 6, 64) indptr = torch.tensor([0, 2, 5, 6]) indptr = indptr.view(1, 1) # Broadcasting in the first and last dim. out = segment_csr(src, indptr, reduce="sum") print(out.size())
torch.Size([10, 3, 64])