Skip to content

Instantly share code, notes, and snippets.

@albertlai431
Created February 7, 2019 22:17
Show Gist options
  • Save albertlai431/d926976ad21680499286af1a86414070 to your computer and use it in GitHub Desktop.
Save albertlai431/d926976ad21680499286af1a86414070 to your computer and use it in GitHub Desktop.
# Defining method to make mini-batches for training
def get_batches(arr, batch_size, seq_length):
'''Create a generator that returns batches of size
batch_size x seq_length from arr.
Arguments
---------
arr: Array you want to make batches from
batch_size: Batch size, the number of sequences per batch
seq_length: Number of encoded chars in a sequence
'''
batch_size_total = batch_size * seq_length
# total number of batches we can make
n_batches = len(arr)//batch_size_total
# Keep only enough characters to make full batches
arr = arr[:n_batches * batch_size_total]
# Reshape into batch_size rows
arr = arr.reshape((batch_size, -1))
# iterate through the array, one sequence at a time
for n in range(0, arr.shape[1], seq_length):
# The features
x = arr[:, n:n+seq_length]
# The targets, shifted by one
y = np.zeros_like(x)
try:
y[:, :-1], y[:, -1] = x[:, 1:], arr[:, n+seq_length]
except IndexError:
y[:, :-1], y[:, -1] = x[:, 1:], arr[:, 0]
yield x, y
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment