Skip to content

Instantly share code, notes, and snippets.

@ilblackdragon
Last active September 6, 2017 23:17
Show Gist options
  • Save ilblackdragon/662d623eec37dd0eeeabfcb650f00494 to your computer and use it in GitHub Desktop.
Save ilblackdragon/662d623eec37dd0eeeabfcb650f00494 to your computer and use it in GitHub Desktop.
Tree LSTM folded
from pytorch_tools import torchfold
def encode_tree_fold(fold, tree):
def encode_node(node):
if node.is_leaf():
return fold.add('leaf', node.id).split(2)
else:
left_h, left_c = encode_node(node.left)
right_h, right_c = encode_node(node.right)
return fold.add('children', left_h, left_c, right_h, right_c).split(2)
encoding, _ = encode_node(tree.root)
return fold.add('logits', encoding)
...
fold = torchfold.Fold(cuda=args.cuda)
all_logits, all_labels = [], []
for tree in batch:
all_logits.append(encode_tree_folded(fold, tree))
all_labels.append(tree.label)
res = fold.apply(model, [all_logits, all_labels])
loss = criterion(res[0], res[1])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment