Source code for torch_scatter.composite.std

from typing import Optional

import torch
from torch_scatter import scatter_sum
from torch_scatter.utils import broadcast


[docs]def scatter_std(src: torch.Tensor, index: torch.Tensor, dim: int = -1, out: Optional[torch.Tensor] = None, dim_size: Optional[int] = None, unbiased: bool = True) -> torch.Tensor: if out is not None: dim_size = out.size(dim) if dim < 0: dim = src.dim() + dim count_dim = dim if index.dim() <= dim: count_dim = index.dim() - 1 ones = torch.ones(index.size(), dtype=src.dtype, device=src.device) count = scatter_sum(ones, index, count_dim, dim_size=dim_size) index = broadcast(index, src, dim) tmp = scatter_sum(src, index, dim, dim_size=dim_size) count = broadcast(count, tmp, dim).clamp(1) mean = tmp.div(count) var = (src - mean.gather(dim, index)) var = var * var out = scatter_sum(var, index, dim, out, dim_size) if unbiased: count = count.sub(1).clamp_(1) out = out.div(count + 1e-6).sqrt() return out