Created
April 14, 2021 11:46
-
-
Save scravy/107833c7d8b9670ab762d26393dc1922 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Given a DataFrame like that: | |
date name mx my mz | |
0 2021-01-01 foo 46.676136 54.017562 1.882048 | |
1 2021-01-01 bar 81.669122 51.525996 8.062095 | |
2 2021-01-01 qux 54.584057 53.621155 9.475698 | |
3 2021-01-02 foo 84.945618 51.289542 9.132064 | |
4 2021-01-02 bar 87.438244 50.643675 1.881121 | |
5 2021-01-02 qux 91.891133 50.347743 8.809469 | |
we're interested in various aggregations over a sliding window of N days | |
per each name and metric, i.e. for N=3 days and min, max, avg aggregations: | |
date, name, mx_d3_avg, mx_d3_min, mx_d3_max, my_d3_avg, ... | |
The resulting DF would have the same number of rows as the input DF, | |
but quite some more columns. | |
""" | |
import random as rd | |
from datetime import date | |
from typing import Dict, Callable, List, Tuple, Union | |
import numpy as np | |
import pandas as pd | |
Num = Union[int, float] | |
def rand(scale: Num, y: Num) -> float: | |
return rd.random() * scale + y | |
def r(scale, y) -> Callable[[], float]: | |
return lambda: rand(scale, y) | |
names = ['foo', 'bar', 'qux', 'zee'] | |
dates = [date(2021, 1, day) for day in range(1, 11)] | |
metrics: Dict[str, Callable[[], float]] = { | |
'mx': r(90, 10), | |
'my': r(5, 50), | |
'mz': r(10, 0), | |
} | |
aggs = [np.min, np.max, np.average] | |
def mk_frame() -> pd.DataFrame: | |
# noinspection PyTypeChecker | |
data: List[Tuple[date, str, float, float, float]] = \ | |
[tuple([d, n, *(mgen() for mgen in metrics.values())]) for d in dates for n in names] | |
return pd.DataFrame(data, columns=['date', 'name', *metrics.keys()]) | |
def main(): | |
df: pd.DataFrame = mk_frame() | |
g: pd.DataFrame = df.groupby('name').rolling(window=3, on='date').agg(dict([(m, aggs) for m in metrics.keys()])) | |
assert len(df) == len(g) | |
print(g) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment