Created
September 2, 2017 20:05
-
-
Save emirceyani/2ca7d8c3c9a2704d0f1e7f72cfbdac72 to your computer and use it in GitHub Desktop.
CBOW
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.autograd as autograd | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.optim as optim | |
torch.manual_seed(1) | |
CONTEXT_SIZE = 2 # 2 words to the left, 2 to the right | |
EMBEDDING_DIM=10 | |
raw_text = """We are about to study the idea of a computational process. | |
Computational processes are abstract beings that inhabit computers. | |
As they evolve, processes manipulate other abstract things called data. | |
The evolution of a process is directed by a pattern of rules | |
called a program. People create programs to direct processes. In effect, | |
we conjure the spirits of the computer with our spells.""".split() | |
# By deriving a set from `raw_text`, we deduplicate the array | |
vocab = set(raw_text) | |
vocab_size = len(vocab) | |
word_to_ix = {word: i for i, word in enumerate(vocab)} | |
data = [] | |
for i in range(2, len(raw_text) - 2): | |
context = [raw_text[i - 2], raw_text[i - 1], | |
raw_text[i + 1], raw_text[i + 2]] | |
target = raw_text[i] | |
data.append((context, target)) | |
class CBOW(nn.Module): | |
def __init__(self,vocab_size,embedding_dim,context_size): | |
super(CBOW,self).__init__() | |
self.embeddings=nn.Embedding(vocab_size,embedding_dim) | |
self.linear1=nn.Linear(embedding_dim,vocab_size) | |
def forward(self, inputs): | |
embeds = self.embeddings(inputs).sum(dim=0).view((1,-1)) | |
out = self.linear1(embeds) | |
log_probs = F.log_softmax(out) | |
return log_probs | |
def make_context_vector(context, word_to_ix): | |
idxs = [word_to_ix[w] for w in context] | |
tensor = torch.LongTensor(idxs) | |
return autograd.Variable(tensor) | |
context=make_context_vector(data[0][0], word_to_ix) # example | |
# create your model and train. here are some functions to help you make | |
# the data ready for use by your module | |
losses = [] | |
loss_function = nn.NLLLoss() | |
model = CBOW(vocab_size, EMBEDDING_DIM, CONTEXT_SIZE) | |
optimizer = optim.SGD(model.parameters(), lr=0.001) | |
for epoch in range(20000): | |
total_loss = torch.Tensor([0]) | |
for context, target in data: | |
# Step 1. Prepare the inputs to be passed to the model (i.e, turn the words | |
# into integer indices and wrap them in variables) | |
context_idxs = [word_to_ix[w] for w in context] | |
context_var = autograd.Variable(torch.LongTensor(context_idxs)) | |
# Step 2. Recall that torch *accumulates* gradients. Before passing in a | |
# new instance, you need to zero out the gradients from the old | |
# instance | |
model.zero_grad() | |
# Step 3. Run the forward pass, getting log probabilities over next | |
# words | |
log_probs = model(context_var) | |
# Step 4. Compute your loss function. (Again, Torch wants the target | |
# word wrapped in a variable) | |
loss = loss_function(log_probs, autograd.Variable( | |
torch.LongTensor([word_to_ix[target]]))) | |
# Step 5. Do the backward pass and update the gradient | |
loss.backward() | |
optimizer.step() | |
total_loss += loss.data | |
losses.append(total_loss) | |
print(losses) # The loss decreased every iteration over the training data! | |
#For sanity check, let's see what our model predicts | |
acc=0 | |
for context, target in data: | |
# Step 1. Prepare the inputs to be passed to the model (i.e, turn the words | |
# into integer indices and wrap them in variables) | |
context_idxs = [word_to_ix[w] for w in context] | |
context_var = autograd.Variable(torch.LongTensor(context_idxs)) | |
# Step 3. Run the forward pass, getting log probabilities over next | |
# words | |
log_probs = model(context_var) | |
_,idx= torch.min(log_probs,-1) | |
print (context,word_to_ix[target],idx.data[0]) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment