Last active
January 3, 2022 13:53
-
-
Save adoankim/e03dc6a82b83c1eaa41bd735734a023c to your computer and use it in GitHub Desktop.
Python generator to create slices of indices for batching
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def batches_generator(list_to_batch_lenght, batch_size): | |
""" Generate `batch_size` batches of indice slices for a list with len of `list_to_batch_lenght`""" | |
assert list_to_batch_lenght > 0, "list_to_batch_lenght should be a non-zero positive number" | |
assert batch_size > 0, "list_to_batch_lenght should be a non-zero positive number" | |
batches = list_to_batch_lenght // batch_size | |
for i in range(batches): | |
start, end = i * batch_size, (i + 1) * batch_size | |
yield start, end | |
small_end_batch_index = batches * batch_size | |
small_end_batch = range(list_to_batch_lenght)[small_end_batch_index:] | |
if len(small_end_batch) > 0: | |
yield small_end_batch_index, None | |
return None | |
def tests(): | |
print("With a list of 150 elements and a batch size of 50:") | |
batch_size = 50 | |
x = range(1, 151) | |
y = list(batches_generator(len(x), batch_size)) | |
assert len(y) == 3, f"Number of batches should be 3" | |
batch_1 = x[y[0][0]: y[0][1]] | |
batch_2 = x[y[1][0]: y[1][1]] | |
batch_3 = x[y[2][0]: y[2][1]] | |
assert len(batch_1) == batch_size, f"First batch size should be {batch_size}" | |
assert len(batch_2) == batch_size, f"Second batch size should be {batch_size}" | |
assert len(batch_3) == batch_size, f"Last batch size should be {batch_size}" | |
print(" pass") | |
print("With a list of 130 elements and a batch size of 50:") | |
batch_size = 50 | |
x = range(1, 131) | |
y = list(batches_generator(len(x), batch_size)) | |
assert len(y) == 3, f"Number of batches should be 3" | |
batch_1 = x[y[0][0]: y[0][1]] | |
batch_2 = x[y[1][0]: y[1][1]] | |
batch_3 = x[y[2][0]: y[2][1]] | |
assert len(batch_1) == batch_size, f"First batch size should be {batch_size}" | |
assert len(batch_2) == batch_size, f"Second batch size should be {batch_size}" | |
assert len(batch_3) == 30, "Last batch size should be 30" | |
print(" pass") | |
print("With a list of 4 elements and a batch size of 50:") | |
batch_size = 50 | |
x = range(1, 5) | |
y = list(batches_generator(len(x), batch_size)) | |
assert len(y) == 1, f"Number of batches should be 1" | |
batch_1 = x[y[0][0]: y[0][1]] | |
assert len(batch_1) == len(x), f"Single batch size should be {len(x)}" | |
print(" pass") | |
print("Passing negative numbers to the generator shoud fail:") | |
try: | |
y = list(batches_generator(-1, batch_size)) | |
y = list(batches_generator(len(x), -1)) | |
assert True, "batches_generator arguments must be non-zero positive numbers" | |
except: | |
print(" pass") | |
print("all right!") | |
tests() | |
def sample(): | |
x = range(0, 10500) | |
count = len(x) | |
batch_size = 5000 | |
for start, end in batches_generator(count, batch_size): | |
print(x[start:end]) | |
sample() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment