Skip to content

Instantly share code, notes, and snippets.

@thomasbrandon
Created September 26, 2019 06:37
Show Gist options
  • Star 4 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save thomasbrandon/ad5b1218fc573c10ea4e1f0c63658469 to your computer and use it in GitHub Desktop.
Save thomasbrandon/ad5b1218fc573c10ea4e1f0c63658469 to your computer and use it in GitHub Desktop.
Collect running statistics (mean/std) efficiently in PyTorch
import torch
from torch import Tensor
from typing import Iterable
from fastprogress import progress_bar
class RunningStatistics:
'''Records mean and variance of the final `n_dims` dimension over other dimensions across items. So collecting across `(l,m,n,o)` sized
items with `n_dims=1` will collect `(l,m,n)` sized statistics while with `n_dims=2` the collected statistics will be of size `(l,m)`.
Uses the algorithm from Chan, Golub, and LeVeque in "Algorithms for computing the sample variance: analysis and recommendations":
`variance = variance1 + variance2 + n/(m*(m+n)) * pow(((m/n)*t1 - t2), 2)`
This combines the variance for 2 blocks: block 1 having `n` elements with `variance1` and a sum of `t1` and block 2 having `m` elements
with `variance2` and a sum of `t2`. The algorithm is proven to be numerically stable but there is a reasonable loss of accuracy (~0.1% error).
Note that collecting minimum and maximum values is reasonably innefficient, adding about 80% to the running time, and hence is disabled by default.
'''
def __init__(self, n_dims:int=2, record_range=False):
self._n_dims,self._range = n_dims,record_range
self.n,self.sum,self.min,self.max = 0,None,None,None
def update(self, data:Tensor):
data = data.view(*list(data.shape[:-self._n_dims]) + [-1])
with torch.no_grad():
new_n,new_var,new_sum = data.shape[-1],data.var(-1),data.sum(-1)
if self.n == 0:
self.n = new_n
self._shape = data.shape[:-1]
self.sum = new_sum
self._nvar = new_var.mul_(new_n)
if self._range:
self.min = data.min(-1)[0]
self.max = data.max(-1)[0]
else:
assert data.shape[:-1] == self._shape, f"Mismatched shapes, expected {self._shape} but got {data.shape[:-1]}."
ratio = self.n / new_n
t = (self.sum / ratio).sub_(new_sum).pow_(2)
self._nvar.add_(new_n, new_var).add_(ratio / (self.n + new_n), t)
self.sum.add_(new_sum)
self.n += new_n
if self._range:
self.min = torch.min(self.min, data.min(-1)[0])
self.max = torch.max(self.max, data.max(-1)[0])
@property
def mean(self): return self.sum / self.n if self.n > 0 else None
@property
def var(self): return self._nvar / self.n if self.n > 0 else None
@property
def std(self): return self.var.sqrt() if self.n > 0 else None
def __repr__(self):
def _fmt_t(t:Tensor):
if t.numel() > 5: return f"tensor of ({','.join(map(str,t.shape))})"
def __fmt_t(t:Tensor):
return '[' + ','.join([f"{v:.3g}" if v.ndim==0 else __fmt_t(v) for v in t]) + ']'
return __fmt_t(t)
rng_str = f", min={_fmt_t(self.min)}, max={_fmt_t(self.max)}" if self._range else ""
return f"RunningStatistics(n={self.n}, mean={_fmt_t(self.mean)}, std={_fmt_t(self.std)}{rng_str})"
def collect_stats(items:Iterable, n_dims:int=2, record_range:bool=False):
stats = RunningStatistics(n_dims, record_range)
for it in progress_bar(items):
if hasattr(it, 'data'):
stats.update(it.data)
else:
stats.update(it)
return stats
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment