Segment CSR¶
- torch_scatter.segment_csr(src: Tensor, indptr: Tensor, out: Tensor | None = None, reduce: str = 'sum') Tensor [source]¶
Reduces all values from the
src
tensor intoout
within the ranges specified in theindptr
tensor along the last dimension ofindptr
. For each value insrc
, its output index is specified by its index insrc
for dimensions outside ofindptr.dim() - 1
and by the corresponding range index inindptr
for dimensionindptr.dim() - 1
. The applied reduction is defined via thereduce
argument.Formally, if
src
andindptr
are \(n\)-dimensional and \(m\)-dimensional tensors with size \((x_0, ..., x_{m-1}, x_m, x_{m+1}, ..., x_{n-1})\) and \((x_0, ..., x_{m-2}, y)\), respectively, thenout
must be an \(n\)-dimensional tensor with size \((x_0, ..., x_{m-2}, y - 1, x_{m}, ..., x_{n-1})\). Moreover, the values ofindptr
must be between \(0\) and \(x_m\) in ascending order. Theindptr
tensor supports broadcasting in case its dimensions do not match withsrc
.For one-dimensional tensors with
reduce="sum"
, the operation computes\[\mathrm{out}_i = \sum_{j = \mathrm{indptr}[i]}^{\mathrm{indptr}[i+1]-1}~\mathrm{src}_j.\]Due to the use of index pointers,
segment_csr()
is the fastest method to apply for grouped reductions.Note
In contrast to
scatter()
andsegment_coo()
, this operation is fully-deterministic.- Parameters:
src – The source tensor.
indptr – The index pointers between elements to segment. The number of dimensions of
index
needs to be less than or equal tosrc
.out – The destination tensor.
reduce – The reduce operation (
"sum"
,"mean"
,"min"
or"max"
). (default:"sum"
)
- Return type:
Tensor
from torch_scatter import segment_csr src = torch.randn(10, 6, 64) indptr = torch.tensor([0, 2, 5, 6]) indptr = indptr.view(1, -1) # Broadcasting in the first and last dim. out = segment_csr(src, indptr, reduce="sum") print(out.size())
torch.Size([10, 3, 64])