Skip to content

Instantly share code, notes, and snippets.

@nicolasdespres
Created March 8, 2017 08:11
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save nicolasdespres/81689421f56b86a315a81f19d301508a to your computer and use it in GitHub Desktop.
Save nicolasdespres/81689421f56b86a315a81f19d301508a to your computer and use it in GitHub Desktop.
Iterate over shuffled batch of sliding window on a sequence.
class iter_shuffle_batch_window(Iterator):
"""Iterate a window over an in-memory sequence of data.
Args:
`window_size`: The size of the window (must be smaller than the data)
`window_alignment`: How to align the window around the point: "left",
"right", or "center".
`shifts`: A list of indices to shifts the windows from. By default
it is `[0]` but you can set it `[0, 1]` generate batch of
inputs and targets windows.
`packer`: A function to pack each batch of windows into a container.
By default it is `list` but you can set it to `np.stack`
to get a numpy array.
"""
def __init__(self, data,
batch_size=None,
window_size=None,
shuffle=True,
allow_smaller_final_batch=False,
num_cycles=1,
window_alignment="left",
shifts=None,
packer=list):
self.data = data
if not isinstance(window_size, int):
raise TypeError("window_size must be int, not {}"
.format(type(window_size).__name__))
if window_size <= 0:
raise ValueError("window_size must be positive")
self._window_size = window_size
self._window_alignment = window_alignment
self._shifts = [0] if shifts is None else shifts
assert all(i >=0 for i in self._shifts), \
"shifts value must be all positive or null"
self._range = window_range(len(self.data),
size=self._window_size + max(self._shifts),
alignment=self._window_alignment)
self._take = partial(window_at,
size=self._window_size,
alignment=self._window_alignment)
self._pack = packer
def reset(self):
self._it = iter_shuffle_batch_range(
self._range,
batch_size=batch_size,
shuffle=shuffle,
allow_smaller_final_batch=allow_smaller_final_batch,
num_cycles=num_cycles)
self.reset = types.MethodType(reset, self)
self.reset()
def __len__(self):
return len(self._it)
@property
def batch_size(self):
"""The size of each batch."""
return self._it.batch_size
@property
def window_size(self):
return self._window_size
@property
def window_alignment(self):
return self._window_alignment
@property
def shifts(self):
return self._shifts
@property
def allow_smaller_final_batch(self):
return self._it.allow_smaller_final_batch
@property
def size(self):
return self._it.size
@property
def num_cycles(self):
return self._it.num_cycles
@property
def shuffle(self):
return self._it.shuffle
@property
def steps_per_epoch(self):
return self._it.steps_per_epoch
@property
def epoch(self):
return self._it.epoch
@property
def step(self):
return self._it.step
def __next__(self):
# Take a batch of indices pointing to the beginning of a window.
# For each index we get a slice of our data starting from it.
batch = next(self._it)
windows = []
for s in self._shifts:
windows.append(self._pack(
[self._take(self.data, i+s) for i in batch]))
return windows
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment