Skip to content

Instantly share code, notes, and snippets.

@mbollmann
Created June 18, 2016 08:59
Show Gist options
  • Save mbollmann/2d38bd38259a03ea83999de32dbe3466 to your computer and use it in GitHub Desktop.
Save mbollmann/2d38bd38259a03ea83999de32dbe3466 to your computer and use it in GitHub Desktop.
StateTransferLSTM for Keras 1.x
# Source:
# https://github.com/farizrahman4u/seq2seq/blob/master/seq2seq/layers/state_transfer_lstm.py
from keras import backend as K
from keras.layers.recurrent import LSTM
class StateTransferLSTM(LSTM):
"""LSTM with the ability to transfer its hidden state.
This layer behaves just like an LSTM, except that it can transfer (or
*broadcast*) its hidden state to another LSTM layer. This requires that the
shapes and hidden dimensions of the layers match, and that the batch size of
the input is known in advance (e.g. by specifying `batch_input_shape`).
Use :py:meth:`broadcast_state` to enable the hidden state transfer to
another layer.
While this layer can be used fine in a normal Sequential model, it is only
saved and restored properly when used with
:py:class:`mblearn.models.StateTransferSequential`.
See Also:
https://github.com/farizrahman4u/seq2seq/blob/master/seq2seq/layers/state_transfer_lstm.py
"""
def __init__(self, *args, state_input=True, **kwargs):
self.state_outputs = []
self.state_input = state_input
super(StateTransferLSTM, self).__init__(*args, **kwargs)
def get_config(self):
stateful = self.stateful
self.stateful = stateful or len(self.state_outputs) > 0
config = super(StateTransferLSTM, self).get_config()
self.stateful = stateful
config['state_input'] = self.state_input \
if isinstance(self.state_input, bool) else repr(self.state_input)
config['state_outputs'] = [repr(x) for x in self.state_outputs]
return config
@classmethod
def from_config(cls, config):
import copy
config = copy.deepcopy(config)
del config['state_outputs']
return cls(**config)
def build(self, input_shape):
# If we're broadcasting state, we're pretending to be stateful while
# calling LSTM.build()
stateful = self.stateful
self.stateful = stateful or self.state_input or len(self.state_outputs) > 0
if hasattr(self, 'states'):
del self.states
super(StateTransferLSTM, self).build(input_shape)
self.stateful = stateful
def broadcast_state(self, rnns):
"""Make the LSTM broadcast its hidden state to another layer.
Args:
rnns: One or several layers to broadcast the hidden state to.
"""
if type(rnns) not in [list, tuple]:
rnns = [rnns]
self.state_outputs += rnns
for rnn in rnns:
rnn.state_input = self
def call(self, x, mask=None):
# input shape: (nb_samples, time (padded with zeros), input_dim)
# note that the .build() method of subclasses MUST define
# self.input_spec with a complete input shape.
input_shape = self.input_spec[0].shape
if K._BACKEND == 'tensorflow':
if not input_shape[1]:
raise Exception('When using TensorFlow, you should define '
'explicitly the number of timesteps of '
'your sequences.\n'
'If your first layer is an Embedding, '
'make sure to pass it an "input_length" '
'argument. Otherwise, make sure '
'the first layer has '
'an "input_shape" or "batch_input_shape" '
'argument, including the time axis. '
'Found input shape at layer ' + self.name +
': ' + str(input_shape))
if self.stateful or self.state_input or len(self.state_outputs) > 0:
initial_states = self.states
else:
initial_states = self.get_initial_states(x)
constants = self.get_constants(x)
preprocessed_input = self.preprocess_input(x)
last_output, outputs, states = K.rnn(self.step, preprocessed_input,
initial_states,
go_backwards=self.go_backwards,
mask=mask,
constants=constants,
unroll=self.unroll,
input_length=input_shape[1])
if self.stateful and not self.state_input:
self.updates = []
for i in range(len(states)):
self.updates.append((self.states[i], states[i]))
for o in self.state_outputs:
o.updates = []
for i in range(len(states)):
o.updates.append((o.states[i], states[i]))
if self.return_sequences:
return outputs
else:
return last_output
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment