Skip to content

Instantly share code, notes, and snippets.

@adoankim
Last active January 3, 2022 13:53
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save adoankim/e03dc6a82b83c1eaa41bd735734a023c to your computer and use it in GitHub Desktop.
Save adoankim/e03dc6a82b83c1eaa41bd735734a023c to your computer and use it in GitHub Desktop.
Python generator to create slices of indices for batching
def batches_generator(list_to_batch_lenght, batch_size):
""" Generate `batch_size` batches of indice slices for a list with len of `list_to_batch_lenght`"""
assert list_to_batch_lenght > 0, "list_to_batch_lenght should be a non-zero positive number"
assert batch_size > 0, "list_to_batch_lenght should be a non-zero positive number"
batches = list_to_batch_lenght // batch_size
for i in range(batches):
start, end = i * batch_size, (i + 1) * batch_size
yield start, end
small_end_batch_index = batches * batch_size
small_end_batch = range(list_to_batch_lenght)[small_end_batch_index:]
if len(small_end_batch) > 0:
yield small_end_batch_index, None
return None
def tests():
print("With a list of 150 elements and a batch size of 50:")
batch_size = 50
x = range(1, 151)
y = list(batches_generator(len(x), batch_size))
assert len(y) == 3, f"Number of batches should be 3"
batch_1 = x[y[0][0]: y[0][1]]
batch_2 = x[y[1][0]: y[1][1]]
batch_3 = x[y[2][0]: y[2][1]]
assert len(batch_1) == batch_size, f"First batch size should be {batch_size}"
assert len(batch_2) == batch_size, f"Second batch size should be {batch_size}"
assert len(batch_3) == batch_size, f"Last batch size should be {batch_size}"
print(" pass")
print("With a list of 130 elements and a batch size of 50:")
batch_size = 50
x = range(1, 131)
y = list(batches_generator(len(x), batch_size))
assert len(y) == 3, f"Number of batches should be 3"
batch_1 = x[y[0][0]: y[0][1]]
batch_2 = x[y[1][0]: y[1][1]]
batch_3 = x[y[2][0]: y[2][1]]
assert len(batch_1) == batch_size, f"First batch size should be {batch_size}"
assert len(batch_2) == batch_size, f"Second batch size should be {batch_size}"
assert len(batch_3) == 30, "Last batch size should be 30"
print(" pass")
print("With a list of 4 elements and a batch size of 50:")
batch_size = 50
x = range(1, 5)
y = list(batches_generator(len(x), batch_size))
assert len(y) == 1, f"Number of batches should be 1"
batch_1 = x[y[0][0]: y[0][1]]
assert len(batch_1) == len(x), f"Single batch size should be {len(x)}"
print(" pass")
print("Passing negative numbers to the generator shoud fail:")
try:
y = list(batches_generator(-1, batch_size))
y = list(batches_generator(len(x), -1))
assert True, "batches_generator arguments must be non-zero positive numbers"
except:
print(" pass")
print("all right!")
tests()
def sample():
x = range(0, 10500)
count = len(x)
batch_size = 5000
for start, end in batches_generator(count, batch_size):
print(x[start:end])
sample()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment