Skip to content

Instantly share code, notes, and snippets.

@mrdrozdov
Created February 11, 2019 01:32
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mrdrozdov/8058f3fe337be96883b47b9aa32e1974 to your computer and use it in GitHub Desktop.
Save mrdrozdov/8058f3fe337be96883b47b9aa32e1974 to your computer and use it in GitHub Desktop.
nary-treelstm.py
def forward(self, left_h, left_c, right_h, right_c, constant=1.0):
U, B = self.U, self.B
W = U.t()
width, height = W.shape
Wl = W[:width//2]
Wr = W[width//2:]
al = torch.matmul(left_h, Wl)
al_lst = torch.chunk(al, 5, dim=1)
ar = torch.matmul(right_h, Wr)
ar_lst = torch.chunk(ar, 5, dim=1)
B_lst = torch.chunk(B, 5)
i = torch.sigmoid(al_lst[0] + ar_lst[0] + B_lst[0])
fl = torch.sigmoid(al_lst[1] + ar_lst[1] + B_lst[1] + constant)
fr = torch.sigmoid(al_lst[2] + ar_lst[2] + B_lst[2] + constant)
o = torch.sigmoid(al_lst[3] + ar_lst[3] + B_lst[3])
u = torch.tanh(al_lst[4] + ar_lst[4] + B_lst[4])
c = fl * left_c + fr * right_c + i * u
h = o * torch.tanh(c)
return h, c
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment