Skip to content

Instantly share code, notes, and snippets.

@crestonbunch
Last active May 12, 2018 00:11
Show Gist options
  • Save crestonbunch/0398cd5d0c76e8343ad35bbed6beb4eb to your computer and use it in GitHub Desktop.
Save crestonbunch/0398cd5d0c76e8343ad35bbed6beb4eb to your computer and use it in GitHub Desktop.
Multiplicative LSTM (mLSTM) Wrapper for Tensorflow
"""A multiplicative cell wrapper as described in
'Multiplicative LSTM for Sequence Modeling' by Krause et al."""
import tensorflow as tf
import math
from collections import namedtuple
from tensorflow.contrib.rnn import RNNCell
from tensorflow.python.ops.rnn_cell_impl import LSTMStateTuple
class MultiplicativeLSTMWrapper(RNNCell):
"""Wraps an LSTM cell with a multiplicative hidden state."""
def __init__(self, cell):
super().__init__()
self._cell = cell
@property
def output_size(self):
return self._cell.output_size
@property
def state_size(self):
return self._cell.state_size
def zero_state(self, batch_size, dtype):
return self._cell.zero_state(batch_size, dtype)
def __call__(self, inputs, state, scope=None):
Wx = tf.layers.dense(inputs, self._cell.output_size)
Wh = tf.layers.dense(state.h, self._cell.output_size)
state = LSTMStateTuple(c=state.c, h=tf.multiply(Wx, Wh))
return self._cell(inputs, state, scope)
from mlstm import MultiplicativeLSTMWrapper
from tensorflow.contrib.rnn import BasicLSTMCell
# put in the rest of your model
# when you want an LSTM cell simply do:
# ...
cell = MultiplicativeLSTMWrapper(BasicLSTMCell(x))
# ...
# and you're done!
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment