Skip to content

Instantly share code, notes, and snippets.

@cjlovering
Created April 24, 2018 22:15
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save cjlovering/b2c1b389dea999be6dea4e6f6f48be06 to your computer and use it in GitHub Desktop.
Save cjlovering/b2c1b389dea999be6dea4e6f6f48be06 to your computer and use it in GitHub Desktop.
Accurately get the number of batches.
for batch in range(num_data // batch_size + (num_data % batch_size > 0)):

Ever run into errors when training models that have to do with running one too many or two few batches?

This snippet should fix it.

  1. First get the number of iterations that the batch fits cleanly in the number of data points. E.g. num_data // batch_size.
  2. Add 1 if there are remaining data points. E.g. (num_data % batch_size > 0)1.

This does assume python3.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment