Skip to content

Instantly share code, notes, and snippets.

@xmodar
Last active February 10, 2023 17:49
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save xmodar/dc95444faf3624ed979b4d0b2088fdf1 to your computer and use it in GitHub Desktop.
Save xmodar/dc95444faf3624ed979b4d0b2088fdf1 to your computer and use it in GitHub Desktop.
Seamless running stats for (native python, numpy.ndarray, torch.tensor). Also see: https://gist.github.com/davidbau/00a9b6763a260be8274f6ba22df9a145
"""Seamless running stats for (native python, numpy.ndarray, torch.tensor)."""
from collections import namedtuple
class MeanMeter:
"""Estimate the mean for a stream of values."""
def __init__(self):
"""Initialize the meter."""
self.count = self.mass = 0
def reset(self):
"""Reset the meter."""
self.__init__()
@property
def mean(self):
"""Compute the mean."""
factor = float('nan') if self.count == 0 else 1 / self.count
return self.mass * factor
def update(self, value, count=1):
"""Update the meter."""
assert isinstance(count, int) and count > 0
self.mass += value * count
self.count += count
class VarianceMeter:
"""Estimate the variance/covariance for a stream of values."""
Stats = namedtuple('VarianceMeterStats',
('mean', 'variance', 'count', 'unbiased'))
def __init__(self, unbiased=True, full=False):
"""Initialize the meter (covariance mode if `full`)."""
self._full = bool(full)
self.unbiased = bool(unbiased)
self.count = self.mean = self.mass = 0
def reset(self):
"""Reset the meter."""
self.__init__(self.unbiased, self.full)
def factor(self, count=None, unbiased=None):
"""Get inverted count or Bessel's correction factor."""
if count is None:
count = self.count
if unbiased is None:
unbiased = self.unbiased
count = count - int(bool(unbiased))
return float('nan') if count == 0 else 1 / count
@property
def full(self):
"""Get whether this meter tracks covariances."""
return self._full
@property
def variance(self):
"""Get the variance."""
mass = (self.mass.diagonal() if self.full else self.mass)
return mass * self.factor()
@property
def covariance(self):
"""Get the covariance if `self.full` else `None`."""
return self.mass * self.factor() if self.full else None
def update(self, value):
"""Update the meter."""
if isinstance(value, self.Stats):
kwargs = value._asdict()
else:
kwargs = dict(mean=value, variance=None, count=1, unbiased=None)
return self.stats_update(**kwargs)
def batch_update(self, batch, inplace=False):
"""Update the meter with a data batch."""
stats = self.get_stats(batch, inplace, unbiased=False)
return self.update(stats)
def stats_update(self, mean, variance, count, unbiased):
"""Update the meter with batch statistics.
Sources:
www.johndcook.com/blog/standard_deviation/
stats.stackexchange.com/a/389925
notmatthancock.github.io/2017/03/23/simple-batch-stat-updates.html
stats.stackexchange.com/a/10445
https://stackoverflow.com/a/38324464 # for covariance
Args:
mean: Batch mean (must be a single dimension > 1 if `self.full`).
variance: Batch variance (or covariance if `self.full`).
count: Batch size (must be `int`).
unbiased: Whether `variance` is unbiased.
"""
assert isinstance(count, int) and count > 0
assert not self.full or (mean.ndim == 1 and mean.shape[0] > 1)
diff = mean - self.mean
self.mean += diff * (count / (self.count + count))
if self.full:
diff = diff[:, None] @ (mean - self.mean)[None, :]
else:
diff *= mean - self.mean
if count > 1:
diff += variance
diff *= count
if unbiased:
diff -= variance
self.mass += diff
self.count += count
def get_stats(self, batch, inplace=False, unbiased=None):
"""Estimate the mean and variance over the first dimension.
Args:
batch: Data batch.
inplace: Whether to subtract the mean from `batch` if `self.full`.
unbiased: Whether to compute unbiased variance.
Returns:
namedtuple(mean, variance, count=batch.shape[0], `unbiased`).
"""
unbiased = bool(self.unbiased if unbiased is None else unbiased)
count = batch.shape[0]
mean = batch.mean(0)
if self.full:
assert batch.ndim == 2 and batch.shape[1] > 1
if inplace:
batch -= mean
else:
batch = batch - mean
variance = (batch.T @ batch) * self.factor(count, unbiased)
else:
if hasattr(batch, 'is_cuda'): # pytorch
variance = batch.var(0, unbiased=unbiased)
else: # numpy (we can use try-except TypeError instead)
variance = batch.var(0, ddof=int(unbiased))
return self.Stats(mean, variance, count, unbiased)
if __name__ == '__main__':
from itertools import product
from argparse import Namespace
import torch
test_args = {
'meter': ['mean', 'variance', 'covariance'],
'unbiased': [True, False],
'count': [1, 10000],
'dim': [3],
'batched': [True, False],
}
fail = lambda x, y: not torch.allclose(x, y, equal_nan=True)
for test_case in product(*test_args.values()):
arg = Namespace(**dict(zip(test_args.keys(), test_case)))
if arg.meter is MeanMeter:
if not arg.unbiased:
continue
m = MeanMeter()
else:
full = arg.meter == 'covariance'
m = VarianceMeter(unbiased=arg.unbiased, full=full)
x = torch.randn(arg.count, arg.dim, dtype=torch.float64)
if arg.batched:
while m.count < x.shape[0]:
batch_size = torch.randint(1, 1000, []).item()
m.batch_update(x[m.count:m.count + batch_size])
else:
for y in x:
m.update(y)
passed = True
if fail(x.mean(0), m.mean):
print('MeanTest', arg)
passed = False
continue
if arg.meter.endswith('variance'):
if fail(x.var(0, m.unbiased), m.variance):
print('VarianceTest', arg)
passed = False
continue
if arg.meter == 'covariance':
if fail(m.get_stats(x).variance, m.covariance):
print('CovarianceTest', arg)
passed = False
continue
if passed:
print('Passed!')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment