Last active
February 10, 2023 17:49
-
-
Save xmodar/dc95444faf3624ed979b4d0b2088fdf1 to your computer and use it in GitHub Desktop.
Seamless running stats for (native python, numpy.ndarray, torch.tensor). Also see: https://gist.github.com/davidbau/00a9b6763a260be8274f6ba22df9a145
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Seamless running stats for (native python, numpy.ndarray, torch.tensor).""" | |
from collections import namedtuple | |
class MeanMeter: | |
"""Estimate the mean for a stream of values.""" | |
def __init__(self): | |
"""Initialize the meter.""" | |
self.count = self.mass = 0 | |
def reset(self): | |
"""Reset the meter.""" | |
self.__init__() | |
@property | |
def mean(self): | |
"""Compute the mean.""" | |
factor = float('nan') if self.count == 0 else 1 / self.count | |
return self.mass * factor | |
def update(self, value, count=1): | |
"""Update the meter.""" | |
assert isinstance(count, int) and count > 0 | |
self.mass += value * count | |
self.count += count | |
class VarianceMeter: | |
"""Estimate the variance/covariance for a stream of values.""" | |
Stats = namedtuple('VarianceMeterStats', | |
('mean', 'variance', 'count', 'unbiased')) | |
def __init__(self, unbiased=True, full=False): | |
"""Initialize the meter (covariance mode if `full`).""" | |
self._full = bool(full) | |
self.unbiased = bool(unbiased) | |
self.count = self.mean = self.mass = 0 | |
def reset(self): | |
"""Reset the meter.""" | |
self.__init__(self.unbiased, self.full) | |
def factor(self, count=None, unbiased=None): | |
"""Get inverted count or Bessel's correction factor.""" | |
if count is None: | |
count = self.count | |
if unbiased is None: | |
unbiased = self.unbiased | |
count = count - int(bool(unbiased)) | |
return float('nan') if count == 0 else 1 / count | |
@property | |
def full(self): | |
"""Get whether this meter tracks covariances.""" | |
return self._full | |
@property | |
def variance(self): | |
"""Get the variance.""" | |
mass = (self.mass.diagonal() if self.full else self.mass) | |
return mass * self.factor() | |
@property | |
def covariance(self): | |
"""Get the covariance if `self.full` else `None`.""" | |
return self.mass * self.factor() if self.full else None | |
def update(self, value): | |
"""Update the meter.""" | |
if isinstance(value, self.Stats): | |
kwargs = value._asdict() | |
else: | |
kwargs = dict(mean=value, variance=None, count=1, unbiased=None) | |
return self.stats_update(**kwargs) | |
def batch_update(self, batch, inplace=False): | |
"""Update the meter with a data batch.""" | |
stats = self.get_stats(batch, inplace, unbiased=False) | |
return self.update(stats) | |
def stats_update(self, mean, variance, count, unbiased): | |
"""Update the meter with batch statistics. | |
Sources: | |
www.johndcook.com/blog/standard_deviation/ | |
stats.stackexchange.com/a/389925 | |
notmatthancock.github.io/2017/03/23/simple-batch-stat-updates.html | |
stats.stackexchange.com/a/10445 | |
https://stackoverflow.com/a/38324464 # for covariance | |
Args: | |
mean: Batch mean (must be a single dimension > 1 if `self.full`). | |
variance: Batch variance (or covariance if `self.full`). | |
count: Batch size (must be `int`). | |
unbiased: Whether `variance` is unbiased. | |
""" | |
assert isinstance(count, int) and count > 0 | |
assert not self.full or (mean.ndim == 1 and mean.shape[0] > 1) | |
diff = mean - self.mean | |
self.mean += diff * (count / (self.count + count)) | |
if self.full: | |
diff = diff[:, None] @ (mean - self.mean)[None, :] | |
else: | |
diff *= mean - self.mean | |
if count > 1: | |
diff += variance | |
diff *= count | |
if unbiased: | |
diff -= variance | |
self.mass += diff | |
self.count += count | |
def get_stats(self, batch, inplace=False, unbiased=None): | |
"""Estimate the mean and variance over the first dimension. | |
Args: | |
batch: Data batch. | |
inplace: Whether to subtract the mean from `batch` if `self.full`. | |
unbiased: Whether to compute unbiased variance. | |
Returns: | |
namedtuple(mean, variance, count=batch.shape[0], `unbiased`). | |
""" | |
unbiased = bool(self.unbiased if unbiased is None else unbiased) | |
count = batch.shape[0] | |
mean = batch.mean(0) | |
if self.full: | |
assert batch.ndim == 2 and batch.shape[1] > 1 | |
if inplace: | |
batch -= mean | |
else: | |
batch = batch - mean | |
variance = (batch.T @ batch) * self.factor(count, unbiased) | |
else: | |
if hasattr(batch, 'is_cuda'): # pytorch | |
variance = batch.var(0, unbiased=unbiased) | |
else: # numpy (we can use try-except TypeError instead) | |
variance = batch.var(0, ddof=int(unbiased)) | |
return self.Stats(mean, variance, count, unbiased) | |
if __name__ == '__main__': | |
from itertools import product | |
from argparse import Namespace | |
import torch | |
test_args = { | |
'meter': ['mean', 'variance', 'covariance'], | |
'unbiased': [True, False], | |
'count': [1, 10000], | |
'dim': [3], | |
'batched': [True, False], | |
} | |
fail = lambda x, y: not torch.allclose(x, y, equal_nan=True) | |
for test_case in product(*test_args.values()): | |
arg = Namespace(**dict(zip(test_args.keys(), test_case))) | |
if arg.meter is MeanMeter: | |
if not arg.unbiased: | |
continue | |
m = MeanMeter() | |
else: | |
full = arg.meter == 'covariance' | |
m = VarianceMeter(unbiased=arg.unbiased, full=full) | |
x = torch.randn(arg.count, arg.dim, dtype=torch.float64) | |
if arg.batched: | |
while m.count < x.shape[0]: | |
batch_size = torch.randint(1, 1000, []).item() | |
m.batch_update(x[m.count:m.count + batch_size]) | |
else: | |
for y in x: | |
m.update(y) | |
passed = True | |
if fail(x.mean(0), m.mean): | |
print('MeanTest', arg) | |
passed = False | |
continue | |
if arg.meter.endswith('variance'): | |
if fail(x.var(0, m.unbiased), m.variance): | |
print('VarianceTest', arg) | |
passed = False | |
continue | |
if arg.meter == 'covariance': | |
if fail(m.get_stats(x).variance, m.covariance): | |
print('CovarianceTest', arg) | |
passed = False | |
continue | |
if passed: | |
print('Passed!') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment