-
-
Save thomasaarholt/55a6c06f7107b5a4bf7811c842fa22fb to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from datetime import date | |
import numpy as np | |
from numpy.random import default_rng | |
import pandas as pd | |
import pymc as pm | |
import pymc.math | |
rng = default_rng(seed=42) | |
first_date = date(2023, 1, 1) | |
last_date = date(2024, 1, 1) | |
dates = pd.date_range(first_date, last_date, freq="D").to_numpy() | |
def monthly_pattern(t: int | np.ndarray | pd.Series, day_offset: int): | |
"A wavy dummy data." | |
return np.sin(2 * np.pi * (t + day_offset) / 30.5) | |
def make_actual(days, offset, day_offset): | |
"Shift the wavy data by a bit and add some noise." | |
return ( | |
offset | |
+ monthly_pattern(np.arange(days), day_offset=day_offset) | |
+ rng.normal(loc=0, scale=0.3, size=days) | |
) | |
df1 = pd.DataFrame( | |
{ | |
"date": dates, | |
"group": "A", | |
"actual": make_actual(days=len(dates), offset=2, day_offset=0), | |
} | |
) | |
df2 = pd.DataFrame( | |
{ | |
"date": dates, | |
"group": "B", | |
"actual": make_actual(days=len(dates), offset=-2, day_offset=15), | |
} | |
) | |
df_groups = pd.concat([df1, df2], ignore_index=True) | |
date_idx, date_labels = pd.factorize(df_groups["date"]) | |
group_idx, group_labels = pd.factorize(df_groups["group"]) | |
index = np.arange(len(df_groups)) | |
def get_fourier_features( | |
dates: pd.Series, n_order: int, period: float | |
) -> tuple[pd.DataFrame, np.ndarray]: | |
day = dates.dt.day_of_year / period | |
labels = np.arange(2 * n_order) | |
features = pd.DataFrame( | |
{ | |
f"{func.__name__}_order_{order}": func(2 * np.pi * day * order) | |
for order in range(1, n_order + 1) | |
for func in (np.sin, np.cos) | |
}, | |
) | |
features.set_index(dates, inplace=True) | |
return features, labels | |
# monthly fourier series, we create a shape 365 x 12 dataframe here | |
fourier, fourier_labels = get_fourier_features(pd.Series(dates), 6, 30.5) | |
coords = { | |
"index": index, # 732 rows | |
"date": date_labels, # 366 days | |
"group": group_labels, # 2 groups | |
"fourier": fourier_labels, # 12 fourier features | |
} | |
with pm.Model(coords=coords): | |
intercept = pm.Normal("intercept", mu=0, sigma=1, dims="group") | |
# HOW CAN I MAKE ONE B_FOURIER_MONTHLY PER GROUP (A, B)? | |
b_fourier_monthly = pm.Normal( | |
name="b_fourier_monthly", | |
mu=0, | |
sigma=1, | |
dims="fourier", # <--- Should I make this ("group", "fourier") instead? | |
) | |
seasonality_monthly = pm.Deterministic( | |
name="seasonality_monthly", | |
var=pymc.math.dot(b_fourier_monthly, fourier.T), | |
dims="date", | |
# this will currently have shape 366, or 2, 366 if I change the dims above | |
) | |
pm.Normal( | |
name="likelihood", | |
mu=( | |
intercept[group_idx] | |
+ seasonality_monthly # <-- this has shape (366,), while the intercept has shape (732,), so we get an error here | |
), | |
sigma=1, | |
observed=df_groups["actual"], # shape = (732, ) | |
dims="index", | |
) | |
trace = pm.sample(1000) | |
trace = pm.sample_posterior_predictive( | |
trace, | |
extend_inferencedata=True, | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment