Skip to content

Instantly share code, notes, and snippets.

@cschell
Created June 7, 2019 07:42
Show Gist options
  • Save cschell/07aaa25b9cee2db5ff8181ec9ad4f898 to your computer and use it in GitHub Desktop.
Save cschell/07aaa25b9cee2db5ff8181ec9ad4f898 to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
class CBOW(nn.Module):
def __init__(self, vocab_size: int, embedding_dim: int = 100, context_size: int = 4):
super(CBOW, self).__init__()
self.embeddings = nn.Embedding(vocab_size, embedding_dim)
self.linear1 = nn.Linear(context_size * embedding_dim, 128)
self.linear2 = nn.Linear(128, vocab_size)
def forward(self, inputs):
embeds = self.embeddings(inputs).view((1, -1))
out = F.relu(self.linear1(embeds))
out = self.linear2(out)
log_probs = F.log_softmax(out, dim=1)
return log_probs
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment