Source code for torch_scatter.std

import torch

from torch_scatter import scatter_add
from torch_scatter.utils.gen import gen


[docs]def scatter_std(src, index, dim=-1, out=None, dim_size=None, unbiased=True): r""" | .. image:: https://raw.githubusercontent.com/rusty1s/pytorch_scatter/ master/docs/source/_figures/std.svg?sanitize=true :align: center :width: 400px | Computes the standard-deviation from all values from the :attr:`src` tensor into :attr:`out` at the indices specified in the :attr:`index` tensor along a given axis :attr:`dim` (`cf.` :meth:`~torch_scatter.scatter_add`). For one-dimensional tensors, the operation computes .. math:: \mathrm{out}_i = \sqrt{\frac{\sum_j {\left( x_j - \overline{x}_i \right)}^2}{N_i - 1}} where :math:`\sum_j` is over :math:`j` such that :math:`\mathrm{index}_j = i`. :math:`N_i` and :math:`\overline{x}_i` indicate the number of indices referencing :math:`i` and their mean value, respectively. Args: src (Tensor): The source tensor. index (LongTensor): The indices of elements to scatter. dim (int, optional): The axis along which to index. (default: :obj:`-1`) out (Tensor, optional): The destination tensor. (default: :obj:`None`) dim_size (int, optional): If :attr:`out` is not given, automatically create output with size :attr:`dim_size` at dimension :attr:`dim`. If :attr:`dim_size` is not given, a minimal sized output tensor is returned. (default: :obj:`None`) unbiased (bool, optional): If set to :obj:`False`, then the standard- deviation will be calculated via the biased estimator. (default: :obj:`True`) :rtype: :class:`Tensor` """ src, out, index, dim = gen(src, index, dim, out, dim_size, fill_value=0) tmp = None if out is None else out.clone().fill_(0) tmp = scatter_add(src, index, dim, tmp, dim_size) count = None if out is None else out.clone().fill_(0) count = scatter_add(torch.ones_like(src), index, dim, count, dim_size) mean = tmp / count.clamp(min=1) var = (src - mean.gather(dim, index)) var = var * var out = scatter_add(var, index, dim, out, dim_size) out = out / (count - 1 if unbiased else count).clamp(min=1) out = torch.sqrt(out) return out