Skip to content

Instantly share code, notes, and snippets.

@agastidukare
Created April 1, 2020 02:09
Show Gist options
  • Save agastidukare/d3982fb146fa87af6194fd2c9da45c42 to your computer and use it in GitHub Desktop.
Save agastidukare/d3982fb146fa87af6194fd2c9da45c42 to your computer and use it in GitHub Desktop.
batch_size = 128
num_complete_batches, leftover = divmod(num_train, batch_size)
num_batches = num_complete_batches + bool(leftover)
def data_stream():
rng = npr.RandomState(0)
while True:
perm = rng.permutation(num_train)
for i in range(num_batches):
batch_idx = perm[i * batch_size:(i + 1) * batch_size]
yield train_images[batch_idx], train_labels[batch_idx]
batches = data_stream()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment