Skip to content

Instantly share code, notes, and snippets.

@nagadomi
Last active February 18, 2019 05:41
Show Gist options
  • Save nagadomi/23cc73e6e8c7d8ede0662b509acbecbc to your computer and use it in GitHub Desktop.
Save nagadomi/23cc73e6e8c7d8ede0662b509acbecbc to your computer and use it in GitHub Desktop.
import numpy as np
from timeit import timeit
from collections import deque
# online running meanstd. ref https://github.com/ajcr/rolling/tree/master/rolling
class TimeRunningMeanStd(object):
def __init__(self, shape, ttl):
self.ttl = ttl
self.mean = np.zeros(shape, dtype=np.float64)
self.std = np.ones(shape, dtype=np.float64)
self.sslm = np.zeros(shape, dtype=np.float64)
self._x = deque()
self._t = deque()
self.update_count = 0
def _add_new(self, new, t):
self._x.append(new)
self._t.append(t)
delta = new - self.mean
self.mean += delta / len(self._t)
self.sslm += delta * (new - self.mean)
def _remove_old(self):
old = self._x.popleft()
self._t.popleft()
delta = old - self.mean
self.mean -= delta / (len(self._t))
self.sslm -= delta * (old - self.mean)
def get_error(self):
mean_error = np.abs(self.mean - np.mean(self._x, axis=0)).sum()
std_error = np.abs(self.std - np.std(self._x, axis=0)).sum()
return mean_error, std_error
def update(self, x, timestamp):
self.update_count += 1
if self.update_count > 100000000:
self.update_count = 0
self._add_new(x, timestamp)
limit = self._t[-1] - self.ttl
while self._t[0] < limit:
self._remove_old()
if not self._t:
break
if self._t:
if self.update_count % 10000 == 0:
# fix cumulative error
self.mean = np.mean(self._x, axis=0, dtype=np.float64)
self.sslm = np.var(self._x, axis=0, dtype=np.float64) * len(self._t)
self.std = np.sqrt(self.sslm / len(self._t)) + 1e-13
else:
self.mean.fill(0)
self.sslm.fill(0)
self.std.fill(1)
if __name__ == "__main__":
m = TimeRunningMeanStd(shape=(64,), ttl=60*15)
t = 0.0
last_t = 0
while t < 300*60:
x = np.zeros((64,))
t += np.random.uniform(0.001, 0.1)
for i in range(64):
x[i] = np.random.normal(i, i)
m.update(x, t)
if t - last_t > 60:
last_t = t
me, se = m.get_error()
print(round(t/60.0), "mean_error", me, "std_error", se)
@nagadomi
Copy link
Author

nagadomi commented Feb 4, 2019

mean error 8.641218300908804e-10 std error 6.382599191567806e-09
TimeRunningMeanStd 3.9924477376043797 sec
TrueTimeRunningMeanStd 29.189583705738187 sec

@nagadomi
Copy link
Author

nagadomi commented Feb 8, 2019

1 mean_error 2.3835378115677486e-12 std_error 6.3307936588013035e-12
2 mean_error 2.8983482280864337e-12 std_error 5.825198093387007e-12
3 mean_error 4.1944225870338414e-12 std_error 6.223546114622513e-12
4 mean_error 4.525269048372138e-12 std_error 6.388969345291662e-12
5 mean_error 5.170086581074429e-12 std_error 6.375868713601085e-12
6 mean_error 4.963363053889225e-12 std_error 6.586255976767552e-12
7 mean_error 5.9685589803848416e-12 std_error 6.757341344862289e-12
8 mean_error 5.46274137036562e-12 std_error 7.29935222548429e-12
10000
9 mean_error 2.705724533313969e-12 std_error 6.2597393852252935e-12
10 mean_error 4.235500838944972e-12 std_error 6.601688076809842e-12
11 mean_error 5.0672799289941395e-12 std_error 6.786762255014855e-12
12 mean_error 6.147748976559342e-12 std_error 6.943414723789465e-12
13 mean_error 7.35378424820965e-12 std_error 6.992264536872972e-12
14 mean_error 7.453593298123451e-12 std_error 7.08485713712671e-12
15 mean_error 7.41695593831082e-12 std_error 6.9937078268049845e-12
16 mean_error 9.556355706763497e-12 std_error 7.61709805513201e-12
20000
17 mean_error 2.5270896486517813e-12 std_error 7.368741164523362e-12
18 mean_error 4.867883873771461e-12 std_error 9.092362410253918e-12
19 mean_error 7.042588734407218e-12 std_error 7.945279981211206e-12
20 mean_error 7.719824779428563e-12 std_error 7.638636381809738e-12
21 mean_error 9.132472555961613e-12 std_error 8.460423464637279e-12
22 mean_error 8.804623696789804e-12 std_error 8.703007195517875e-12
23 mean_error 7.849276784099857e-12 std_error 7.961711281975658e-12
24 mean_error 9.793943434033281e-12 std_error 8.279235067018453e-12
25 mean_error 1.0900391700374712e-11 std_error 8.652047958687581e-12
30000
26 mean_error 3.92119670067359e-12 std_error 6.831282198302324e-12
27 mean_error 6.221689829999377e-12 std_error 7.411262706366506e-12
28 mean_error 8.647527138805344e-12 std_error 7.15502323228302e-12
29 mean_error 1.0222711566143516e-11 std_error 7.219860256921129e-12
30 mean_error 1.1964207402570537e-11 std_error 8.044866986520083e-12
31 mean_error 1.194977450325041e-11 std_error 8.11358979174438e-12
32 mean_error 1.2385426018113321e-11 std_error 8.45564950563139e-12
33 mean_error 1.3370082818653373e-11 std_error 8.814584609492704e-12
40000
34 mean_error 3.17446069431071e-12 std_error 8.963798584002325e-12
35 mean_error 5.449418694070118e-12 std_error 8.943370480349222e-12
36 mean_error 6.88205048504642e-12 std_error 8.279457111623378e-12
37 mean_error 8.575695709112097e-12 std_error 9.32328879937595e-12
38 mean_error 9.97002480573883e-12 std_error 1.0025615884753824e-11
39 mean_error 1.0648149029179876e-11 std_error 9.248459767516215e-12
40 mean_error 1.070787902790471e-11 std_error 9.905489753489383e-12
41 mean_error 1.13473674900888e-11 std_error 8.706559909196676e-12
42 mean_error 1.0900946811887025e-11 std_error 7.807612326157687e-12
50000
43 mean_error 5.9894311732477945e-12 std_error 7.008806859939887e-12
44 mean_error 8.000156093146416e-12 std_error 8.707781154523764e-12
45 mean_error 8.47255599012442e-12 std_error 7.0895200738301355e-12
46 mean_error 1.0702772001991434e-11 std_error 8.494507311493271e-12
47 mean_error 1.0888068224801373e-11 std_error 9.0518392698551e-12
48 mean_error 1.0964451568895583e-11 std_error 8.148783861624997e-12
49 mean_error 1.184474740512087e-11 std_error 7.622760192557598e-12
50 mean_error 1.2645884339690383e-11 std_error 8.051639346970296e-12

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment