Skip to content

Instantly share code, notes, and snippets.

@nicolasdespres
Created March 8, 2017 08:13
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save nicolasdespres/d380cc43ebb1eb2d780b816fde294869 to your computer and use it in GitHub Desktop.
Save nicolasdespres/d380cc43ebb1eb2d780b816fde294869 to your computer and use it in GitHub Desktop.
Iterator over batch of shuffled indices from a range.
class iter_shuffle_batch_range(Iterator):
"""Iterate over batch of (potentially shuffled) indices in range.
The range is seen as a list of indices (potentially shuffled)
that is repeated `num_cycles` times. This iterator sequentially returns
chunk of `batch_size` items of this sequence.
This object tries to follow the same API as `tf.train.shuffle_batch`
functions so that it can be used easily plugged with other TensorFlow's
routines. The `num_epochs` argument is a noticeable exception because it
has been renamed `num_cycles` for clarity.
Example:
>>> it = iter_shuffle_batch_range(5, batch_size=2, num_cycles=3,
shuffle=False)
for batch in it:
print("{}|{:.02f}| {!r}".format(it.step, it.epoch, batch))
0|0.40| [0, 1]
1|0.80| [2, 3]
2|1.20| [4, 0]
3|1.60| [1, 2]
4|2.00| [3, 4]
5|2.40| [0, 1]
6|2.80| [2, 3]
>>> list(iter_shuffle_batch_range(5, batch_size=2, num_cycles=3,
shuffle=False,
allow_smaller_final_batch=True))
[[0, 1], [2, 3], [4, 0], [1, 2], [3, 4], [0, 1], [2, 3], [4]]
Args:
`r`: A range to iterate on.
`batch_size`: Number of rows in each batch.
`num_cycles`: How many times I should iterate over the entire dataset.
`shuffle`: Whether to shuffle the items (use `random.seed` to set the
the random seed).
`allow_smaller_final_batch`: whether to return a smaller final batch
or to discard the last items.
Output:
A list of `batch_size` indices.
"""
def __init__(self, r, batch_size=None, shuffle=True,
allow_smaller_final_batch=False, num_cycles=1):
def reset(self):
self._src = cycle_range(r, times=num_cycles)
self._shuf = shuffle_iter(self._src) if shuffle else None
self._it = batch_iter(
self._shuf if shuffle else self._src,
batch_size=batch_size,
allow_smaller_final_batch=allow_smaller_final_batch)
self._steps_per_epoch = self.size / batch_size
self.reset = types.MethodType(reset, self)
self.reset()
def __next__(self):
return next(self._it)
@property
def shuffle(self):
"""Whether samples are shuffled between each iteration."""
return self._shuf is not None
@property
def batch_size(self):
"""The size of each batch."""
return self._it.batch_size
@property
def size(self):
"""The length of the sequence as passed to the constructor."""
return self._src.size
@property
def allow_smaller_final_batch(self):
"""Whether the final batch is allowed to be smaller than `batch_size`.
The final batch is the last batch at the end of the repeated sequence.
"""
return self._it.allow_smaller_final_batch
@property
def num_cycles(self):
return self._src.times
def __len__(self):
"""Return the total number of iteration."""
return len(self._it)
@property
def steps_per_epoch(self):
"""The number of iteration per epoch.
May not be an integer.
"""
return self._steps_per_epoch
@property
def epoch(self):
"""Return the current epoch."""
return (self.step + 1) / self._steps_per_epoch
@property
def step(self):
"""Return the current iteration step.
This is equivalent to calling `enumerate` on this iterator.
"""
return self._it.step
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment