Skip to content

Instantly share code, notes, and snippets.

@thomasbrandon
thomasbrandon / running_stats.py
Created September 26, 2019 06:37
Collect running statistics (mean/std) efficiently in PyTorch
import torch
from torch import Tensor
from typing import Iterable
from fastprogress import progress_bar
class RunningStatistics:
'''Records mean and variance of the final `n_dims` dimension over other dimensions across items. So collecting across `(l,m,n,o)` sized
items with `n_dims=1` will collect `(l,m,n)` sized statistics while with `n_dims=2` the collected statistics will be of size `(l,m)`.
Uses the algorithm from Chan, Golub, and LeVeque in "Algorithms for computing the sample variance: analysis and recommendations":