Skip to content

Instantly share code, notes, and snippets.

@wassname
Last active November 14, 2023 15:09
  • Star 9 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
Star You must be signed in to star a gist
Save wassname/a9502f562d4d3e73729dc5b184db2501 to your computer and use it in GitHub Desktop.
Running stats (mean, standard deviation) for python, pytorch, etc
import numpy as np
# handle pytorch tensors etc, by using tensorboardX's method
try:
from tensorboardX.x2num import make_np
except ImportError:
def make_np(x):
return np.array(x).copy().astype('float16')
class RunningStats(object):
"""Computes running mean and standard deviation
Url: https://gist.github.com/wassname/a9502f562d4d3e73729dc5b184db2501
Adapted from:
*
<http://stackoverflow.com/questions/1174984/how-to-efficiently-\
calculate-a-running-standard-deviation>
* <http://mathcentral.uregina.ca/QQ/database/QQ.09.02/carlos1.html>
* <https://gist.github.com/fvisin/5a10066258e43cf6acfa0a474fcdb59f>
Usage:
rs = RunningStats()
for i in range(10):
rs += np.random.randn()
print(rs)
print(rs.mean, rs.std)
"""
def __init__(self, n=0., m=None, s=None):
self.n = n
self.m = m
self.s = s
def clear(self):
self.n = 0.
def push(self, x, per_dim=False):
x = make_np(x)
# process input
if per_dim:
self.update_params(x)
else:
for el in x.flatten():
self.update_params(el)
def update_params(self, x):
self.n += 1
if self.n == 1:
self.m = x
self.s = 0.
else:
prev_m = self.m.copy()
self.m += (x - self.m) / self.n
self.s += (x - prev_m) * (x - self.m)
def __add__(self, other):
if isinstance(other, RunningStats):
sum_ns = self.n + other.n
prod_ns = self.n * other.n
delta2 = (other.m - self.m) ** 2.
return RunningStats(sum_ns,
(self.m * self.n + other.m * other.n) / sum_ns,
self.s + other.s + delta2 * prod_ns / sum_ns)
else:
self.push(other)
return self
@property
def mean(self):
return self.m if self.n else 0.0
def variance(self):
return self.s / (self.n - 1) if self.n else 0.0
@property
def std(self):
return np.sqrt(self.variance())
def __repr__(self):
return '<RunningMean(mean={: 2.4f}, std={: 2.4f}, n={: 2f}, m={: 2.4f}, s={: 2.4f})>'.format(self.mean, self.std, self.n, self.m, self.s)
def __str__(self):
return 'mean={: 2.4f}, std={: 2.4f}'.format(self.mean, self.std)
@DrSkippy
Copy link

DrSkippy commented Aug 5, 2020

The add for the case of combining two RunningStats should be something like this:

def __add__(self, other):
    if isinstance(other, RunningStats):
        sum_ns = self.n + other.n
        prod_ns = self.n * other.n
        delta2 = (other.m - self.m) ** 2.
        return RunningStats(sum_ns,
                            (self.m * self.n + other.m * other.n) / sum_ns,
                            self.s + other.s + delta2 * prod_ns / sum_ns)
    else:
        self.push(other)
        return self

@wassname
Copy link
Author

wassname commented Dec 3, 2020

@DrSkippy true! I added it, thanks.

@TomerAntman
Copy link

According to this source: https://www.johndcook.com/blog/standard_deviation/
in the calculation of the variance you should have self.s / (self.n -1 )
"For 2 ≤ k ≤ n, the kth estimate of the variance is s^2 = Sk/(k – 1)."

@wassname
Copy link
Author

Thanks Tomer, I added that too.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment