Skip to content

Instantly share code, notes, and snippets.

@mehdidc
Last active September 19, 2016 15:25
Show Gist options
  • Save mehdidc/ea52a3524f0ae614f665bd08eff7d38e to your computer and use it in GitHub Desktop.
Save mehdidc/ea52a3524f0ae614f665bd08eff7d38e to your computer and use it in GitHub Desktop.
def minibatcher(fn, batchsize=1000):
"""
fn : a function that takes an input and returns an output
batchsize : divide the total input into divisions of size batchsize at most
iterate through all the divisions, call fn, get the results,
then concatenate all the results.
"""
def f(X):
results = []
for sl in iterate_minibatches(len(X), batchsize):
results.append(fn(X[sl]))
return np.concatenate(results, axis=0)
return f
def iterate_minibatches(nb_inputs, batchsize, shuffle=False):
if shuffle:
indices = np.arange(nb_inputs)
np.random.shuffle(indices)
for start_idx in range(0, max(nb_inputs, nb_inputs - batchsize + 1), batchsize):
if shuffle:
excerpt = indices[start_idx:start_idx + batchsize]
else:
excerpt = slice(start_idx, start_idx + batchsize)
yield excerpt
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment