Skip to content

Instantly share code, notes, and snippets.

@mikaelsouza
Last active February 20, 2020 17:31
Show Gist options
  • Save mikaelsouza/f38723589322b0fd54543aa71f3d5271 to your computer and use it in GitHub Desktop.
Save mikaelsouza/f38723589322b0fd54543aa71f3d5271 to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
MAX_SIZE = 3
device = torch.device('cpu')
# generating vocab
sentences = [
'hello friend',
'how are you',
]
word2index = dict((word, index + 1) for index, word in enumerate(' '.join(sentences).split(' ')))
index2word = {v: k for k, v in word2index.items()}
num_distinct_words = len(word2index) + 1
num_sentences = len(sentences)
encoded_sentences = [[word2index[word] for word in sentence.split(' ')] for sentence in sentences]
# generating inputs
inputs = torch.zeros((num_sentences, MAX_SIZE), dtype=torch.int64).to(device)
for i, sentence in enumerate(encoded_sentences):
for j, word in enumerate(sentence):
inputs[i, j] = word
# declaring network
class Net(nn.Module):
def __init__(self, input_size, output_size, embedding_size=10, hidden_size=100):
super().__init__()
self.embedding = nn.Embedding(input_size, embedding_size)
self.rnn = nn.RNN(embedding_size, hidden_size, batch_first=True)
self.linear = nn.Linear(hidden_size, output_size)
def forward(self, x):
output = self.embedding(x)
output, state = self.rnn(output)
output = self.linear(output)
return output
def predict(self, x):
with torch.no_grad():
result = torch.argmax(F.softmax(self.forward(x), dim=-1), dim=-1)
return result
net = Net(num_distinct_words, num_distinct_words, embedding_size=num_distinct_words).to(device)
# training step
EPOCHS = 10
optimizer = optim.Adam(net.parameters())
loss_function = nn.CrossEntropyLoss()
for e in range(EPOCHS):
net.zero_grad()
output = net.forward(inputs)
loss = loss_function(output, inputs)
loss.backward(retain_graph=True)
optimizer.step()
print("Epoch: ", e, "Loss: ", loss)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment