Skip to content

Instantly share code, notes, and snippets.

@spezold
Created May 17, 2024 14:38
Show Gist options
  • Save spezold/f8fe94ea9750543be7b7468e3bc5658f to your computer and use it in GitHub Desktop.
Save spezold/f8fe94ea9750543be7b7468e3bc5658f to your computer and use it in GitHub Desktop.
Calculate stable mean and standard deviation over a potentially large sample of values, using Chan et al.'s version of Welford's online algorithm (same as "mean_std.py", but in PyTorch)
"""
Calculate sample mean and std. over all desired axes of given samples, using Chan et al.'s version of Welford's online
algorithm; cf. https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm (20240515) and
equations (2.1a), (2.1b) in the referenced paper [1]_; or likewise, equations (1.5a), (1.5b) in [2]_.
Mind the typo in [1]_, (2.1a): T_{1, m+n} on the right side of the equation should be T_{m+1, m+n} (cf. [2]_, (1.5a)).
References
----------
.. [1] T. F. Chan, G. H. Golub, and R. J. LeVeque, “Updating Formulae and a Pairwise Algorithm for Computing Sample
Variances,” in COMPSTAT 1982 5th Symposium held at Toulouse 1982, Heidelberg, 1982, pp. 30–41,
doi: 10.1007/978-3-642-51461-6_3.
.. [2] T. F. Chan, G. H. Golub, and R. J. Leveque, “Algorithms for Computing the Sample Variance: Analysis and
Recommendations,” The American Statistician, vol. 37, no. 3, pp. 242–247, 1983,
doi: 10.1080/00031305.1983.10483115.
"""
from collections.abc import Iterable
import torch
try:
from tqdm import tqdm
except ImportError:
tqdm = None
float_ = torch.float
torch_device = "cpu"
def n_s_t_new(
values: torch.Tensor,
n_s_t_old: tuple[int, torch.Tensor, torch.Tensor] | None = None,
*,
axis: int | str | None = None,
) -> tuple[int, torch.Tensor, torch.Tensor]:
"""
Update sample number ``n`` and values ``S``, ``T``, based on chosen ``axis`` and shape plus contents of given
``values``. The arrays in ``n_s_t_old``, if given, will be updated *in-place* to provide the result.
:param values: array that contains values to be used in update
:param n_s_t_old: current values for ``n``, ``S``, ``T`` (if given, assert ``S``, ``T`` are ``float64``)
:param axis: axis over which to reduce; if None (default), reduce over all ``values`` → resulting ``S``, ``T`` will
be scalar (0-dimensional) arrays; if int, reduce over corresponding axis in ``values`` → resulting ``S``, ``T``
will lose this axis; if "pointwise", reduce over (virtual) dimension of consecutive function calls → resulting
``S``, ``T`` will have the same shape as ``values``; in any case, resulting ``n`` will be an int scalar
:return: updated sample number ``n`` and values ``S``, ``T`` (see ``axis`` for shape, ``n_s_t_old`` for dtype)
"""
# Check shapes of n, S, T
if n_s_t_old is not None:
n, s, t = n_s_t_old
assert isinstance(n, int), f"Sample number `n` should be an integer (is {type(n)})."
for v, name in zip((s, t), "st"):
assert isinstance(v, torch.Tensor), f"Value `{name}` should be a Torch tensor (is {type(v)})."
assert v.dtype == float_, f"Data type of value `{name}` should be {float_} (is {v.dtype})."
if axis is None:
for v, name in zip((s, t), "st"):
assert v.ndim == 0, f"For `axis=None`, {name} should be scalar (0-d) (has {v.ndim} dimensions)."
elif axis == "pointwise":
for v, name in zip((s, t), "st"):
assert v.shape == values.shape, (f"For `axis='pointwise'` and `values` of shape {values.shape}, shape "
f"of {name} should be {values.shape} (is {v.shape}).")
elif isinstance(axis, int):
shape_should = list(values.shape)
shape_should.pop(axis)
shape_should = tuple(shape_should)
for v, name in zip((s, t), "st"):
assert v.shape == shape_should, (f"For `axis={axis}` and `values` of shape {values.shape}, shape of "
f"{name} should be {shape_should} (is {v.shape}).")
else:
raise ValueError(f"Cannot handle `axis={axis}` of type `{type(axis)}` (provide int, None, or 'pointwise').")
def s_for(v_, t_, n_): # This is not yet summed
t_ = torch.clone(t_).to(float_)
s_ = torch.clone(v_).to(float_)
return torch.square(torch.subtract(v_, torch.divide(t_, n_, out=t_), out=s_), out=s_)
def t_diff_for(t_old_, t_spl_, n_old_, n_spl_, n_new_):
d_ = torch.clone(t_old_).to(float_)
d_ = torch.square(torch.subtract(torch.multiply(n_spl_ / n_old_, d_, out=d_), t_spl_, out=d_), out=d_)
return torch.multiply(n_old_ / (n_spl_ * n_new_), d_, out=d_)
# Calculate n, S, T for sample
if axis == "pointwise":
n_spl = 1
t_spl = torch.clone(values).to(float_)
s_spl = torch.zeros_like(t_spl)
elif axis is None:
n_spl = int(values.numel())
t_spl = torch.sum(values, dtype=float_)
s_spl = torch.sum(s_for(values, t_spl, n_spl))
else:
n_spl = int(values.shape[axis])
t_spl = torch.sum(values, dim=axis, dtype=float_)
s_spl = torch.sum(s_for(values, torch.unsqueeze(t_spl, dim=axis), n_spl), dim=axis)
if n_s_t_old is None or n_s_t_old[0] == 0:
# If no old values, only use the new ones (which may also be empty, causing ``n_old == 0`` on next input)
n_new, s_new, t_new = n_spl, s_spl, t_spl
else:
n_old, s_old, t_old = n_s_t_old
n_new = n_old + n_spl
t_diff = t_diff_for(t_old, t_spl, n_old, n_spl, n_new)
t_new = torch.add(t_old, t_spl, out=t_old)
s_new = torch.add(torch.add(s_old, s_spl, out=s_old), t_diff, out=s_old)
return n_new, s_new, t_new
def _mean_std(n: int, s: torch.Tensor, t: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""
Calculate sample mean and standard deviation from sample number ``n`` and values ``S``, ``T``, as produced by
:func:`n_s_t_new`. The given arrays ``s``, ``t`` will be updated *in-place* to provide the result.
:param n: sample number in the axis over which values were reduced
:param s: intermediate result ``S`` of [1]_, [2]_
:param t: intermediate result ``T`` of [1]_, [2]_
:return: resulting mean and standard deviation
"""
mean = torch.divide(t, n, out=t)
std = torch.sqrt(torch.divide(s, n, out=s), out=s)
return mean, std
def mean_std(
values: Iterable[torch.Tensor],
*,
axis: int | str | None = None,
as_scalars: bool = True,
show_progress: bool = True,
**progres_args
) -> tuple[torch.Tensor | float, torch.Tensor | float]:
"""
Calculate sample mean and standard deviation from given values.
The ``axis`` parameter enables control over how values should be reduced:
* if None (default), reduce over all values → results will be scalar (0-dimensional) arrays;
* if int, reduce over corresponding axis in each item of the iterable → results will lose this axis as compared to
the items of the iterable;
* if "pointwise", reduce over item dimension of the iterable → results will have the same shape as the items of the
iterable.
Example:
Assume ``values`` is an ``n``-element list with each item (list entry) being a 2×3×5-shaped array.
* ``mean_std(values, axis=None)``: mean and std. are scalars (reduced over all given values)
* ``mean_std(values, axis=1)``: mean and std. are 2×5-shaped arrays (reduced over all values in axis 1 in all items
of the given list)
* ``mean_std(values, axis="pointwise")``: mean and std. are 2×3×5-shaped arrays (for each item index, reduced over
all items of the given list)
:param values: values from which to calculate the mean and standard deviation; potentially a "huge" iterable of
"regular-sized" arrays
:param axis: axis over which to reduce in each item of the given iterable ``values`` (see above)
:param as_scalars: if True (default), provide 0-dimensional results as native Python ``float`` values; if False,
keep the results as scalar (0-dimensional) Torch tensors; has no effect on ``d``-dimensional results with ``d>0``
:param show_progress: if True (default) show progress (a progress bar if the iterable provides length information,
runtime statistics otherwise); uses ``tqdm``, has no effect if ``tqdm`` is not installed
:param progres_args: further arguments to pass to ``tqdm`` for showing progress
:return: resulting mean and standard deviation (Torch tensors of ``float`` values or scalars, cf. ``as_scalars``)
"""
n_s_t = None
for values_item in (values if not show_progress or tqdm is None else tqdm(values, **progres_args)):
n_s_t = n_s_t_new(values_item, n_s_t, axis=axis)
mean, std = _mean_std(*n_s_t)
return (mean.item(), std.item()) if (as_scalars and mean.ndim == 0) else (mean, std)
if __name__ == "__main__":
import numpy as np # Benchmark against Numpy, as torch seems to be too imprecise
torch.random.manual_seed(42)
# Create 7 items (arrays) of shape 2×3×5 each
a = torch.normal(mean=0., std=1., size=(7, 2, 3, 5)).to(torch_device)
a_backup = torch.clone(a)
# Test 1: reduce over all values
assert np.allclose((np.mean(a.cpu().numpy()), np.std(a.cpu().numpy())), mean_std(item for item in a)) # Compare to pure Numpy
assert torch.all(a_backup == a) # Inputs should remain unaltered
assert all(isinstance(v, float) for v in mean_std(a)) # Both results should be native floats by default
assert all(isinstance(v, torch.Tensor) for v in mean_std(a, as_scalars=False)) # Both results should be Numpy arrays
# Test 2: reduce over axis 1 in each item, in all items → 2×5 result
a_np = np.transpose(a.cpu().numpy(), (0, 2, 1, 3)).reshape(-1, a[0].shape[0], a[0].shape[2])
assert np.allclose((np.mean(a_np, axis=0), np.std(a_np, axis=0)),
[ms.cpu().numpy() for ms in mean_std((item for item in a), axis=1)]) # Compare to pure Numpy
assert torch.all(a_backup == a) # Inputs should remain unaltered
# Test 3: for each index, reduce over all items → 2×3×5 result
assert np.allclose((np.mean(a.cpu().numpy(), axis=0), np.std(a.cpu().numpy(), axis=0)),
[ms.cpu().numpy() for ms in mean_std((item for item in a), axis="pointwise")]) # Compare to pure Numpy
assert torch.all(a_backup == a) # Inputs should remain unaltered
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment