Skip to content

Instantly share code, notes, and snippets.

@pkhuong
Last active January 9, 2023 21:03
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save pkhuong/549106fc8194c0d1fce85b00c9e192d5 to your computer and use it in GitHub Desktop.
Save pkhuong/549106fc8194c0d1fce85b00c9e192d5 to your computer and use it in GitHub Desktop.
Fully dynamic variance for a bag of observations
import math
import struct
import unittest
import hypothesis.strategies as st
from hypothesis.stateful import Bundle, RuleBasedStateMachine, consumes, invariant, multiple, precondition, rule
class VarianceStack:
def __init__(self):
self.n = 0
self.mean = 0
self.var_sum = 0 # variance of the sum
def push(self, v):
self.n += 1
if self.n == 1:
self.mean = v
return
old_mean = self.mean
self.mean += (v - old_mean) / self.n
self.var_sum += (v - old_mean) * (v - self.mean)
def pop(self, v):
assert self.n > 0
if self.n == 1:
self.n = 0
self.mean = 0
self.var_sum = 0
return
next_n = self.n - 1
old_mean = self.mean
# m' = m + (v - m) / n
# <-> nm' = nm + v - m
# = v + (n - 1) m
# <-> m = (nm' - v) / (n - 1)
# = n/(n - 1) m' - v/(n - 1)
# n/(n - 1)m' - v/(n - 1) = m' + (m' - v)/(n-1)
self.mean = old_mean + (old_mean - v) / next_n
self.var_sum = max(0, self.var_sum - (v - self.mean) * (v - old_mean))
self.n -= 1
def get_mean(self):
return self.mean
def get_variance(self):
return self.var_sum / (self.n - 1) if self.n > 1 else 0
def float_bits(x: float) -> int:
"""Convert float to sign-magnitude bits, then to 2's complement.
>>> float_bits(0.0)
0
>>> float_bits(-0.0)
-1
>>> float_bits(1.0)
4607182418800017408
>>> float_bits(-2.5)
-4612811918334230529
>>> -float_bits(math.pi) - 1 == float_bits(-math.pi)
True
>>> float_bits(1.0) > 0
True
>>> float_bits(-1.0) < 0
True
"""
bits = struct.unpack('=q', struct.pack('=d', x))[0]
significand = bits % (1 << 63)
# ~significand = -1 - significand. We need that instead of just
# -significand to handle signed zeros.
return significand if bits >= 0 else ~significand
FLOAT_DISTANCE = 2**32
ABSOLUTE_EPS = 1e-8
def assert_almost_equal(x, y, max_delta=FLOAT_DISTANCE, abs_eps=ABSOLUTE_EPS):
delta = abs(x - y)
distance = abs(float_bits(x) - float_bits(y))
assert distance <= max_delta or delta <= abs_eps, '%.18g != %.18g (%f)' % (
x, y, math.log(distance, 2))
MAX_RANGE = 2**12
FLOAT_STRATEGY = st.floats(width=32, min_value=-MAX_RANGE, max_value=MAX_RANGE)
class VarianceStackDriver(RuleBasedStateMachine):
def __init__(self):
super(VarianceStackDriver, self).__init__()
self.values = []
self.variance_stack = VarianceStack()
@rule(v=FLOAT_STRATEGY)
def push(self, v):
self.variance_stack.push(v)
self.values.append(v)
@precondition(lambda self: self.values)
@rule()
def pop(self):
self.variance_stack.pop(self.values[-1])
self.values.pop()
def reference_mean(self):
if self.values:
return sum(self.values) / len(self.values)
return 0
def reference_variance(self):
n = len(self.values)
if n <= 1:
return 0
mean = self.reference_mean()
return sum(pow(x - mean, 2) for x in self.values) / (n - 1)
@invariant()
def mean_matches(self):
assert_almost_equal(self.reference_mean(),
self.variance_stack.get_mean())
@invariant()
def variance_matches(self):
assert_almost_equal(self.reference_variance(),
self.variance_stack.get_variance())
StackTest = VarianceStackDriver.TestCase
class VarianceBag(VarianceStack):
def update(self, old, new):
assert self.n > 0
if self.n == 1:
self.mean = new
self.var_sum = 0
return
delta = new - old
old_mean = self.mean
delta_mean = delta / self.n
self.mean += delta_mean
# we have \sum (x_i - initial_mean)^2
# we want \sum (x_i - mean)^2
# = \sum [(x_i - initial_mean) - delta_mean]^2
# = \sum [(x_i - initial_mean)^2 - 2 delta_mean (x_i - initial_mean) + delta_mean^2]
# = [\sum (x_i - initial_mean)^2] + n delta_mean^2 -- \sum (x_i - initial_mean) = (\sum x_i) - n(sum x_i / n) = 0
#
# delta_mean = delta / n
# n delta_mean^2 = delta delta_mean
# we have (old - mean)^2
# we want (new - mean)^2
# = (old + delta - mean)^2
# = [(old - mean) + delta]^2
# = (old - mean)^2 + 2 delta (old - mean) + delta^2
# = (old - mean)^2 + delta [2(old - mean) + delta]
# = (old - mean)^2 + delta [old + new - 2 mean]
# Total adjustment = delta delta_mean + delta (old + new - 2 mean)
# = delta [(old - mean + delta_mean) + (new - mean)]
# = delta [(old - old_mean) + (new - mean)]
#
# mean = old_mean + delta_mean
# new = old + delta
# -> new - mean = (old - old_mean) + (delta - delta_mean)
# Total adjustment = delta [2 (old - old_mean) + (delta - delta_mean)]
adjustment = delta * (2 * (old - old_mean) + (delta - delta_mean))
self.var_sum = max(0, self.var_sum + adjustment)
class VarianceBagDriver(RuleBasedStateMachine):
keys = Bundle("keys")
def __init__(self):
super(VarianceBagDriver, self).__init__()
self.entries = dict()
self.variance_bag = VarianceBag()
@rule(target=keys, k=st.binary(), v=FLOAT_STRATEGY)
def add_entry(self, k, v):
if k in self.entries:
self.update_entry(k, v)
return multiple()
self.entries[k] = v
self.variance_bag.push(v)
return k
@rule(k=consumes(keys))
def del_entry(self, k):
self.variance_bag.pop(self.entries[k])
del self.entries[k]
@rule(k=keys, v=FLOAT_STRATEGY)
def update_entry(self, k, v):
self.variance_bag.update(self.entries[k], v)
self.entries[k] = v
def reference_mean(self):
if self.entries:
return sum(self.entries.values()) / len(self.entries)
return 0
def reference_variance(self):
n = len(self.entries)
if n <= 1:
return 0
mean = self.reference_mean()
return sum(pow(x - mean, 2) for x in self.entries.values()) / (n - 1)
@invariant()
def mean_matches(self):
assert_almost_equal(self.reference_mean(),
self.variance_bag.get_mean())
@invariant()
def variance_matches(self):
assert_almost_equal(self.reference_variance(),
self.variance_bag.get_variance())
BagTest = VarianceBagDriver.TestCase
if __name__ == '__main__':
unittest.main()
from z3 import *
def appended_expressions(vars, new_x):
old_mean = sum(vars) / len(vars)
old_var_sum = sum((v - old_mean) * (v - old_mean) for v in vars)
n = len(vars) + 1
new_mean = old_mean + (new_x - old_mean) / n
new_var_sum = old_var_sum + (new_x - old_mean) * (new_x - new_mean)
return new_mean, new_var_sum
def removed_expressions(vars):
x = vars[0]
prev_n = len(vars) - 1
new_mean = sum(vars) / len(vars)
mean = new_mean + (new_mean - x) / prev_n
var_sum = sum((v - new_mean) * (v - new_mean) for v in vars)
var_sum -= (x - mean) * (x - new_mean)
return mean, var_sum
def updated_expressions(vars, new_x):
x = vars[0]
num_var = len(vars)
mean = sum(vars) / num_var
var_sum = sum((v - mean) * (v - mean) for v in vars)
delta = new_x - x
delta_mean = delta / num_var
new_mean = mean + delta_mean
adjustment = delta * (2 * (x - mean) + (delta - delta_mean))
new_var_sum = var_sum + adjustment
return new_mean, new_var_sum
def test_num_var(num_var):
assert num_var > 0
vars = [Real('x_%i' % i) for i in range(0, num_var)]
new_x = Real('new_x')
new_mean, new_var_sum = updated_expressions(vars, new_x)
new_vars = [new_x] + vars[1:]
s = Solver()
s.push()
s.add(new_mean != sum(new_vars) / num_var)
result = s.check()
print('updated mean %s' % result)
if result != unsat:
print(s.model())
return False
s.pop()
s.push()
s.add(new_mean == sum(new_vars) / num_var)
s.add(new_var_sum != sum(
(v - new_mean) * (v - new_mean) for v in new_vars))
result = s.check()
print('updated variance %s' % result)
if result != unsat:
print(s.model())
return False
s.pop()
new_mean, new_var_sum = appended_expressions(vars, new_x)
s.push()
s.add(new_mean != (new_x + sum(vars)) / (1 + num_var))
result = s.check()
print('append mean %s' % result)
if result != unsat:
print(s.model())
return False
s.pop()
s.push()
s.add(new_mean == (new_x + sum(vars)) / (1 + num_var))
s.add(new_var_sum != sum(
(v - new_mean) * (v - new_mean) for v in [new_x] + vars))
result = s.check()
print('append variance %s' % result)
if result != unsat:
print(s.model())
return False
s.pop()
if num_var > 1:
new_mean, new_var_sum = removed_expressions(vars)
s.push()
s.add(new_mean != sum(vars[1:]) / (num_var - 1))
result = s.check()
print('removed mean %s' % result)
if result != unsat:
print(s.model())
return False
s.pop()
s.push()
s.add(new_mean == sum(vars[1:]) / (num_var - 1))
s.add(new_var_sum != sum(
(v - new_mean) * (v - new_mean) for v in vars[1:]))
result = s.check()
print('removed variance %s' % result)
if result != unsat:
print(s.model())
return False
s.pop()
return True
for i in range(1, 11):
print('testing n=%i' % i)
if test_num_var(i):
print('OK')
else:
print('FAIL %i' % i)
break
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment