Skip to content

Instantly share code, notes, and snippets.

@farizrahman4u
Last active September 3, 2016 13:51
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save farizrahman4u/02092825699d4d6f5bb216ac86ebb038 to your computer and use it in GitHub Desktop.
Save farizrahman4u/02092825699d4d6f5bb216ac86ebb038 to your computer and use it in GitHub Desktop.
from keras.layers import Recurrent
from keras.models import Sequential
from keras import backend as K
def _isRNN(layer):
return issubclass(layer.__class__, Recurrent)
def _zeros(shape):
shape = [i if i else 2 for i in shape]
return K.zeros(shape)
class DepthFirstRecurrentContainer(Recurrent, Sequential):
def __init__(self, *args, **kwargs):
Recurrent.__init__(self, *args, **kwargs)
Sequential.__init__(self)
@property
def input_shape(self):
return Sequential.input_shape(self)
@property
def output_shape(self):
shape = Sequential.output_shape(self)
if self.return_sequences:
shape = (shape[0], self.input_shape[1]) + shape[1:]
return shape
def get_output_shape_for(self, input_shape):
shape = Sequential.get_output_shape_for(self, input_shape)
if self.return_sequences:
shape = (shape[0], self.input_shape[1]) + shape[1:]
return shape
def add(self, layer):
if _isRNN(layer):
layer.return_sequences = False
layer.consume_less = 'mem'
if len(self.layers > 0) and not _isRNN(self.layers[-1]):
input_length = self.input_shape[0]
if not input_length:
input_length = 1
dummy_layer = Lambda(lambda x: K.tile(K.expand_dims(x, 1), [1, input_length] + [1] * (K.ndim(x) - 1)), output_shape=lambda s: (s[0], input_length) + s[1:])
dummy_layer.dummy = True
Sequential.add(self, dummy_layer)
Sequential.add(self, layer)
def step(self, x, states):
nb_states = []
nb_constants = []
for layer in self.layers:
if _isRNN(layer):
nb_states += [len(layer.states)]
if not hasattr(layer, 'nb_constants'):
layer.nb_constants = len(layer.get_constants(_zeros(layer.input_shape)))
nb_constants += [layer.nb_constants]
rnn_index = 0
for layer in self.layers:
if hasattr(layer, 'dummy'):
continue
if _isRNN(layer):
states_idx = sum(nb_states[:rnn_index])
consts_idx = states_idx + sum(nb_states[rnn_index:]) + sum(nb_constants[:rnn_index])
required_states = states[states_idx : nb_states[rnn_index]] + states[consts_idx : nb_constants[rnn_index]]
x, new_states = layer.step(x, states)
states[states_idx : nb_states[rnn_index]] = new_states
rnn_index += 1
else:
x = layer.call(x)
return x, states[:sum(nb_states)]
def get_initial_states(self, x):
initial_states = []
for layer in self.layers:
if _isRNN(layer):
initial_states += layer.get_initial_states(_zeros(layer.input_shape))
return initial_states
def get_constants(self, x):
constants = []
for layer in self.layers:
if _isRNN(layer):
consts = layer.get_constants(_zeros(layer.input_shape))
layer.nb_constants = len(consts)
return constants
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment