Skip to content

Instantly share code, notes, and snippets.

@chricke
Created March 26, 2019 10:08
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save chricke/f22b05033e7e468297abf89678521909 to your computer and use it in GitHub Desktop.
Save chricke/f22b05033e7e468297abf89678521909 to your computer and use it in GitHub Desktop.
Create batches
def get_batches(int_text, batch_size, seq_length):
n_batches = len(int_text) // (batch_size * seq_length)
words = np.asarray(int_text[:n_batches*(batch_size * seq_length)])
batches = np.zeros(shape=(n_batches, 2, batch_size, seq_length))
input_sequences = words.reshape(-1, seq_length)
target_sequences = np.roll(words, -1)
target_sequences = target_sequences.reshape(-1, seq_length)
for idx in range(0, input_sequences.shape[0]):
input_idx = idx % n_batches
target_idx = idx // n_batches
batches[input_idx,0,target_idx,:] = input_sequences[idx,:]
batches[input_idx,1,target_idx,:] = target_sequences[idx,:]
return batches
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment