Segment CSR¶
-
torch_scatter.
segment_csr
(src, indptr, out=None, reduce='sum')[source]¶ Reduces all values from the
src
tensor intoout
within the ranges specified in theindptr
tensor along the last dimension ofindptr
. For each value insrc
, its output index is specified by its index insrc
for dimensions outside ofindptr.dim() - 1
and by the corresponding range index inindptr
for dimensionindptr.dim() - 1
. The applied reduction is defined via thereduce
argument.Formally, if
src
andindptr
are \(n\)-dimensional and \(m\)-dimensional tensors with size \((x_0, ..., x_{m-1}, x_m, x_{m+1}, ..., x_{n-1})\) and \((x_0, ..., x_{m-1}, y)\), respectively, thenout
must be an \(n\)-dimensional tensor with size \((x_0, ..., x_{m-1}, y - 1, x_{m+1}, ..., x_{n-1})\). Moreover, the values ofindptr
must be between \(0\) and \(x_m\) in ascending order. Theindptr
tensor supports broadcasting in case its dimensions do not match withsrc
.For one-dimensional tensors with
reduce="sum"
, the operation computes\[\mathrm{out}_i = \sum_{j = \mathrm{indptr}[i]}^{\mathrm{indptr}[i+i]}~\mathrm{src}_j.\]Due to the use of index pointers,
segment_csr()
is the fastest method to apply for grouped reductions.Note
In contrast to
scatter()
andsegment_coo()
, this operation is fully-deterministic.Parameters: - src (
Tensor
) – The source tensor. - indptr (
Tensor
) – The index pointers between elements to segment. The number of dimensions ofindex
needs to be less than or equal tosrc
. - out (
Optional
[Tensor
]) – The destination tensor. - reduce (
str
) – The reduce operation ("sum"
,"mean"
,"min"
or"max"
). (default:"sum"
)
Return type: Tensor
from torch_scatter import segment_csr src = torch.randn(10, 6, 64) indptr = torch.tensor([0, 2, 5, 6]) indptr = indptr.view(1, -1) # Broadcasting in the first and last dim. out = segment_csr(src, indptr, reduce="sum") print(out.size())
- src (