Skip to content

Instantly share code, notes, and snippets.

@unnonouno
Last active July 13, 2016 02:47
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save unnonouno/b46137bf9291475adb425d20265f1bc0 to your computer and use it in GitHub Desktop.
Save unnonouno/b46137bf9291475adb425d20265f1bc0 to your computer and use it in GitHub Desktop.
def StatefullRNNUnit(chainer.Chain):
def StatefullRNNUnit(self, rnn):
self.rnn = rnn
self.state = (None,) * rnn.n_state
def __call__(x):
args = self.state + (x,)
ret = self.rnn(*args)
self.state = ret[:-1]
return ret[-1]
def GRU(chainer.Chain):
n_state = 1
def __call__(self, h, x):
if h is None:
...
if len(h.data) > len(x.data): # for variable length input
h, h_rest = split_axis(h, [len(x.data)], axis=0)
...
if h_rest:
h_new = concat(h_new, h_rest)
return h_new, y
class NStepRNN(chainer.Chain):
def __call__(self, xs):
assert isisintance(xs, list)
inds = numpy.argsort([-len(x) for x in xs])
sorted_xs = permutate(xs, inds)
txs = transpose_sequnect(sorted_xs)
states = (None,) * self.rnn.n_state
tys = []
for tx in txs:
# check the batch-size of states and tx
args = states + (tx,)
ret = self.rnn(*args)
states = ret[:-1]
ty = ret[-1]
tys.append(ty)
sorted_ys = transpose_sequence(tys)
ys = permutate(sorted_ys, inds, inv=True)
return ys
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment