Last active
September 25, 2020 18:20
-
-
Save edumunozsala/c6ccc14511e1d311f07e6e24ef8e5579 to your computer and use it in GitHub Desktop.
Encode the text and create the input and target datasets
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def one_hot_encode(indices, dict_size): | |
''' Define one hot encode matrix for our sequences''' | |
# Creating a multi-dimensional array with the desired output shape | |
# Encode every integer with its one hot representation | |
features = np.eye(dict_size, dtype=np.float32)[indices.flatten()] | |
# Finally reshape it to get back to the original array | |
features = features.reshape((*indices.shape, dict_size)) | |
return features | |
def encode_text(input_text, vocab, one_hot = False): | |
# Replace every char by its integer value based on the vocabulary | |
output = [vocab.char2int.get(character,0) for character in input_text] | |
if one_hot: | |
# One hot encode every integer of the sequence | |
dict_size = len(vocab.char2int) | |
return one_hot_encode(output, dict_size) | |
else: | |
return np.array(output) | |
# Encode the train dataset | |
train_data = encode_text(sentences, vocab, one_hot = False) | |
# Create the input sequence, from 0 to len-1 | |
input_seq=train_data[:-1] | |
# Create the target sequence, from 1 to len. It is right-shifted one place | |
target_seq=train_data[1:] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment