Skip to content

Instantly share code, notes, and snippets.

@ilblackdragon
Last active September 6, 2017 23:16
Show Gist options
  • Save ilblackdragon/e737cc98dce4865c1e16a42565ef70e9 to your computer and use it in GitHub Desktop.
Save ilblackdragon/e737cc98dce4865c1e16a42565ef70e9 to your computer and use it in GitHub Desktop.
TreeLSTM regular model
class TreeLSTM(nn.Module):
def __init__(self, num_units):
super(TreeLSTM, self).__init__()
self.num_units = num_units
self.left = nn.Linear(num_units, 5 * num_units)
self.right = nn.Linear(num_units, 5 * num_units)
def forward(self, left_in, right_in):
lstm_in = self.left(left_in[0])
lstm_in += self.right(right_in[0])
a, i, f1, f2, o = lstm_in.chunk(5, 1)
c = (a.tanh() * i.sigmoid() + f1.sigmoid() * left_in[1] +
f2.sigmoid() * right_in[1])
h = o.sigmoid() * c.tanh()
return h, c
class SPINN(nn.Module):
def __init__(self, n_classes, size, n_words):
super(SPINN, self).__init__()
self.size = size
self.tree_lstm = TreeLSTM(size)
self.embeddings = nn.Embedding(n_words, size)
self.out = nn.Linear(size, n_classes)
def leaf(self, word_id):
return self.embeddings(word_id), Variable(torch.FloatTensor(word_id.size()[0], self.size))
def children(self, left_h, left_c, right_h, right_c):
return self.tree_lstm((left_h, left_c), (right_h, right_c))
def logits(self, encoding):
return self.out(encoding)
def encode_tree_regular(model, tree):
def encode_node(node):
if node.is_leaf():
return model.leaf(Variable(torch.LongTensor([node.id])))
else:
left_h, left_c = encode_node(node.left)
right_h, right_c = encode_node(node.right)
return model.children(left_h, left_c, right_h, right_c)
encoding, _ = encode_node(tree.root)
return model.logits(encoding)
...
all_logits, all_labels = [], []
for tree in batch:
all_logits.append(encode_tree_regular(model, tree))
all_labels.append(tree.label)
loss = criterion(torch.cat(all_logits, 0), Variable(torch.LongTensor(all_labels)))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment