Skip to content

Instantly share code, notes, and snippets.

@thomwolf
Last active March 22, 2018 22:04
Show Gist options
  • Save thomwolf/76ab2b5e59a5a24b467c784b476557c4 to your computer and use it in GitHub Desktop.
Save thomwolf/76ab2b5e59a5a24b467c784b476557c4 to your computer and use it in GitHub Desktop.
The Neuralcoref pyTorch model
class Model(nn.Module):
def __init__(self, vocab_size, embed_dim, H1, H2, H3, pairs_in, single_in, drop=0.5):
super(Model, self).__init__()
self.embed = nn.Embedding(vocab_size, embedding_dim)
self.drop = nn.Dropout(drop)
self.pairs = nn.Sequential(nn.Linear(pairs_in, H1), nn.ReLU(), nn.Dropout(drop),
nn.Linear(H1, H2), nn.ReLU(), nn.Dropout(drop),
nn.Linear(H2, H3), nn.ReLU(), nn.Dropout(drop),
nn.Linear(H3, 1),
nn.Linear(1, 1))
self.single = nn.Sequential(nn.Linear(single_in, H1), nn.ReLU(), nn.Dropout(drop),
nn.Linear(H1, H2), nn.ReLU(), nn.Dropout(drop),
nn.Linear(H2, H3), nn.ReLU(), nn.Dropout(drop),
nn.Linear(H3, 1),
nn.Linear(1, 1))
def forward(self, inputs, concat_axis=1):
pairs = (len(inputs) == 8)
if pairs:
(spans, words, s_features, a_spans,
a_words, a_spans, m_words, p_features) = inputs
else:
spans, words, s_features = inputs
embed_words = self.drop(self.embed(words).view(words.size()[0], -1))
single_input = torch.cat([spans, embed_words, s_features], 1)
single_scores = self.single(single_input)
if pairs:
btz, n_pairs, _ = a_spans.size()
a_embed = self.drop(self.embed(a_words.view(btz, -1)).view(btz, n_pairs, -1))
m_embed = self.drop(self.embed(m_words.view(btz, -1)).view(btz, n_pairs, -1))
pair_input = torch.cat([a_spans, a_embed, a_spans, m_embed, p_features], 2)
pair_scores = self.pairs(pair_input).squeeze(dim=2)
total_scores = torch.cat([pair_scores, single_scores], concat_axis)
return total_scores if pairs else single_scores
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment