Create batches
def get_batches(int_text, batch_size, seq_length): | |
n_batches = len(int_text) // (batch_size * seq_length) | |
words = np.asarray(int_text[:n_batches*(batch_size * seq_length)]) | |
batches = np.zeros(shape=(n_batches, 2, batch_size, seq_length)) | |
input_sequences = words.reshape(-1, seq_length) | |
target_sequences = np.roll(words, -1) | |
target_sequences = target_sequences.reshape(-1, seq_length) | |
for idx in range(0, input_sequences.shape[0]): | |
input_idx = idx % n_batches | |
target_idx = idx // n_batches | |
batches[input_idx,0,target_idx,:] = input_sequences[idx,:] | |
batches[input_idx,1,target_idx,:] = target_sequences[idx,:] | |
return batches |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment