Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def one_hot_encode(encoded, vocab_size): | |
| result = torch.zeros((len(encoded), vocab_size)) | |
| for i, idx in enumerate(encoded): | |
| result[i, idx] = 1.0 | |
| return result | |
| # One hot encode our encoded charactes | |
| batch_size = 2 | |
| seq_length = 3 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| torch.manual_seed(1) # reproducibility | |
| #### Define the network parameters: | |
| hiddenSize = 2 # network size, this can be any number (depending on your task) | |
| numClass = 4 # this is the same as our vocab_size | |
| #### Weight matrices for our inputs | |
| Wz = Variable(torch.randn(vocab_size, hiddenSize), requires_grad=True)) | |
| Wr = Variable(torch.randn(vocab_size, hiddenSize), requires_grad=True)) | |
| Wh = Variable(torch.randn(vocab_size, hiddenSize), requires_grad=True)) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def gru(x, h): | |
| outputs = [] | |
| for i,sequence in enumerate(x): # iterates over the sequences in each batch | |
| z = torch.sigmoid(torch.matmul(sequence, Wz) + torch.matmul(h, Uz) + bz) | |
| r = torch.sigmoid(torch.matmul(sequence, Wr) + torch.matmul(h, Ur) + br) | |
| h_tilde = torch.tanh(torch.matmul(sequence, Wh) + torch.matmul(r * h, Uh) + bh) | |
| h = z * h + (1 - z) * h_tilde | |
| # Linear layer | |
| y_linear = torch.matmul(h, Wy) + by |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def sample(primer, length_chars_predict): | |
| word = primer | |
| primer_dictionary = [character_dictionary[char] for char in word] | |
| test_input = one_hot_encode(primer_dictionary, vocab_size) | |
| h = torch.zeros(1, hiddenSize) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| character_list = list(set(text)) # get all of the unique letters in our text variable | |
| vocabulary_size = len(character_list) # count the number of unique elements | |
| character_dictionary = {char:e for e, char in enumerate(character_list)} # create a dictionary mapping each unique char to a number | |
| encoded_chars = [character_dictionary[char] for char in text] #integer representation of our vocabulary |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #Input text | |
| # This will be our input ---> x | |
| text = 'MathMathMathMathMath' | |
| # Training loop | |
| max_epochs = 5 # passes through the data | |
| for e in range(max_epochs): | |
| h = torch.zeros(batch_size, hiddenSize) | |
| for i in range(num_batches): |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| ht_2 = [] # stores the calculated h for each input x | |
| outputs = [] | |
| h = torch.zeros(batch_size, hiddenSize) # intitalizes the hidden state | |
| for i in range(num_batches): # this loops over the batches | |
| x = X[i] | |
| for i,sequence in enumerate(x): # iterates over the sequences in each batch | |
| z = torch.sigmoid(torch.matmul(sequence, Wz) + torch.matmul(h, Uz) + bz) | |
| r = torch.sigmoid(torch.matmul(sequence, Wr) + torch.matmul(h, Ur) + br) | |
| h_tilde = torch.tanh(torch.matmul(sequence, Wh) + torch.matmul(r * h, Uh) + bh) | |
| h = z * h + (1 - z) * h_tilde |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| hidden_batch_1 = ht_2[:3] | |
| outputs_batch_1 = outputs[:3] | |
| print(f' Predictions for the first batch: \n\n{outputs_batch_1}, \ | |
| \n \n Hidden states for the first bactch: \n{hidden_batch_1}') | |
| ''' | |
| Predictions for the first batch: | |
| tensor([[[0.4342, 0.1669, 0.1735, 0.2254], | |
| [0.2207, 0.2352, 0.3322, 0.2119]], |
OlderNewer