Last active
September 25, 2020 18:18
-
-
Save edumunozsala/54393e669f9f67d14310b717f3324652 to your computer and use it in GitHub Desktop.
Batch data generator for sequence of chars
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def batch_generator_sequence(features_seq, label_seq, batch_size, seq_len): | |
"""Generator function that yields batches of data (input and target) | |
Args: | |
batch_size (int): number of examples (in this case, sentences) per batch. | |
max_length (int): maximum length of the output tensor. | |
NOTE: max_length includes the end-of-sentence character that will be added | |
to the tensor. | |
Keep in mind that the length of the tensor is always 1 + the length | |
of the original line of characters. | |
input_lines (list): list of the input data to group into batches. | |
target_lines (list): list of the target data to group into batches. | |
shuffle (bool, optional): True if the generator should generate random batches of data. Defaults to True. | |
Yields: | |
tuple: two copies of the batch and the mask | |
""" | |
# calculate the number of batches we can supply | |
num_batches = len(features_seq) // (batch_size * seq_len) | |
if num_batches == 0: | |
raise ValueError("No batches created. Use smaller batch size or sequence length.") | |
# calculate effective length of text to use | |
rounded_len = num_batches * batch_size * seq_len | |
# Reshape the features matrix in batch size x num_batches * seq_len | |
x = np.reshape(features_seq[: rounded_len], [batch_size, num_batches * seq_len]) | |
# Reshape the target matrix in batch size x num_batches * seq_len | |
y = np.reshape(label_seq[: rounded_len], [batch_size, num_batches * seq_len]) | |
epoch = 0 | |
while True: | |
# roll so that no need to reset rnn states over epochs | |
x_epoch = np.split(np.roll(x, -epoch, axis=0), num_batches, axis=1) | |
y_epoch = np.split(np.roll(y, -epoch, axis=0), num_batches, axis=1) | |
for batch in range(num_batches): | |
yield x_epoch[batch], y_epoch[batch] | |
epoch += 1 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment