Last active
November 30, 2023 05:21
-
-
Save MCRE-BE/42f699b2a1631cbdd4b9095b100ee92e to your computer and use it in GitHub Desktop.
Simple TS_Clustering class with some plotting options
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""TimeSeries clustering algorithm transformer.""" | |
# %% | |
################### | |
# IMPORT STATMENT # | |
################### | |
from datetime import datetime, timedelta | |
from typing import TYPE_CHECKING, Optional, Self | |
import numpy as np | |
import pandas as pd | |
if TYPE_CHECKING: | |
import holidays | |
class TSClustering: | |
"""Simple timeseries clustering class. | |
Parameters | |
---------- | |
data : pd.Series | |
Series to cluster | |
Attributes | |
---------- | |
MIN_OBSERVATIONS : int | |
The minimum amount of observations in a cluster | |
CLUSTERING_PARAM : Dict[str, Any] | |
The clustering parameters passed to tslearn | |
""" | |
MIN_OBSERVATIONS = 6 | |
CLUSTERING_PARAM = { | |
'n_clusters': 9, # 9 is found to be arbitrarely the best | |
'metric': "dtw", | |
'random_state': 42, | |
} | |
WEEK_MAPPING = { | |
'0000000': "NORMAL", | |
'0000001': "SUNDAY", | |
'0000010': 'SATURDAY', | |
'0000100': 'FRIDAY', | |
'0001000': 'THURSDAY', | |
'0010000': 'WEDNESDAY', | |
'0100000': 'TUESDAY', | |
'1000000': 'MONDAY', | |
} | |
def __init__(self: Self, data: pd.Series): | |
# Save the attributes that should not change. | |
self.data = data | |
self.fd = data.first_valid_index() | |
# Fix to the next sunday as we cluster per week, so we need to extend the | |
# output to the next available sunday | |
last_id = data.last_valid_index() | |
self.sd = self.get_next_weekday(last_id, 6) | |
@property | |
def index(self: Self) -> pd.DatetimeIndex: | |
"""Return the final index.""" | |
return pd.date_range(self.fd, self.sd, freq="D", name="ds") | |
# ... Public Facing Methods ... | |
def compute_clusters(self: Self) -> Self: | |
"""Compute the timeseries clusters. | |
See Also | |
-------- | |
_transform_input | |
_clean_clusters | |
_compute_clusters | |
_compute_bankholidays | |
""" | |
self._transform_input() | |
self._clean_clusters() | |
self._compute_clusters() | |
self._compute_bankholidays() | |
return self | |
def create_output(self: Self) -> Self: | |
"""Wrangle the output to the correct shape. | |
Attributes | |
---------- | |
output : pd.DataFrame | |
Output dataframe with ds / Cluster | |
""" | |
# --- Variables --- | |
key = ["year", "week"] | |
# --- Script --- | |
# Get the final index | |
date = self.index.isocalendar() | |
date = date[key] | |
# Get the computed data | |
data = pd.concat([self.labels, self.holidays]) | |
# Join together to get "ds" back in the idnex | |
data = date.join(data, on=key) | |
data = data.drop(columns=key) | |
# Fill the missing | |
# First fill all days at the end with the last known label | |
# This to prevent the last label to be a bankholiday label | |
# This can happen in EOY situations | |
check = np.bitwise_and( | |
data["Cluster"].isna(), | |
data.index > data.last_valid_index(), | |
) | |
last_label = self.labels.iloc[-1].values[0] | |
data.loc[check, "Cluster"] = last_label | |
# Then fill the rst | |
data = data.ffill() | |
self.output = data.copy() | |
return self | |
# ... Private API ... | |
def _transform_input(self: Self) -> Self: | |
"""Transfrom the input dataframe into a usable format. | |
Compute for the provided pd.Series the % of volume per weekday | |
and save that. Apply a basic filtering to drop completly empty | |
weeks from the output. Finally, split the regular weeks without | |
a bank holiday from the other weeks. | |
Attributes | |
---------- | |
data: pd.Series | |
The prepared data | |
""" | |
# --- Variables --- | |
df = self.data.copy() | |
groupby = ['year', 'week', 'htype'] | |
# --- Script --- | |
# First reindex to get a full week | |
df = df.reindex(self.index, fill_value=0) | |
# Compute the percentage per day of the week | |
index = df.index.isocalendar() | |
df = df.groupby(by=[index["year"], index["week"]]) | |
df = df.apply(self.week_average) | |
df = df.to_frame() | |
# Add the weektype to be able to filter bank holidays | |
t = df.index.get_level_values("ds") | |
df["htype"] = [self.get_holidayweek(x) for x in t] | |
df["htype"] = df["htype"].map(self.WEEK_MAPPING) | |
# Group all values of a week in a list | |
df = df.groupby(by=groupby) | |
df = df["stores"].apply(list) | |
self.data = df.copy() | |
return self | |
def _clean_clusters(self: Self) -> Self: | |
"""Clean and split the prepared clustering. | |
Attributes | |
---------- | |
feat: pd.DataFrame | |
The data to cluster | |
holidays: pd.DataFrame | |
Data that does not need to be clustered as one of the days | |
is a holiday week. | |
""" | |
# --- Variables --- | |
df = self.data.copy() | |
temp = df.index.get_level_values('htype') | |
# --- Script --- | |
# Split the data in to cluster and bank holidays | |
self.holidays = df[temp != "NORMAL"].copy() | |
# Drop all weeks that sum to zero. | |
to_cluster = df[temp == "NORMAL"].copy() | |
filter = to_cluster.apply(lambda x: np.nansum(x)) | |
to_cluster = to_cluster[filter != 0] | |
to_cluster = to_cluster.reset_index("htype", drop=True) | |
self.feat = to_cluster.copy() | |
return self | |
def _compute_clusters(self: Self) -> Self: | |
"""Compute the time series clusters. | |
See Also | |
-------- | |
tslearn.utils.to_time_series_dataset | |
tslearn.clustering.TimeSeriesKMeans | |
clean_output | |
Attributes | |
---------- | |
labels : pd.DataFrame | |
The provided output labels per week | |
""" | |
# --- Import --- | |
from tslearn.clustering import TimeSeriesKMeans | |
from tslearn.utils import to_time_series_dataset | |
# --- Script --- | |
# Wrap the input to the correct shape | |
to_cluster = self.feat.copy() | |
mySeries = to_time_series_dataset(to_cluster.values) | |
km = TimeSeriesKMeans(**self.CLUSTERING_PARAM) | |
labels = km.fit_predict(mySeries) | |
# Prepare output | |
labels = pd.Series( | |
labels, | |
index=to_cluster.index, | |
name="Cluster", | |
) | |
labels = self.clean_output(labels, self.MIN_OBSERVATIONS) | |
self.labels = labels.to_frame() | |
return self | |
def _compute_bankholidays(self: Self) -> Self: | |
"""Compute the ts clusters corresponding to weeks with bank holidays. | |
Attributes | |
---------- | |
holidays : pd.DataFrame | |
""" | |
df = self.holidays.copy() | |
df = df.to_frame() | |
df = df[[]] | |
df = df.reset_index("htype") | |
df.columns = ["Cluster"] | |
self.holidays = df.copy() | |
return self | |
# ... Helper Methods ... | |
@staticmethod | |
def week_average(x: pd.Series) -> float: | |
"""Compute the weekly average.""" | |
if x.sum().sum() == 0: return x | |
return np.round(x * 100 / x.sum(), 2) | |
@staticmethod | |
def clean_output( | |
labels: pd.Series, | |
min_observations: int = 6, | |
) -> pd.Series: | |
"""Clean the computed timeseries clusters. | |
Clean the computed clusters to remove all the very small clusters | |
and fill all the missing "regular" weeks with the last known cluster | |
next to it. | |
Parameters | |
---------- | |
labels : pd.Series | |
The series of labels coming out of the timeseries clustering algorithm | |
min_observations: int, by default=6 | |
The minimum number of observations in a cluster before we keep it. | |
""" | |
df = labels.copy() | |
# Find all cluster labels with low occurence rate | |
# 6 as lowest amount of bank holidays | |
low = df.value_counts() / 7 | |
low = low[low < min_observations] | |
low = [x for x in low.index] | |
# Replace all small values with Other | |
replace = { | |
k: "Other" | |
for k in low | |
} | |
df = df.replace(replace) | |
# Replace all values to get an increasing number | |
labels = df.unique().tolist() | |
replace = { | |
v: f"NORMAL_{k}" | |
for k, v in enumerate(labels) | |
} | |
df = df.replace(replace) | |
return df | |
@staticmethod | |
def get_holidayweek( | |
d: datetime, | |
hds: Optional["holidays.countries.belgium.Belgium"] = None, | |
) -> str: | |
"""Function to retrieve holidays of the week as string for a day. | |
Parameters | |
---------- | |
d : datetime | |
A date in datetime format | |
hds : holidays dict | |
Output from holidays package with all holidays | |
for a specific country | |
Returns | |
------- | |
output : str | |
String in format "0000000" | |
with 0 = No bank holiday | |
with 1 = A bank holiday | |
""" | |
# --- Import --- | |
from datetime import timedelta | |
import holidays | |
# --- Checks --- | |
if hds is None: hds = holidays.Belgium() | |
# --- Script --- | |
wd = d.weekday() | |
wds = [d + timedelta(days=int(i)) for i in range(-wd, 7 - wd)] | |
wdhs = [str(int(i in hds)) for i in wds] | |
wdhs = "".join(wdhs) | |
return wdhs | |
@staticmethod | |
def get_next_weekday( | |
ori: datetime, | |
weekday: int, | |
) -> datetime: | |
"""Get the next weekday from the origin. | |
Parameters | |
---------- | |
ori : datetime, | |
The point of origin from which to advance | |
weekday : int | |
Which weekday to retrieve | |
""" | |
return ori + timedelta(days=weekday + ori.weekday()) | |
# ... Plotting ... | |
def plot_summary_table( | |
self: Self, | |
color: bool = True, | |
) -> None: | |
"""Plot the found clusters. | |
Parameters | |
---------- | |
color : bool, by default=True | |
Whether the output needs to be with color | |
""" | |
# --- Import --- | |
from IPython.display import display | |
# --- Function --- | |
def color_unique(df): | |
"""Make a dataframe with unique colors for each value.""" | |
# --- Import --- | |
import seaborn as sns | |
# --- Script --- | |
# get unique values | |
a, unique = pd.factorize(df.stack()) | |
# generate a color palette with as many colors | |
# as there are unique values | |
palette = sns.color_palette("pastel", len(unique)).as_hex() | |
out = pd.DataFrame( | |
a.reshape(df.shape), | |
index=df.index, | |
columns=df.columns, | |
) | |
out = out.replace(dict(enumerate(palette))) | |
return out.radd('background-color: ') | |
# --- Script --- | |
test = self.output | |
test = pd.pivot_table( | |
data=test, | |
index=test.index.isocalendar()["week"], | |
columns=test.index.isocalendar()["year"], | |
values="Cluster", | |
aggfunc=set, | |
) | |
test = test.astype(str) # convert everything to string | |
if color: display(test.style.apply(color_unique, axis=None)) | |
else: display(test) | |
def plot_clusters(self: Self) -> None: | |
"""Plot in subplots all the different weeks. | |
Function to plot all the week's and how the orders are spread in the | |
week, with subplots with the final result. It's interesting to use and | |
compare the results with the expectations. | |
Plots a Matplotlib subplot. | |
""" | |
# --- Import --- | |
import matplotlib.pyplot as plt | |
# --- Variables --- | |
clusters = self.data.to_frame() | |
output = self.output | |
# --- Script --- | |
# Get all the needed data in one single dataframe | |
index = [ | |
output.index.isocalendar()["week"], | |
output.index.isocalendar()["year"], | |
] | |
output = output.pivot_table( | |
index=index, | |
values="Cluster", | |
aggfunc="first", | |
) | |
mySeries = clusters.join(output) | |
# Convert to a dict of lists | |
mySeries = mySeries.groupby(by="Cluster").agg(list) | |
mySeries = mySeries.to_dict()["stores"] | |
# Define amount of plots to make | |
labels = list(mySeries.keys()) | |
amnt_options = len(labels) | |
width = 4 | |
depth = np.ceil(amnt_options / width) | |
fig, axs = plt.subplots( | |
ncols=width, | |
nrows=int(depth), | |
figsize=(25, 25), | |
) | |
fig.suptitle('Clusters') | |
# For every label, plot every series with that label | |
row_i, column_j = 0, 0 | |
for label in mySeries.keys(): | |
# Plot each series | |
series = mySeries[label] | |
series = np.array([row for row in series if len(row) == 7]) | |
for i in range(len(series)): | |
axs[row_i, column_j].plot(series[i], c="gray", alpha=0.4) | |
# Plot the average line | |
if len(series) > 0: | |
avg = np.average(np.vstack(series), axis=0) | |
axs[row_i, column_j].plot(avg, c="red") | |
# Set title | |
axs[row_i, column_j].set_title(label) | |
# Go to the next plot | |
column_j += 1 | |
if column_j % width == 0: | |
row_i += 1 | |
column_j = 0 | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment