Skip to content

Instantly share code, notes, and snippets.

@orhermansaffar
Last active July 19, 2023 14:28
Show Gist options
  • Save orhermansaffar/2bd2342c81026de1c09c97d66226eb46 to your computer and use it in GitHub Desktop.
Save orhermansaffar/2bd2342c81026de1c09c97d66226eb46 to your computer and use it in GitHub Desktop.
import pandas as pd
import datetime
from datetime import datetime as dt
from dateutil.relativedelta import *
class TimeBasedCV(object):
'''
Parameters
----------
train_period: int
number of time units to include in each train set
default is 30
test_period: int
number of time units to include in each test set
default is 7
freq: string
frequency of input parameters. possible values are: days, months, years, weeks, hours, minutes, seconds
possible values designed to be used by dateutil.relativedelta class
deafault is days
'''
def __init__(self, train_period=30, test_period=7, freq='days'):
self.train_period = train_period
self.test_period = test_period
self.freq = freq
def split(self, data, validation_split_date=None, date_column='record_date', gap=0):
'''
Generate indices to split data into training and test set
Parameters
----------
data: pandas DataFrame
your data, contain one column for the record date
validation_split_date: datetime.date()
first date to perform the splitting on.
if not provided will set to be the minimum date in the data after the first training set
date_column: string, deafult='record_date'
date of each record
gap: int, default=0
for cases the test set does not come right after the train set,
*gap* days are left between train and test sets
Returns
-------
train_index ,test_index:
list of tuples (train index, test index) similar to sklearn model selection
'''
# check that date_column exist in the data:
try:
data[date_column]
except:
raise KeyError(date_column)
train_indices_list = []
test_indices_list = []
if validation_split_date==None:
validation_split_date = data[date_column].min().date() + eval('relativedelta('+self.freq+'=self.train_period)')
start_train = validation_split_date - eval('relativedelta('+self.freq+'=self.train_period)')
end_train = start_train + eval('relativedelta('+self.freq+'=self.train_period)')
start_test = end_train + eval('relativedelta('+self.freq+'=gap)')
end_test = start_test + eval('relativedelta('+self.freq+'=self.test_period)')
while end_test < data[date_column].max().date():
# train indices:
cur_train_indices = list(data[(data[date_column].dt.date>=start_train) &
(data[date_column].dt.date<end_train)].index)
# test indices:
cur_test_indices = list(data[(data[date_column].dt.date>=start_test) &
(data[date_column].dt.date<end_test)].index)
print("Train period:",start_train,"-" , end_train, ", Test period", start_test, "-", end_test,
"# train records", len(cur_train_indices), ", # test records", len(cur_test_indices))
train_indices_list.append(cur_train_indices)
test_indices_list.append(cur_test_indices)
# update dates:
start_train = start_train + eval('relativedelta('+self.freq+'=self.test_period)')
end_train = start_train + eval('relativedelta('+self.freq+'=self.train_period)')
start_test = end_train + eval('relativedelta('+self.freq+'=gap)')
end_test = start_test + eval('relativedelta('+self.freq+'=self.test_period)')
# mimic sklearn output
index_output = [(train,test) for train,test in zip(train_indices_list,test_indices_list)]
self.n_splits = len(index_output)
return index_output
def get_n_splits(self):
"""Returns the number of splitting iterations in the cross-validator
Returns
-------
n_splits : int
Returns the number of splitting iterations in the cross-validator.
"""
return self.n_splits
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment