Skip to content

Instantly share code, notes, and snippets.

@mattsgithub
Created October 1, 2019 17:49
Show Gist options
  • Save mattsgithub/12ac44606587d18da18471df03593a74 to your computer and use it in GitHub Desktop.
Save mattsgithub/12ac44606587d18da18471df03593a74 to your computer and use it in GitHub Desktop.
Cross validation is a bit more complicated when dealing with time series. This class provides dates to train, validate, and test on.
class SampleForwardChainCV(object):
def __init__(self,
dates,
obs_count,
min_start_date=None,
max_end_date=None,
n_min_train_obs=20,
n_min_validate_obs=20,
n_min_test_obs=20):
self.dates = sorted(list(set(dates)))
self.obs_count = obs_count
self.min_start_date = min_start_date or self.dates[0]
self.max_end_date = max_end_date or self.dates[-1]
self.n_min_train_obs = n_min_train_obs
self.n_min_validate_obs = n_min_validate_obs
self.n_min_test_obs = n_min_test_obs
def split(self):
def get_end_index(start_index, i, min_obs):
"""Return the next index
given the sum of obs in
each element exceed the min
required
"""
n_obs = sum(self.obs_count[start_index:i + 1])
while n_obs < min_obs:
if self.dates[i] >= self.max_end_date:
raise IndexError()
i += 1
n_obs = sum(self.obs_count[start_index:i + 1])
return i
# Start at this index
train_start_index = self.dates.index(self.min_start_date)
# Iterate over all dates
for i in range(train_start_index, len(self.dates)):
try:
train_end_index = get_end_index(train_start_index, i, self.n_min_train_obs)
valid_start_index = train_end_index + 1
valid_end_index = get_end_index(valid_start_index,
valid_start_index,
self.n_min_validate_obs)
test_start_index = valid_end_index + 1
test_end_index = get_end_index(test_start_index,
test_start_index,
self.n_min_test_obs)
yield (self.dates[train_start_index], self.dates[train_end_index]), \
(self.dates[valid_start_index], self.dates[valid_end_index]), \
(self.dates[test_start_index], self.dates[test_end_index])
except:
break
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment