Skip to content

Instantly share code, notes, and snippets.

@sharmaeklavya2
Last active January 25, 2024 16:48
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sharmaeklavya2/72939fbcfcb8f870c299ed9555a5bc67 to your computer and use it in GitHub Desktop.
Save sharmaeklavya2/72939fbcfcb8f870c299ed9555a5bc67 to your computer and use it in GitHub Desktop.
Track how a noisy quantity changes over time (e.g. my body weight)
#!/usr/bin/env python3
"""
Visualize how a quantity varies with time.
Reads a CSV file where first column is timestamp.
Plots the data points and a weighted average of a sliding window.
"""
import csv
import math
import argparse
from datetime import datetime
from collections import namedtuple
from typing import Any
import numpy as np
import matplotlib.pyplot as plt
try:
import seaborn as sns
HAS_SNS = True
except ImportError:
HAS_SNS = False
EPS = 1e-8
STDDEV_FACTOR = 1
def readFile(fpath: str, args: Any) -> tuple[list[datetime], list[float]]:
with open(args.fpath, newline='') as fp:
reader = csv.reader(fp)
timex: list[datetime] = []
y: list[float] = []
for (i, row) in enumerate(reader):
if not(args.has_header) or i > 0:
timex.append(datetime.fromisoformat(row[args.time_index]))
y.append(float(row[args.data_index]))
return (timex, y)
def dtToNpdt(dt):
return np.datetime64(dt.replace(tzinfo=None).isoformat(), 's')
class SummaryStats():
def __init__(self, zero=0):
self._n = 0
self._sum = zero
self._freq = {}
def mean(self):
return self._sum / self._n
def add(self, x):
self._n += 1
self._sum += x
if x not in self._freq:
self._freq[x] = 0
self._freq[x] += 1
def remove(self, x):
self._n -= 1
self._sum -= x
self._freq[x] -= 1
if self._freq[x] == 0:
del self._freq[x]
def __repr__(self):
return '{}(n={}, sum={}, freq={})'.format(self.__class__.__name__, self._n,
self._sum, self._freq)
class BetterSummaryStats(SummaryStats):
def __init__(self, zero=0):
super().__init__(zero)
self._sum2 = zero
def add(self, x):
super().add(x)
self._sum2 += x*x
def remove(self, x):
super().remove(x)
self._sum2 -= x*x
def stddev(self):
if self._n == 0:
return math.nan
elif self._n == 1:
return math.inf
else:
return math.sqrt((self._sum2 + EPS - self._sum ** 2 / self._n) / (self._n - 1))
def __repr__(self):
return '{}(n={}, sum={}, sum2={}, freq={}, mean={}, stddev={})'.format(
self.__class__.__name__, self._n, self._sum, self._sum2, self._freq,
self.mean(), self.stddev())
Point = namedtuple('Point', ['x', 'y', 'ystd'])
def debugPointAddition(point, type, i, j, xStats, yStats):
"""
print('adding point {} of type {} for interval {}'.format(point, type, (i, j)))
print('xStats: {}'.format(xStats))
print('yStats: {}'.format(yStats))
print()
"""
def getPlotData(timex_orig, y_orig, args):
n = len(timex_orig)
assert len(y_orig) == n
timex = np.array([dtToNpdt(dt) for dt in timex_orig], dtype=np.datetime64)
y = np.array(y_orig, dtype=np.float64)
delta = np.timedelta64(args.delta, args.delta_unit)
# sliding window:
# type A: [timex[i], timex[i] + delta], and timex[j] is the smallest outside interval.
# type B: [timex[j-1] - delta, timex[j-1]], and timex[i] is the smallest in interval.
# For both types, we consider points with indices in range(i, j)
# setup phase
i, j = 0, 0
type = 'A'
xStats = SummaryStats(zero=timex[0] - timex[0]) # maintain SummaryStats(timex[i:j] - timex[0])
yStats = BetterSummaryStats() # maintain SummaryStats(y[i:j])
output: list[Point] = []
while j < n and timex[j] - timex[i] <= delta:
xStats.add(timex[j] - timex[0])
yStats.add(y[j])
j += 1
point = Point(timex[0] + xStats.mean(), yStats.mean(), yStats.stddev())
debugPointAddition(point, type, i, j, xStats, yStats)
output.append(point)
# iteration phase
while j < n:
gotNewPoint = False
if type == 'A':
distToA = timex[i+1] - timex[i]
distToB = timex[j] - timex[i] - delta
xStats.remove(timex[i] - timex[0])
yStats.remove(y[i])
i += 1
gotNewPoint = True
else:
distToA = timex[i] + delta - timex[j-1]
distToB = timex[j] - timex[j-1]
if distToB < distToA:
type = 'B'
xStats.add(timex[j] - timex[0])
yStats.add(y[j])
j += 1
gotNewPoint = True
else:
type = 'A'
if gotNewPoint:
point = Point(timex[0] + xStats.mean(), yStats.mean(), yStats.stddev())
debugPointAddition(point, type, i, j, xStats, yStats)
output.append(point)
return (timex, y, output)
def makePlot(timex, y, points, stddevMul):
if HAS_SNS:
sns.set()
plotX = [p.x for p in points]
plotY = np.array([p.y for p in points])
plotYstd = np.array([p.ystd for p in points])
plt.plot(plotX, plotY, 'b.-', zorder=1)
plt.plot(timex, y, 'ro', zorder=5)
plt.fill_between(plotX, plotY - stddevMul * plotYstd, plotY + stddevMul * plotYstd,
facecolor='#8080ff40')
plt.tight_layout()
plt.show()
def main():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument('fpath', help='path to CSV file')
parser.add_argument('--no-header', dest='has_header', action='store_false', default=True,
help='CSV file does not contain header')
parser.add_argument('--time-index', type=int, default=0)
parser.add_argument('--data-index', type=int, default=1)
parser.add_argument('--delta-unit', choices=['s', 'm', 'h', 'D'], default='D')
parser.add_argument('--delta', type=int, default=14)
parser.add_argument('--stddev-mul', type=float, default=1.65)
args = parser.parse_args()
timex_orig, y_orig = readFile(args.fpath, args)
timex, y, points = getPlotData(timex_orig, y_orig, args)
makePlot(timex, y, points, args.stddev_mul)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment