Instantly share code, notes, and snippets.

# pkhuong/dynamic-variance.py

Last active January 9, 2023 21:03
Show Gist options
• Save pkhuong/549106fc8194c0d1fce85b00c9e192d5 to your computer and use it in GitHub Desktop.
Fully dynamic variance for a bag of observations
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
 import math import struct import unittest import hypothesis.strategies as st from hypothesis.stateful import Bundle, RuleBasedStateMachine, consumes, invariant, multiple, precondition, rule class VarianceStack: def __init__(self): self.n = 0 self.mean = 0 self.var_sum = 0 # variance of the sum def push(self, v): self.n += 1 if self.n == 1: self.mean = v return old_mean = self.mean self.mean += (v - old_mean) / self.n self.var_sum += (v - old_mean) * (v - self.mean) def pop(self, v): assert self.n > 0 if self.n == 1: self.n = 0 self.mean = 0 self.var_sum = 0 return next_n = self.n - 1 old_mean = self.mean # m' = m + (v - m) / n # <-> nm' = nm + v - m # = v + (n - 1) m # <-> m = (nm' - v) / (n - 1) # = n/(n - 1) m' - v/(n - 1) # n/(n - 1)m' - v/(n - 1) = m' + (m' - v)/(n-1) self.mean = old_mean + (old_mean - v) / next_n self.var_sum = max(0, self.var_sum - (v - self.mean) * (v - old_mean)) self.n -= 1 def get_mean(self): return self.mean def get_variance(self): return self.var_sum / (self.n - 1) if self.n > 1 else 0 def float_bits(x: float) -> int: """Convert float to sign-magnitude bits, then to 2's complement. >>> float_bits(0.0) 0 >>> float_bits(-0.0) -1 >>> float_bits(1.0) 4607182418800017408 >>> float_bits(-2.5) -4612811918334230529 >>> -float_bits(math.pi) - 1 == float_bits(-math.pi) True >>> float_bits(1.0) > 0 True >>> float_bits(-1.0) < 0 True """ bits = struct.unpack('=q', struct.pack('=d', x))[0] significand = bits % (1 << 63) # ~significand = -1 - significand. We need that instead of just # -significand to handle signed zeros. return significand if bits >= 0 else ~significand FLOAT_DISTANCE = 2**32 ABSOLUTE_EPS = 1e-8 def assert_almost_equal(x, y, max_delta=FLOAT_DISTANCE, abs_eps=ABSOLUTE_EPS): delta = abs(x - y) distance = abs(float_bits(x) - float_bits(y)) assert distance <= max_delta or delta <= abs_eps, '%.18g != %.18g (%f)' % ( x, y, math.log(distance, 2)) MAX_RANGE = 2**12 FLOAT_STRATEGY = st.floats(width=32, min_value=-MAX_RANGE, max_value=MAX_RANGE) class VarianceStackDriver(RuleBasedStateMachine): def __init__(self): super(VarianceStackDriver, self).__init__() self.values = [] self.variance_stack = VarianceStack() @rule(v=FLOAT_STRATEGY) def push(self, v): self.variance_stack.push(v) self.values.append(v) @precondition(lambda self: self.values) @rule() def pop(self): self.variance_stack.pop(self.values[-1]) self.values.pop() def reference_mean(self): if self.values: return sum(self.values) / len(self.values) return 0 def reference_variance(self): n = len(self.values) if n <= 1: return 0 mean = self.reference_mean() return sum(pow(x - mean, 2) for x in self.values) / (n - 1) @invariant() def mean_matches(self): assert_almost_equal(self.reference_mean(), self.variance_stack.get_mean()) @invariant() def variance_matches(self): assert_almost_equal(self.reference_variance(), self.variance_stack.get_variance()) StackTest = VarianceStackDriver.TestCase class VarianceBag(VarianceStack): def update(self, old, new): assert self.n > 0 if self.n == 1: self.mean = new self.var_sum = 0 return delta = new - old old_mean = self.mean delta_mean = delta / self.n self.mean += delta_mean # we have \sum (x_i - initial_mean)^2 # we want \sum (x_i - mean)^2 # = \sum [(x_i - initial_mean) - delta_mean]^2 # = \sum [(x_i - initial_mean)^2 - 2 delta_mean (x_i - initial_mean) + delta_mean^2] # = [\sum (x_i - initial_mean)^2] + n delta_mean^2 -- \sum (x_i - initial_mean) = (\sum x_i) - n(sum x_i / n) = 0 # # delta_mean = delta / n # n delta_mean^2 = delta delta_mean # we have (old - mean)^2 # we want (new - mean)^2 # = (old + delta - mean)^2 # = [(old - mean) + delta]^2 # = (old - mean)^2 + 2 delta (old - mean) + delta^2 # = (old - mean)^2 + delta [2(old - mean) + delta] # = (old - mean)^2 + delta [old + new - 2 mean] # Total adjustment = delta delta_mean + delta (old + new - 2 mean) # = delta [(old - mean + delta_mean) + (new - mean)] # = delta [(old - old_mean) + (new - mean)] # # mean = old_mean + delta_mean # new = old + delta # -> new - mean = (old - old_mean) + (delta - delta_mean) # Total adjustment = delta [2 (old - old_mean) + (delta - delta_mean)] adjustment = delta * (2 * (old - old_mean) + (delta - delta_mean)) self.var_sum = max(0, self.var_sum + adjustment) class VarianceBagDriver(RuleBasedStateMachine): keys = Bundle("keys") def __init__(self): super(VarianceBagDriver, self).__init__() self.entries = dict() self.variance_bag = VarianceBag() @rule(target=keys, k=st.binary(), v=FLOAT_STRATEGY) def add_entry(self, k, v): if k in self.entries: self.update_entry(k, v) return multiple() self.entries[k] = v self.variance_bag.push(v) return k @rule(k=consumes(keys)) def del_entry(self, k): self.variance_bag.pop(self.entries[k]) del self.entries[k] @rule(k=keys, v=FLOAT_STRATEGY) def update_entry(self, k, v): self.variance_bag.update(self.entries[k], v) self.entries[k] = v def reference_mean(self): if self.entries: return sum(self.entries.values()) / len(self.entries) return 0 def reference_variance(self): n = len(self.entries) if n <= 1: return 0 mean = self.reference_mean() return sum(pow(x - mean, 2) for x in self.entries.values()) / (n - 1) @invariant() def mean_matches(self): assert_almost_equal(self.reference_mean(), self.variance_bag.get_mean()) @invariant() def variance_matches(self): assert_almost_equal(self.reference_variance(), self.variance_bag.get_variance()) BagTest = VarianceBagDriver.TestCase if __name__ == '__main__': unittest.main()
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters