Skip to content

Instantly share code, notes, and snippets.

@heiner
Last active November 4, 2021 23:41
Show Gist options
  • Save heiner/6287fc81dde85cbd36dbb7b26d3c6578 to your computer and use it in GitHub Desktop.
Save heiner/6287fc81dde85cbd36dbb7b26d3c6578 to your computer and use it in GitHub Desktop.
Various equivalent implementations of Welford's Algorithm, including its "parallel" version
#
# pytest -svx welford_test.py
#
"""
Cf. Sutton-Barto
http://www.incompleteideas.net/book/first/ebook/node19.html
and
https://math.stackexchange.com/a/103025/5051
as well as
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
and
https://en.wikipedia.org/wiki/Pooled_variance#Aggregation_of_standard_deviation_data
"""
import math
import numpy as np
import pytest
class MeanVar0:
# https://math.stackexchange.com/a/103025/5051
def __init__(self):
self.mean = 0
self.var = 0
self.count = 0
def add(self, value):
n = self.count
mean = self.mean + (value - self.mean) / (n + 1)
var = (n * self.var + n * (self.mean - mean) ** 2 + (value - mean) ** 2) / (
n + 1
)
self.mean = mean
self.var = var
self.count += 1
@property
def std(self):
return self.var ** 0.5
class MeanVar1:
# https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
def __init__(self):
self.count = 0
self.sum = 0
self.m2 = 0
def add_many(self, value):
count = len(value)
sum_ = np.sum(value)
mean = sum_ / count
m2 = np.sum((value - mean) ** 2)
delta = mean - self.mean
self.m2 += m2 + (self.count * count) / (self.count + count) * delta ** 2
self.sum += sum_
self.count += count
@property
def mean(self):
if not self.count:
return 0.0
return self.sum / self.count
@property
def var(self):
return self.m2 / self.count
@property
def std(self):
return self.var ** 0.5
class MeanVar2:
# https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm
def __init__(self):
self.count = 0
self.sum = 0
self.m2 = 0
def add(self, value):
old_mean = self.mean
self.count += 1
self.sum += value
self.m2 += (value - old_mean) * (value - self.mean)
@property
def mean(self):
if not self.count:
return 0.0
return self.sum / self.count
@property
def var(self):
return self.m2 / self.count
@property
def std(self):
return self.var ** 0.5
class MeanVar3:
# Version based on "pooling",
# https://en.wikipedia.org/wiki/Pooled_variance#Aggregation_of_standard_deviation_data
# Basically the same as MeanVar1.
def __init__(self):
self.count = 0
self.sum = 0
self.m2 = 0
def add_many(self, value):
other = MeanVar3()
other.count = len(value)
other.sum = np.sum(value)
other.m2 = np.sum((value - other.mean) ** 2)
self.count, self.sum, self.m2 = self.pool(other)
def pool(self, other):
mvs = (self, other)
count = sum(mv.count for mv in mvs)
sum_ = sum(mv.sum for mv in mvs)
m2 = (
sum(mv.m2 for mv in mvs)
+ self.count * other.count / count * (self.mean - other.mean) ** 2
)
return count, sum_, m2
@property
def mean(self):
if not self.count:
return 0.0
return self.sum / self.count
@property
def var(self):
return self.m2 / self.count
@property
def std(self):
return self.var ** 0.5
def pool(t0, t1):
"""Pool two MeanVar objects."""
# Cf.
# https://en.wikipedia.org/wiki/Pooled_variance#Aggregation_of_standard_deviation_data
ts = (t0, t1)
count = sum(t.count for t in ts)
mean = sum(t.count * t.mean for t in ts) / count
var = (
sum(t.count * t.var for t in ts) / count
+ t0.count * t1.count / count ** 2 * (t0.mean - t1.mean) ** 2
)
return count, mean, var
def _test_values(meanvar, values, new=None):
if new is None:
new = values
try:
meanvar.add_many(new)
except AttributeError:
for v in new:
meanvar.add(v)
assert meanvar.count == len(values)
np.testing.assert_almost_equal(meanvar.mean, np.mean(values))
np.testing.assert_almost_equal(meanvar.var, np.var(values))
np.testing.assert_almost_equal(meanvar.std, np.std(values))
@pytest.mark.parametrize("MeanVar", [MeanVar0, MeanVar1, MeanVar2, MeanVar3])
class TestMeanVar:
def test_simple(self, MeanVar):
_test_values(MeanVar(), list(range(10)))
xs = np.arange(0.01, 3, 0.01)
ys = [math.gamma(x) for x in xs]
def test_gamma(self, MeanVar):
_test_values(MeanVar(), TestMeanVar.ys)
def test_pooled(self, MeanVar):
n = len(TestMeanVar.ys)
n0 = n // 4
ys = TestMeanVar.ys
ys0 = ys[:n0]
ys1 = ys[n0:]
t0 = MeanVar()
t1 = MeanVar()
_test_values(t0, ys0)
_test_values(t1, ys1)
count, mean, var = pool(t0, t1)
assert count == len(ys)
np.testing.assert_almost_equal(mean, np.mean(ys))
np.testing.assert_almost_equal(var, np.var(ys))
def test_random(self, MeanVar):
all_data = []
mv = MeanVar()
for i in range(100):
new_count = np.random.randint(1, 20) # Number of new data points.
new_data = np.random.rand(new_count) # New data points.
all_data += list(new_data)
_test_values(mv, all_data, new=new_data)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment