Skip to content

Instantly share code, notes, and snippets.

@tsuchm
Created October 12, 2016 01:06
Show Gist options
  • Save tsuchm/82c598750fa687c0d5c21c28fa9dc00b to your computer and use it in GitHub Desktop.
Save tsuchm/82c598750fa687c0d5c21c28fa9dc00b to your computer and use it in GitHub Desktop.
Example of LSTM encoder which ignores labels for padding
import chainer
from chainer import Link, Chain, Function, Variable, cuda
import chainer.functions as F
import chainer.links as L
import numpy as np
class LSTMEncoder(Chain):
def __init__(self, vocab_size, embed_size, hidden_size, ignore_label=-1):
super(LSTMEncoder, self).__init__(
xe = L.EmbedID(vocab_size, embed_size, ignore_label=-1),
eh = L.Linear(embed_size, 4 * hidden_size),
hh = L.Linear(hidden_size, 4 * hidden_size),
)
self.ignore_label = ignore_label
def __call__(self, x, c_prev, h_prev):
xp = cuda.get_array_module(x.data)
e = F.tanh(self.xe(x))
c, h = F.lstm(c_prev, self.eh(e) + self.hh(h_prev))
ignore = xp.broadcast_to(xp.reshape((x.data == self.ignore_label), (x.shape[0], 1)), c.shape)
c_next = F.where(ignore, c_prev, c)
h_next = F.where(ignore, h_prev, h)
return c_next, h_next
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment