Created
March 8, 2017 08:13
-
-
Save nicolasdespres/d380cc43ebb1eb2d780b816fde294869 to your computer and use it in GitHub Desktop.
Iterator over batch of shuffled indices from a range.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class iter_shuffle_batch_range(Iterator): | |
"""Iterate over batch of (potentially shuffled) indices in range. | |
The range is seen as a list of indices (potentially shuffled) | |
that is repeated `num_cycles` times. This iterator sequentially returns | |
chunk of `batch_size` items of this sequence. | |
This object tries to follow the same API as `tf.train.shuffle_batch` | |
functions so that it can be used easily plugged with other TensorFlow's | |
routines. The `num_epochs` argument is a noticeable exception because it | |
has been renamed `num_cycles` for clarity. | |
Example: | |
>>> it = iter_shuffle_batch_range(5, batch_size=2, num_cycles=3, | |
shuffle=False) | |
for batch in it: | |
print("{}|{:.02f}| {!r}".format(it.step, it.epoch, batch)) | |
0|0.40| [0, 1] | |
1|0.80| [2, 3] | |
2|1.20| [4, 0] | |
3|1.60| [1, 2] | |
4|2.00| [3, 4] | |
5|2.40| [0, 1] | |
6|2.80| [2, 3] | |
>>> list(iter_shuffle_batch_range(5, batch_size=2, num_cycles=3, | |
shuffle=False, | |
allow_smaller_final_batch=True)) | |
[[0, 1], [2, 3], [4, 0], [1, 2], [3, 4], [0, 1], [2, 3], [4]] | |
Args: | |
`r`: A range to iterate on. | |
`batch_size`: Number of rows in each batch. | |
`num_cycles`: How many times I should iterate over the entire dataset. | |
`shuffle`: Whether to shuffle the items (use `random.seed` to set the | |
the random seed). | |
`allow_smaller_final_batch`: whether to return a smaller final batch | |
or to discard the last items. | |
Output: | |
A list of `batch_size` indices. | |
""" | |
def __init__(self, r, batch_size=None, shuffle=True, | |
allow_smaller_final_batch=False, num_cycles=1): | |
def reset(self): | |
self._src = cycle_range(r, times=num_cycles) | |
self._shuf = shuffle_iter(self._src) if shuffle else None | |
self._it = batch_iter( | |
self._shuf if shuffle else self._src, | |
batch_size=batch_size, | |
allow_smaller_final_batch=allow_smaller_final_batch) | |
self._steps_per_epoch = self.size / batch_size | |
self.reset = types.MethodType(reset, self) | |
self.reset() | |
def __next__(self): | |
return next(self._it) | |
@property | |
def shuffle(self): | |
"""Whether samples are shuffled between each iteration.""" | |
return self._shuf is not None | |
@property | |
def batch_size(self): | |
"""The size of each batch.""" | |
return self._it.batch_size | |
@property | |
def size(self): | |
"""The length of the sequence as passed to the constructor.""" | |
return self._src.size | |
@property | |
def allow_smaller_final_batch(self): | |
"""Whether the final batch is allowed to be smaller than `batch_size`. | |
The final batch is the last batch at the end of the repeated sequence. | |
""" | |
return self._it.allow_smaller_final_batch | |
@property | |
def num_cycles(self): | |
return self._src.times | |
def __len__(self): | |
"""Return the total number of iteration.""" | |
return len(self._it) | |
@property | |
def steps_per_epoch(self): | |
"""The number of iteration per epoch. | |
May not be an integer. | |
""" | |
return self._steps_per_epoch | |
@property | |
def epoch(self): | |
"""Return the current epoch.""" | |
return (self.step + 1) / self._steps_per_epoch | |
@property | |
def step(self): | |
"""Return the current iteration step. | |
This is equivalent to calling `enumerate` on this iterator. | |
""" | |
return self._it.step |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment