Skip to content

Instantly share code, notes, and snippets.

Created December 2, 2017 16:31
Show Gist options
  • Save aaronstevenwhite/62666159081cb1c7bf377ff1bc05b259 to your computer and use it in GitHub Desktop.
Save aaronstevenwhite/62666159081cb1c7bf377ff1bc05b259 to your computer and use it in GitHub Desktop.
A bidirectional extension of Tai et al.'s (2015) child-sum tree LSTM (for dependency trees) implemented as a pytorch module.
import torch
from torch.nn.modules.rnn import RNNBase
import torch.nn.functional as F
from torch.autograd import Variable
from torch.nn.modules.dropout import Dropout
import sys
if sys.version_info.major == 3:
from functools import lru_cache
from functools32 import lru_cache
class ChildSumTreeLSTM(RNNBase):
"""a bidirectional extension of child-sum tree LSTMs
This module is constructed so as to be a drop-in replacement for
the stock LSTM implemented in pytorch.nn.modules.rnn. It
implements both bidirectional and unidirectional child-sum tree
LSTMs for dependency trees. As such, it aims to minimally change
that implementation's interface to allow for nontrivial tree
topologies, and it exposes the parameters of the LSTM in the same
way - i.e. the attribute names for the LSTM weights and biases are
exactly the same as for the linear chain LSTM. The main difference
between the linear chain version and this version is that
forward() requires an nltk dependency graph representing a
dependency tree for the input embeddings and does not require
initial values for the hidden and cell states.
def __init__(self, *args, **kwargs):
# lru_cache is normally used as a decorator, but that usage
# leads to a global cache, where we need an instance specific
# cache
self._get_parameters = lru_cache()(self._get_parameters)
super(ChildSumTreeLSTM, self).__init__('LSTM', *args, **kwargs)
def nonlinearity(x):
return F.relu(x)
def forward(self, inputs, tree):
inputs : torch.autograd.Variable
a 2D (steps x embedding dimension) or a 3D tensor (steps x
batch dimension x embedding dimension); the batch
dimension must always have size == 1, since this module
does not support minibatching
tree : nltk.DependencyGraph
must implement the following instance methods
- root_idx: all root indices in the tree
- children_idx: indices of children of a particular index
- parents_idx: indices of parents of a particular index
hidden_all : torch.autograd.Variable
hidden_final : torch.autograd.Variable
the hidden state of the trees root node; if there are two
or more such nodes, the average of their hidden states is
ridx = tree.root_idx()
self.hidden_state = {}
self.cell_state = {}
for layer in range(self.num_layers):
self.hidden_state[layer] = {'up': {}, 'down': {}}
self.cell_state[layer] = {'up': {}, 'down': {}}
for i in ridx:
self._upward_downward(layer, 'up', inputs, tree, i)
if self.bidirectional:
hidden_all = [[self.hidden_state[self.num_layers-1]['up'][i],
for i in range(inputs.size(0))]
hidden_all = [self.hidden_state[self.num_layers-1]['up'][i]
for i in range(inputs.size(0))]
hidden_final = [hidden_all[i] for i in ridx]
hidden_all = torch.stack(hidden_all)
hidden_final = torch.mean(torch.stack(hidden_final), 0)
if self._has_batch_dimension:
if self.batch_first:
return hidden_all[None,:,:], hidden_final[None,:]
return hidden_all[:,None,:], hidden_final[None,:]
return hidden_all, hidden_final
def _upward_downward(self, layer, direction, inputs, tree, idx):
# check to see whether this node has been computed on this
# layer in this direction, if so short circuit the rest of
# this function and return that result
if idx in self.hidden_state[layer][direction]:
h_t = self.hidden_state[layer][direction][idx]
c_t = self.cell_state[layer][direction][idx]
return h_t, c_t
x_t = self._construct_x_t(layer, inputs, idx)
oidx, hc_prev = self._construct_previous(layer, direction,
inputs, tree, idx)
h_prev, c_prev = hc_prev
if self.bias:
Wih, Whh, bih, bhh = self._get_parameters(layer, direction)
fcio_t_raw = torch.matmul(Whh, h_prev) +\
torch.matmul(Wih, x_t[:,None]) +\
bhh[:,None] + bih[:,None]
Wih, Whh = self._get_parameters(layer, direction)
fcio_t_raw = torch.matmul(Whh, h_prev) +\
torch.matmul(Wih, x_t[:,None])
f_t_raw, c_hat_t_raw, i_t_raw, o_t_raw = torch.split(fcio_t_raw,
f_t = F.sigmoid(f_t_raw)
gated_children = torch.mul(f_t, c_prev)
gated_children = torch.sum(gated_children, 1, keepdim=False)
c_hat_t_raw = torch.sum(c_hat_t_raw, 1, keepdim=False)
i_t_raw = torch.sum(i_t_raw, 1, keepdim=False)
o_t_raw = torch.sum(o_t_raw, 1, keepdim=False)
c_hat_t = self.__class__.nonlinearity(c_hat_t_raw)
i_t = F.sigmoid(i_t_raw)
o_t = F.sigmoid(o_t_raw)
c_t = gated_children + torch.mul(i_t, c_hat_t)
h_t = torch.mul(o_t, self.__class__.nonlinearity(c_t))
if self.dropout:
dropout = Dropout(p=self.dropout)
h_t = dropout(h_t)
c_t = dropout(c_t)
self.hidden_state[layer][direction][idx] = h_t
self.cell_state[layer][direction][idx] = c_t
if direction == 'up' and self.bidirectional:
self._upward_downward(layer, 'down', inputs, tree, idx)
return h_t, c_t
def _validate_inputs(self, inputs):
if len(inputs.size()) == 3:
self._has_batch_dimension = True
assert inputs.size()[1] == 1
except AssertionError:
msg = 'ChildSumTreeLSTM assumes that dimension 1 of'
msg += 'inputs is a batch dimension and, because it'
msg += 'does not support minibatching, this dimension'
msg += 'must always have size == 1'
raise ValueError(msg)
elif len(inputs.size()) == 2:
self._has_batch_dimension = False
msg = 'inputs must be 2D or 3D (with dimension 1 being'
msg += 'a batch dimension)'
raise ValueError(msg)
def _get_parameters(self, layer, direction):
dirtag = '' if direction == 'up' else '_reverse'
Wihattrname = 'weight_ih_l{}{}'.format(str(layer), dirtag)
Whhattrname = 'weight_hh_l{}{}'.format(str(layer), dirtag)
Wih, Whh = getattr(self, Wihattrname), getattr(self, Whhattrname)
if self.bias:
bhhattrname = 'bias_hh_l{}{}'.format(str(layer), dirtag)
bihattrname = 'bias_ih_l{}{}'.format(str(layer), dirtag)
bih, bhh = getattr(self, bihattrname), getattr(self, bhhattrname)
return Wih, Whh, bih, bhh
return Wih, Whh
def _construct_x_t(self, layer, inputs, idx):
if layer > 0 and self.bidirectional:
x_t =[self.hidden_state[layer-1]['up'][idx],
elif layer > 0:
x_t = self.hidden_state[layer-1]['up'][idx]
if self._has_batch_dimension:
x_t = inputs[idx,0]
x_t = inputs[idx]
return x_t
def _construct_previous(self, layer, direction, inputs, tree, idx):
if direction == 'up':
oidx = tree.children_idx(idx)
oidx = tree.parents_idx(idx)
if oidx:
h_prev, c_prev = [], []
for i in oidx:
h_prev_i, c_prev_i = self._upward_downward(layer,
tree, i)
h_prev = torch.stack(h_prev, 1)
c_prev = torch.stack(c_prev, 1)
h_prev = h_prev_sum = Variable(torch.zeros(self.hidden_size,1),
c_prev = c_prev_sum = Variable(torch.zeros(self.hidden_size,1),
return oidx, (h_prev, c_prev)
def main(corpus, nembdims=300, nhiddendims=300):
'''run ChildSumTreeLSTM on random inputs for all sentences of corpus
corpus : nltk.DependencyGraph
hidden_all : [torch.autograd.Variable]
hidden_final : [torch.autograd.Variable]
rnn = ChildSumTreeLSTM(nembdims, nhiddendims,
bidirectional=True, num_layers=2, bias=False)
hidden_all = []
hidden_final = []
for tree in corpus:
nwords = len(tree.nodes.keys())
inputs_init = torch.normal(torch.zeros(nwords, nembdims))
inputs = Variable(inputs_init, requires_grad=False)
h_all, h_final = rnn(inputs, tree)
return hidden_all, hidden_final
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment