Created
April 27, 2017 14:45
-
-
Save msalvaris/ffa960e62af68ab3c70dfb1848aa5755 to your computer and use it in GitHub Desktop.
A simple function to generate variable window sequences for LSTMs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def _generate_start_and_end(num_elements, sequence_length): | |
start_gen = chain(repeat(0, times=sequence_length), | |
range(1, num_elements)) | |
end_gen = chain(range(1, sequence_length), | |
range(sequence_length, num_elements)) | |
for start, stop in zip_longest(start_gen, end_gen, fillvalue=num_elements): | |
yield start, stop | |
def generate_variable_window(timeseries_array, sequence_length): | |
""" Generates a rolling window of the timeseries passed in | |
The window will initially expand until it reaches the value specified by sequence_length then | |
it will continue to deliver sequences of length sequence_length from timeseries_array. | |
When it reaches the end of the timeseries_array it will shrink until it runs out of elements. | |
Parameters | |
---------- | |
timeseries_array: 2D Numpy array where the columns are the features and the rows the samples | |
sequence_length: integer representing the maximum size we want the sequence to grow to | |
""" | |
num_elements = timeseries_array.shape[0] | |
for start, stop in _generate_start_and_end(num_elements, sequence_length): | |
yield timeseries_array[start:stop, :] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment