Skip to content

Instantly share code, notes, and snippets.

@edumunozsala
Last active September 25, 2020 18:18
Show Gist options
  • Save edumunozsala/54393e669f9f67d14310b717f3324652 to your computer and use it in GitHub Desktop.
Save edumunozsala/54393e669f9f67d14310b717f3324652 to your computer and use it in GitHub Desktop.
Batch data generator for sequence of chars
def batch_generator_sequence(features_seq, label_seq, batch_size, seq_len):
"""Generator function that yields batches of data (input and target)
Args:
batch_size (int): number of examples (in this case, sentences) per batch.
max_length (int): maximum length of the output tensor.
NOTE: max_length includes the end-of-sentence character that will be added
to the tensor.
Keep in mind that the length of the tensor is always 1 + the length
of the original line of characters.
input_lines (list): list of the input data to group into batches.
target_lines (list): list of the target data to group into batches.
shuffle (bool, optional): True if the generator should generate random batches of data. Defaults to True.
Yields:
tuple: two copies of the batch and the mask
"""
# calculate the number of batches we can supply
num_batches = len(features_seq) // (batch_size * seq_len)
if num_batches == 0:
raise ValueError("No batches created. Use smaller batch size or sequence length.")
# calculate effective length of text to use
rounded_len = num_batches * batch_size * seq_len
# Reshape the features matrix in batch size x num_batches * seq_len
x = np.reshape(features_seq[: rounded_len], [batch_size, num_batches * seq_len])
# Reshape the target matrix in batch size x num_batches * seq_len
y = np.reshape(label_seq[: rounded_len], [batch_size, num_batches * seq_len])
epoch = 0
while True:
# roll so that no need to reset rnn states over epochs
x_epoch = np.split(np.roll(x, -epoch, axis=0), num_batches, axis=1)
y_epoch = np.split(np.roll(y, -epoch, axis=0), num_batches, axis=1)
for batch in range(num_batches):
yield x_epoch[batch], y_epoch[batch]
epoch += 1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment