Skip to content

Instantly share code, notes, and snippets.

@MCRE-BE
Last active November 30, 2023 05:21
Show Gist options
  • Save MCRE-BE/42f699b2a1631cbdd4b9095b100ee92e to your computer and use it in GitHub Desktop.
Save MCRE-BE/42f699b2a1631cbdd4b9095b100ee92e to your computer and use it in GitHub Desktop.
Simple TS_Clustering class with some plotting options
"""TimeSeries clustering algorithm transformer."""
# %%
###################
# IMPORT STATMENT #
###################
from datetime import datetime, timedelta
from typing import TYPE_CHECKING, Optional, Self
import numpy as np
import pandas as pd
if TYPE_CHECKING:
import holidays
class TSClustering:
"""Simple timeseries clustering class.
Parameters
----------
data : pd.Series
Series to cluster
Attributes
----------
MIN_OBSERVATIONS : int
The minimum amount of observations in a cluster
CLUSTERING_PARAM : Dict[str, Any]
The clustering parameters passed to tslearn
"""
MIN_OBSERVATIONS = 6
CLUSTERING_PARAM = {
'n_clusters': 9, # 9 is found to be arbitrarely the best
'metric': "dtw",
'random_state': 42,
}
WEEK_MAPPING = {
'0000000': "NORMAL",
'0000001': "SUNDAY",
'0000010': 'SATURDAY',
'0000100': 'FRIDAY',
'0001000': 'THURSDAY',
'0010000': 'WEDNESDAY',
'0100000': 'TUESDAY',
'1000000': 'MONDAY',
}
def __init__(self: Self, data: pd.Series):
# Save the attributes that should not change.
self.data = data
self.fd = data.first_valid_index()
# Fix to the next sunday as we cluster per week, so we need to extend the
# output to the next available sunday
last_id = data.last_valid_index()
self.sd = self.get_next_weekday(last_id, 6)
@property
def index(self: Self) -> pd.DatetimeIndex:
"""Return the final index."""
return pd.date_range(self.fd, self.sd, freq="D", name="ds")
# ... Public Facing Methods ...
def compute_clusters(self: Self) -> Self:
"""Compute the timeseries clusters.
See Also
--------
_transform_input
_clean_clusters
_compute_clusters
_compute_bankholidays
"""
self._transform_input()
self._clean_clusters()
self._compute_clusters()
self._compute_bankholidays()
return self
def create_output(self: Self) -> Self:
"""Wrangle the output to the correct shape.
Attributes
----------
output : pd.DataFrame
Output dataframe with ds / Cluster
"""
# --- Variables ---
key = ["year", "week"]
# --- Script ---
# Get the final index
date = self.index.isocalendar()
date = date[key]
# Get the computed data
data = pd.concat([self.labels, self.holidays])
# Join together to get "ds" back in the idnex
data = date.join(data, on=key)
data = data.drop(columns=key)
# Fill the missing
# First fill all days at the end with the last known label
# This to prevent the last label to be a bankholiday label
# This can happen in EOY situations
check = np.bitwise_and(
data["Cluster"].isna(),
data.index > data.last_valid_index(),
)
last_label = self.labels.iloc[-1].values[0]
data.loc[check, "Cluster"] = last_label
# Then fill the rst
data = data.ffill()
self.output = data.copy()
return self
# ... Private API ...
def _transform_input(self: Self) -> Self:
"""Transfrom the input dataframe into a usable format.
Compute for the provided pd.Series the % of volume per weekday
and save that. Apply a basic filtering to drop completly empty
weeks from the output. Finally, split the regular weeks without
a bank holiday from the other weeks.
Attributes
----------
data: pd.Series
The prepared data
"""
# --- Variables ---
df = self.data.copy()
groupby = ['year', 'week', 'htype']
# --- Script ---
# First reindex to get a full week
df = df.reindex(self.index, fill_value=0)
# Compute the percentage per day of the week
index = df.index.isocalendar()
df = df.groupby(by=[index["year"], index["week"]])
df = df.apply(self.week_average)
df = df.to_frame()
# Add the weektype to be able to filter bank holidays
t = df.index.get_level_values("ds")
df["htype"] = [self.get_holidayweek(x) for x in t]
df["htype"] = df["htype"].map(self.WEEK_MAPPING)
# Group all values of a week in a list
df = df.groupby(by=groupby)
df = df["stores"].apply(list)
self.data = df.copy()
return self
def _clean_clusters(self: Self) -> Self:
"""Clean and split the prepared clustering.
Attributes
----------
feat: pd.DataFrame
The data to cluster
holidays: pd.DataFrame
Data that does not need to be clustered as one of the days
is a holiday week.
"""
# --- Variables ---
df = self.data.copy()
temp = df.index.get_level_values('htype')
# --- Script ---
# Split the data in to cluster and bank holidays
self.holidays = df[temp != "NORMAL"].copy()
# Drop all weeks that sum to zero.
to_cluster = df[temp == "NORMAL"].copy()
filter = to_cluster.apply(lambda x: np.nansum(x))
to_cluster = to_cluster[filter != 0]
to_cluster = to_cluster.reset_index("htype", drop=True)
self.feat = to_cluster.copy()
return self
def _compute_clusters(self: Self) -> Self:
"""Compute the time series clusters.
See Also
--------
tslearn.utils.to_time_series_dataset
tslearn.clustering.TimeSeriesKMeans
clean_output
Attributes
----------
labels : pd.DataFrame
The provided output labels per week
"""
# --- Import ---
from tslearn.clustering import TimeSeriesKMeans
from tslearn.utils import to_time_series_dataset
# --- Script ---
# Wrap the input to the correct shape
to_cluster = self.feat.copy()
mySeries = to_time_series_dataset(to_cluster.values)
km = TimeSeriesKMeans(**self.CLUSTERING_PARAM)
labels = km.fit_predict(mySeries)
# Prepare output
labels = pd.Series(
labels,
index=to_cluster.index,
name="Cluster",
)
labels = self.clean_output(labels, self.MIN_OBSERVATIONS)
self.labels = labels.to_frame()
return self
def _compute_bankholidays(self: Self) -> Self:
"""Compute the ts clusters corresponding to weeks with bank holidays.
Attributes
----------
holidays : pd.DataFrame
"""
df = self.holidays.copy()
df = df.to_frame()
df = df[[]]
df = df.reset_index("htype")
df.columns = ["Cluster"]
self.holidays = df.copy()
return self
# ... Helper Methods ...
@staticmethod
def week_average(x: pd.Series) -> float:
"""Compute the weekly average."""
if x.sum().sum() == 0: return x
return np.round(x * 100 / x.sum(), 2)
@staticmethod
def clean_output(
labels: pd.Series,
min_observations: int = 6,
) -> pd.Series:
"""Clean the computed timeseries clusters.
Clean the computed clusters to remove all the very small clusters
and fill all the missing "regular" weeks with the last known cluster
next to it.
Parameters
----------
labels : pd.Series
The series of labels coming out of the timeseries clustering algorithm
min_observations: int, by default=6
The minimum number of observations in a cluster before we keep it.
"""
df = labels.copy()
# Find all cluster labels with low occurence rate
# 6 as lowest amount of bank holidays
low = df.value_counts() / 7
low = low[low < min_observations]
low = [x for x in low.index]
# Replace all small values with Other
replace = {
k: "Other"
for k in low
}
df = df.replace(replace)
# Replace all values to get an increasing number
labels = df.unique().tolist()
replace = {
v: f"NORMAL_{k}"
for k, v in enumerate(labels)
}
df = df.replace(replace)
return df
@staticmethod
def get_holidayweek(
d: datetime,
hds: Optional["holidays.countries.belgium.Belgium"] = None,
) -> str:
"""Function to retrieve holidays of the week as string for a day.
Parameters
----------
d : datetime
A date in datetime format
hds : holidays dict
Output from holidays package with all holidays
for a specific country
Returns
-------
output : str
String in format "0000000"
with 0 = No bank holiday
with 1 = A bank holiday
"""
# --- Import ---
from datetime import timedelta
import holidays
# --- Checks ---
if hds is None: hds = holidays.Belgium()
# --- Script ---
wd = d.weekday()
wds = [d + timedelta(days=int(i)) for i in range(-wd, 7 - wd)]
wdhs = [str(int(i in hds)) for i in wds]
wdhs = "".join(wdhs)
return wdhs
@staticmethod
def get_next_weekday(
ori: datetime,
weekday: int,
) -> datetime:
"""Get the next weekday from the origin.
Parameters
----------
ori : datetime,
The point of origin from which to advance
weekday : int
Which weekday to retrieve
"""
return ori + timedelta(days=weekday + ori.weekday())
# ... Plotting ...
def plot_summary_table(
self: Self,
color: bool = True,
) -> None:
"""Plot the found clusters.
Parameters
----------
color : bool, by default=True
Whether the output needs to be with color
"""
# --- Import ---
from IPython.display import display
# --- Function ---
def color_unique(df):
"""Make a dataframe with unique colors for each value."""
# --- Import ---
import seaborn as sns
# --- Script ---
# get unique values
a, unique = pd.factorize(df.stack())
# generate a color palette with as many colors
# as there are unique values
palette = sns.color_palette("pastel", len(unique)).as_hex()
out = pd.DataFrame(
a.reshape(df.shape),
index=df.index,
columns=df.columns,
)
out = out.replace(dict(enumerate(palette)))
return out.radd('background-color: ')
# --- Script ---
test = self.output
test = pd.pivot_table(
data=test,
index=test.index.isocalendar()["week"],
columns=test.index.isocalendar()["year"],
values="Cluster",
aggfunc=set,
)
test = test.astype(str) # convert everything to string
if color: display(test.style.apply(color_unique, axis=None))
else: display(test)
def plot_clusters(self: Self) -> None:
"""Plot in subplots all the different weeks.
Function to plot all the week's and how the orders are spread in the
week, with subplots with the final result. It's interesting to use and
compare the results with the expectations.
Plots a Matplotlib subplot.
"""
# --- Import ---
import matplotlib.pyplot as plt
# --- Variables ---
clusters = self.data.to_frame()
output = self.output
# --- Script ---
# Get all the needed data in one single dataframe
index = [
output.index.isocalendar()["week"],
output.index.isocalendar()["year"],
]
output = output.pivot_table(
index=index,
values="Cluster",
aggfunc="first",
)
mySeries = clusters.join(output)
# Convert to a dict of lists
mySeries = mySeries.groupby(by="Cluster").agg(list)
mySeries = mySeries.to_dict()["stores"]
# Define amount of plots to make
labels = list(mySeries.keys())
amnt_options = len(labels)
width = 4
depth = np.ceil(amnt_options / width)
fig, axs = plt.subplots(
ncols=width,
nrows=int(depth),
figsize=(25, 25),
)
fig.suptitle('Clusters')
# For every label, plot every series with that label
row_i, column_j = 0, 0
for label in mySeries.keys():
# Plot each series
series = mySeries[label]
series = np.array([row for row in series if len(row) == 7])
for i in range(len(series)):
axs[row_i, column_j].plot(series[i], c="gray", alpha=0.4)
# Plot the average line
if len(series) > 0:
avg = np.average(np.vstack(series), axis=0)
axs[row_i, column_j].plot(avg, c="red")
# Set title
axs[row_i, column_j].set_title(label)
# Go to the next plot
column_j += 1
if column_j % width == 0:
row_i += 1
column_j = 0
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment