Last active
November 4, 2021 23:41
-
-
Save heiner/6287fc81dde85cbd36dbb7b26d3c6578 to your computer and use it in GitHub Desktop.
Various equivalent implementations of Welford's Algorithm, including its "parallel" version
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# | |
# pytest -svx welford_test.py | |
# | |
""" | |
Cf. Sutton-Barto | |
http://www.incompleteideas.net/book/first/ebook/node19.html | |
and | |
https://math.stackexchange.com/a/103025/5051 | |
as well as | |
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm | |
and | |
https://en.wikipedia.org/wiki/Pooled_variance#Aggregation_of_standard_deviation_data | |
""" | |
import math | |
import numpy as np | |
import pytest | |
class MeanVar0: | |
# https://math.stackexchange.com/a/103025/5051 | |
def __init__(self): | |
self.mean = 0 | |
self.var = 0 | |
self.count = 0 | |
def add(self, value): | |
n = self.count | |
mean = self.mean + (value - self.mean) / (n + 1) | |
var = (n * self.var + n * (self.mean - mean) ** 2 + (value - mean) ** 2) / ( | |
n + 1 | |
) | |
self.mean = mean | |
self.var = var | |
self.count += 1 | |
@property | |
def std(self): | |
return self.var ** 0.5 | |
class MeanVar1: | |
# https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm | |
def __init__(self): | |
self.count = 0 | |
self.sum = 0 | |
self.m2 = 0 | |
def add_many(self, value): | |
count = len(value) | |
sum_ = np.sum(value) | |
mean = sum_ / count | |
m2 = np.sum((value - mean) ** 2) | |
delta = mean - self.mean | |
self.m2 += m2 + (self.count * count) / (self.count + count) * delta ** 2 | |
self.sum += sum_ | |
self.count += count | |
@property | |
def mean(self): | |
if not self.count: | |
return 0.0 | |
return self.sum / self.count | |
@property | |
def var(self): | |
return self.m2 / self.count | |
@property | |
def std(self): | |
return self.var ** 0.5 | |
class MeanVar2: | |
# https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm | |
def __init__(self): | |
self.count = 0 | |
self.sum = 0 | |
self.m2 = 0 | |
def add(self, value): | |
old_mean = self.mean | |
self.count += 1 | |
self.sum += value | |
self.m2 += (value - old_mean) * (value - self.mean) | |
@property | |
def mean(self): | |
if not self.count: | |
return 0.0 | |
return self.sum / self.count | |
@property | |
def var(self): | |
return self.m2 / self.count | |
@property | |
def std(self): | |
return self.var ** 0.5 | |
class MeanVar3: | |
# Version based on "pooling", | |
# https://en.wikipedia.org/wiki/Pooled_variance#Aggregation_of_standard_deviation_data | |
# Basically the same as MeanVar1. | |
def __init__(self): | |
self.count = 0 | |
self.sum = 0 | |
self.m2 = 0 | |
def add_many(self, value): | |
other = MeanVar3() | |
other.count = len(value) | |
other.sum = np.sum(value) | |
other.m2 = np.sum((value - other.mean) ** 2) | |
self.count, self.sum, self.m2 = self.pool(other) | |
def pool(self, other): | |
mvs = (self, other) | |
count = sum(mv.count for mv in mvs) | |
sum_ = sum(mv.sum for mv in mvs) | |
m2 = ( | |
sum(mv.m2 for mv in mvs) | |
+ self.count * other.count / count * (self.mean - other.mean) ** 2 | |
) | |
return count, sum_, m2 | |
@property | |
def mean(self): | |
if not self.count: | |
return 0.0 | |
return self.sum / self.count | |
@property | |
def var(self): | |
return self.m2 / self.count | |
@property | |
def std(self): | |
return self.var ** 0.5 | |
def pool(t0, t1): | |
"""Pool two MeanVar objects.""" | |
# Cf. | |
# https://en.wikipedia.org/wiki/Pooled_variance#Aggregation_of_standard_deviation_data | |
ts = (t0, t1) | |
count = sum(t.count for t in ts) | |
mean = sum(t.count * t.mean for t in ts) / count | |
var = ( | |
sum(t.count * t.var for t in ts) / count | |
+ t0.count * t1.count / count ** 2 * (t0.mean - t1.mean) ** 2 | |
) | |
return count, mean, var | |
def _test_values(meanvar, values, new=None): | |
if new is None: | |
new = values | |
try: | |
meanvar.add_many(new) | |
except AttributeError: | |
for v in new: | |
meanvar.add(v) | |
assert meanvar.count == len(values) | |
np.testing.assert_almost_equal(meanvar.mean, np.mean(values)) | |
np.testing.assert_almost_equal(meanvar.var, np.var(values)) | |
np.testing.assert_almost_equal(meanvar.std, np.std(values)) | |
@pytest.mark.parametrize("MeanVar", [MeanVar0, MeanVar1, MeanVar2, MeanVar3]) | |
class TestMeanVar: | |
def test_simple(self, MeanVar): | |
_test_values(MeanVar(), list(range(10))) | |
xs = np.arange(0.01, 3, 0.01) | |
ys = [math.gamma(x) for x in xs] | |
def test_gamma(self, MeanVar): | |
_test_values(MeanVar(), TestMeanVar.ys) | |
def test_pooled(self, MeanVar): | |
n = len(TestMeanVar.ys) | |
n0 = n // 4 | |
ys = TestMeanVar.ys | |
ys0 = ys[:n0] | |
ys1 = ys[n0:] | |
t0 = MeanVar() | |
t1 = MeanVar() | |
_test_values(t0, ys0) | |
_test_values(t1, ys1) | |
count, mean, var = pool(t0, t1) | |
assert count == len(ys) | |
np.testing.assert_almost_equal(mean, np.mean(ys)) | |
np.testing.assert_almost_equal(var, np.var(ys)) | |
def test_random(self, MeanVar): | |
all_data = [] | |
mv = MeanVar() | |
for i in range(100): | |
new_count = np.random.randint(1, 20) # Number of new data points. | |
new_data = np.random.rand(new_count) # New data points. | |
all_data += list(new_data) | |
_test_values(mv, all_data, new=new_data) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment