Last active
May 12, 2018 00:11
-
-
Save crestonbunch/0398cd5d0c76e8343ad35bbed6beb4eb to your computer and use it in GitHub Desktop.
Multiplicative LSTM (mLSTM) Wrapper for Tensorflow
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""A multiplicative cell wrapper as described in | |
'Multiplicative LSTM for Sequence Modeling' by Krause et al.""" | |
import tensorflow as tf | |
import math | |
from collections import namedtuple | |
from tensorflow.contrib.rnn import RNNCell | |
from tensorflow.python.ops.rnn_cell_impl import LSTMStateTuple | |
class MultiplicativeLSTMWrapper(RNNCell): | |
"""Wraps an LSTM cell with a multiplicative hidden state.""" | |
def __init__(self, cell): | |
super().__init__() | |
self._cell = cell | |
@property | |
def output_size(self): | |
return self._cell.output_size | |
@property | |
def state_size(self): | |
return self._cell.state_size | |
def zero_state(self, batch_size, dtype): | |
return self._cell.zero_state(batch_size, dtype) | |
def __call__(self, inputs, state, scope=None): | |
Wx = tf.layers.dense(inputs, self._cell.output_size) | |
Wh = tf.layers.dense(state.h, self._cell.output_size) | |
state = LSTMStateTuple(c=state.c, h=tf.multiply(Wx, Wh)) | |
return self._cell(inputs, state, scope) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from mlstm import MultiplicativeLSTMWrapper | |
from tensorflow.contrib.rnn import BasicLSTMCell | |
# put in the rest of your model | |
# when you want an LSTM cell simply do: | |
# ... | |
cell = MultiplicativeLSTMWrapper(BasicLSTMCell(x)) | |
# ... | |
# and you're done! |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment