Created
December 10, 2023 22:23
-
-
Save mzdravkov/9f5e2ffc7abe0a2b2f221b81c3d1534c to your computer and use it in GitHub Desktop.
Build time series sequences
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def build_sequences(time_series, valid_periods, categories, train_size, test_size): | |
""" | |
Creates all possible test sequences with size <test_size> which have | |
a training sequence of <train_size> in front. | |
""" | |
X = [] | |
y = [] | |
final_categories = [] | |
for ts, range, category in zip(time_series, valid_periods, categories): | |
valid_ts = cut_valid(ts, range) | |
size = len(valid_ts) | |
splits = (size - train_size) // test_size | |
if splits < 2: | |
if size < train_size + test_size: | |
padding_len = train_size + test_size - size | |
padding = np.zeros(padding_len, dtype='float32') | |
valid_ts = np.concatenate((padding, valid_ts)) | |
start = 0 | |
else: | |
start = size - train_size - test_size | |
X.append(valid_ts[start:-test_size]) | |
y.append(valid_ts[-test_size:]) | |
final_categories.append(category) | |
else: | |
ts_splitter = TimeSeriesSplit(n_splits=splits, max_train_size=train_size, test_size=test_size) | |
for train_seq_ix, test_seq_ix in ts_splitter.split(valid_ts): | |
X.append(valid_ts[train_seq_ix[0]:train_seq_ix[-1]+1]) | |
y.append(valid_ts[test_seq_ix[0]:test_seq_ix[-1]+1]) | |
final_categories.append(category) | |
return np.array(X), np.array(y), np.array(final_categories) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment