Skip to content

Instantly share code, notes, and snippets.

@mpariente
Last active December 19, 2017 22:12
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mpariente/24f1e3b43e5a999acaa97a42df9a4ed9 to your computer and use it in GitHub Desktop.
Save mpariente/24f1e3b43e5a999acaa97a42df9a4ed9 to your computer and use it in GitHub Desktop.
Creating a child of Bidirectional to handle return_states
# In answer to this issue
# https://github.com/keras-team/keras/issues/8823
from keras import backend as K
from keras.layers import Bidirectional
from keras.utils.generic_utils import has_arg
from keras.layers import Input, LSTM
class MyBidirectional(Bidirectional):
def __init__(self, layer, merge_mode='concat', weights=None, **kwargs):
super(MyBidirectional, self).__init__(layer, **kwargs)
self.return_state = layer.return_state
def compute_output_shape(self, input_shape):
if hasattr(self.layer.cell.state_size, '__len__'):
output_dim = self.layer.cell.state_size[0]
else:
output_dim = self.layer.cell.state_size
if self.merge_mode in ['sum', 'ave', 'mul']:
output_shape = self.forward_layer.compute_output_shape(input_shape)
elif self.merge_mode == 'concat':
shape = list(self.forward_layer.compute_output_shape(input_shape))
shape[-1] *= 2
output_shape = tuple(shape)
elif self.merge_mode is None:
output_shape = [self.forward_layer.compute_output_shape(input_shape)] * 2
if self.return_state:
state_shape = [(input_shape[0], output_dim) for _ in range(4)]
return [output_shape] + state_shape
else:
return output_shape
def call(self, inputs, training=None, mask=None):
kwargs = {}
if has_arg(self.layer.call, 'training'):
kwargs['training'] = training
if has_arg(self.layer.call, 'mask'):
kwargs['mask'] = mask
y = self.forward_layer.call(inputs, **kwargs)
y_rev = self.backward_layer.call(inputs, **kwargs)
if self.return_state:
y, states_h, states_c = y
y_rev, states_h_rev, states_c_rev = y_rev
if self.return_sequences:
y_rev = K.reverse(y_rev, 1)
if self.merge_mode == 'concat':
output = K.concatenate([y, y_rev])
elif self.merge_mode == 'sum':
output = y + y_rev
elif self.merge_mode == 'ave':
output = (y + y_rev) / 2
elif self.merge_mode == 'mul':
output = y * y_rev
elif self.merge_mode is None:
output = [y, y_rev]
# Properly set learning phase
if (getattr(y, '_uses_learning_phase', False) or
getattr(y_rev, '_uses_learning_phase', False)):
if self.merge_mode is None:
for out in output:
out._uses_learning_phase = True
else:
output._uses_learning_phase = True
if self.return_state:
states = [states_h, states_h_rev, states_c, states_c_rev]
if not isinstance(states, (list, tuple)):
states = [states]
else:
states = list(states)
return [output] + states
else:
return output
# Just random values to fit with your code
num_encoder_tokens = 10
latent_dim = 30
encoder_inputs = Input(shape=(None, num_encoder_tokens))
encoder = LSTM(latent_dim, return_state=True)
encoder_outputs, state_h, state_c = encoder(encoder_inputs)
encoder_inputs = Input(shape=(None, num_encoder_tokens))
encoder = MyBidirectional(LSTM(latent_dim, return_state=True),merge_mode="mul")
outputs_and_states = encoder(encoder_inputs)
#encoder_outputs, states_h, states_h_rev, states_c, states_c_rev = encoder(encoder_inputs)
outputs = outputs_and_states[0]
states = outputs_and_states[1:]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment