Created
March 8, 2017 08:11
-
-
Save nicolasdespres/81689421f56b86a315a81f19d301508a to your computer and use it in GitHub Desktop.
Iterate over shuffled batch of sliding window on a sequence.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class iter_shuffle_batch_window(Iterator): | |
"""Iterate a window over an in-memory sequence of data. | |
Args: | |
`window_size`: The size of the window (must be smaller than the data) | |
`window_alignment`: How to align the window around the point: "left", | |
"right", or "center". | |
`shifts`: A list of indices to shifts the windows from. By default | |
it is `[0]` but you can set it `[0, 1]` generate batch of | |
inputs and targets windows. | |
`packer`: A function to pack each batch of windows into a container. | |
By default it is `list` but you can set it to `np.stack` | |
to get a numpy array. | |
""" | |
def __init__(self, data, | |
batch_size=None, | |
window_size=None, | |
shuffle=True, | |
allow_smaller_final_batch=False, | |
num_cycles=1, | |
window_alignment="left", | |
shifts=None, | |
packer=list): | |
self.data = data | |
if not isinstance(window_size, int): | |
raise TypeError("window_size must be int, not {}" | |
.format(type(window_size).__name__)) | |
if window_size <= 0: | |
raise ValueError("window_size must be positive") | |
self._window_size = window_size | |
self._window_alignment = window_alignment | |
self._shifts = [0] if shifts is None else shifts | |
assert all(i >=0 for i in self._shifts), \ | |
"shifts value must be all positive or null" | |
self._range = window_range(len(self.data), | |
size=self._window_size + max(self._shifts), | |
alignment=self._window_alignment) | |
self._take = partial(window_at, | |
size=self._window_size, | |
alignment=self._window_alignment) | |
self._pack = packer | |
def reset(self): | |
self._it = iter_shuffle_batch_range( | |
self._range, | |
batch_size=batch_size, | |
shuffle=shuffle, | |
allow_smaller_final_batch=allow_smaller_final_batch, | |
num_cycles=num_cycles) | |
self.reset = types.MethodType(reset, self) | |
self.reset() | |
def __len__(self): | |
return len(self._it) | |
@property | |
def batch_size(self): | |
"""The size of each batch.""" | |
return self._it.batch_size | |
@property | |
def window_size(self): | |
return self._window_size | |
@property | |
def window_alignment(self): | |
return self._window_alignment | |
@property | |
def shifts(self): | |
return self._shifts | |
@property | |
def allow_smaller_final_batch(self): | |
return self._it.allow_smaller_final_batch | |
@property | |
def size(self): | |
return self._it.size | |
@property | |
def num_cycles(self): | |
return self._it.num_cycles | |
@property | |
def shuffle(self): | |
return self._it.shuffle | |
@property | |
def steps_per_epoch(self): | |
return self._it.steps_per_epoch | |
@property | |
def epoch(self): | |
return self._it.epoch | |
@property | |
def step(self): | |
return self._it.step | |
def __next__(self): | |
# Take a batch of indices pointing to the beginning of a window. | |
# For each index we get a slice of our data starting from it. | |
batch = next(self._it) | |
windows = [] | |
for s in self._shifts: | |
windows.append(self._pack( | |
[self._take(self.data, i+s) for i in batch])) | |
return windows |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment