Skip to content

Instantly share code, notes, and snippets.

@aarestu
Forked from T-B-F/mybidirectional.py
Created July 12, 2018 07:09
Show Gist options
  • Save aarestu/ab1ea616b170e42b6b3b622c84dd9f0c to your computer and use it in GitHub Desktop.
Save aarestu/ab1ea616b170e42b6b3b622c84dd9f0c to your computer and use it in GitHub Desktop.
from __future__ import print_function
import numpy as np
from keras import backend as K
from keras.preprocessing import sequence
from keras.models import Model, Sequential
from keras.layers import Dense, Dropout, Embedding, LSTM, Wrapper, Input
from keras.datasets import imdb
from keras.utils.generic_utils import has_arg
import copy
class Bidirectional(Wrapper):
"""Bidirectional wrapper for RNNs.
# Arguments
layer: `Recurrent` instance.
merge_mode: Mode by which outputs of the
forward and backward RNNs will be combined.
One of {'sum', 'mul', 'concat', 'ave', None}.
If None, the outputs will not be combined,
they will be returned as a list.
# Raises
ValueError: In case of invalid `merge_mode` argument.
# Examples
```python
model = Sequential()
model.add(Bidirectional(LSTM(10, return_sequences=True),
input_shape=(5, 10)))
model.add(Bidirectional(LSTM(10)))
model.add(Dense(5))
model.add(Activation('softmax'))
model.compile(loss='categorical_crossentropy', optimizer='rmsprop')
```
"""
def __init__(self, layer, merge_mode='concat', weights=None, **kwargs):
super(Bidirectional, self).__init__(layer, **kwargs)
if merge_mode not in ['sum', 'mul', 'ave', 'concat', None]:
raise ValueError('Invalid merge mode. '
'Merge mode should be one of '
'{"sum", "mul", "ave", "concat", None}')
self.forward_layer = copy.copy(layer)
config = layer.get_config()
config['go_backwards'] = not config['go_backwards']
self.backward_layer = layer.__class__.from_config(config)
self.forward_layer.name = 'forward_' + self.forward_layer.name
self.backward_layer.name = 'backward_' + self.backward_layer.name
self.merge_mode = merge_mode
if weights:
nw = len(weights)
self.forward_layer.initial_weights = weights[:nw // 2]
self.backward_layer.initial_weights = weights[nw // 2:]
self.stateful = layer.stateful
self.return_sequences = layer.return_sequences
self.return_state = layer.return_state
self._states = None
self.cell = layer.cell
@property
def states(self):
if self._states is None:
if isinstance(self.cell.state_size, int):
num_states = 1
else:
num_states = len(self.cell.state_size)
return [None for _ in range(num_states)]
return self._states
@states.setter
def states(self, states):
self._states = states
def get_weights(self):
return self.forward_layer.get_weights() + self.backward_layer.get_weights()
def set_weights(self, weights):
nw = len(weights)
self.forward_layer.set_weights(weights[:nw // 2])
self.backward_layer.set_weights(weights[nw // 2:])
def compute_output_shape(self, input_shape):
if hasattr(self.cell.state_size, '__len__'):
output_dim = self.cell.state_size[0]
else:
output_dim = self.cell.state_size
if self.merge_mode in ['sum', 'ave', 'mul']:
output_shape = self.forward_layer.compute_output_shape(input_shape)
elif self.merge_mode == 'concat':
shape = list(self.forward_layer.compute_output_shape(input_shape))
shape[-1] *= 2
output_shape = tuple(shape)
elif self.merge_mode is None:
output_shape = [self.forward_layer.compute_output_shape(input_shape)] * 2
if self.return_state:
state_shape = [(input_shape[0], output_dim) for _ in self.states]
return [output_shape] + state_shape
else:
return output_shape
def call(self, inputs, training=None, mask=None):
kwargs = {}
if has_arg(self.layer.call, 'training'):
kwargs['training'] = training
if has_arg(self.layer.call, 'mask'):
kwargs['mask'] = mask
if not self.return_state:
y = self.forward_layer.call(inputs, **kwargs)
y_rev = self.backward_layer.call(inputs, **kwargs)
else:
y, states_h, states_c = self.forward_layer.call(inputs, **kwargs)
y_rev, states_h_rev, states_c_rev = self.backward_layer.call(inputs, **kwargs)
if self.return_sequences:
y_rev = K.reverse(y_rev, 1)
if self.merge_mode == 'concat':
output = K.concatenate([y, y_rev])
if self.return_state:
states_final_h = K.concatenate([states_h, states_h_rev])
states_final_c = K.concatenate([states_c, states_c_rev])
elif self.merge_mode == 'sum':
output = y + y_rev
if self.return_state:
states_final_h = states_h+states_h_rev
states_final_c = states_c+ states_c_rev
elif self.merge_mode == 'ave':
output = (y + y_rev) / 2
if self.return_state:
states_final_h = states_h * states_h_rev
states_final_c = states_c * states_c_rev
elif self.merge_mode == 'mul':
output = y * y_rev
if self.return_state:
states_final_h = states_h * states_h_rev
states_final_c = states_c * states_c_rev
elif self.merge_mode is None:
output = [y, y_rev]
if self.return_state:
states_final_h = [states_h, states_h_rev]
states_final_c = [states_c, states_c_rev]
# Properly set learning phase
if (getattr(y, '_uses_learning_phase', False) or
getattr(y_rev, '_uses_learning_phase', False)):
if self.merge_mode is None:
for out in output:
out._uses_learning_phase = True
else:
output._uses_learning_phase = True
if self.return_state:
states = [states_final_h, states_final_c]
if not isinstance(states, (list, tuple)):
states = [states]
else:
states = list(states)
return [output] + states
else:
return output
def reset_states(self):
self.forward_layer.reset_states()
self.backward_layer.reset_states()
def build(self, input_shape):
with K.name_scope(self.forward_layer.name):
self.forward_layer.build(input_shape)
with K.name_scope(self.backward_layer.name):
self.backward_layer.build(input_shape)
self.built = True
def compute_mask(self, inputs, mask):
if isinstance(mask, list):
mask = mask[0]
if self.return_sequences:
output_mask = mask
state_mask = [None for _ in self.states]
return [output_mask] + state_mask
else:
return None
@property
def trainable_weights(self):
if hasattr(self.forward_layer, 'trainable_weights'):
return (self.forward_layer.trainable_weights +
self.backward_layer.trainable_weights)
return []
@property
def non_trainable_weights(self):
if hasattr(self.forward_layer, 'non_trainable_weights'):
return (self.forward_layer.non_trainable_weights +
self.backward_layer.non_trainable_weights)
return []
@property
def updates(self):
if hasattr(self.forward_layer, 'updates'):
return self.forward_layer.updates + self.backward_layer.updates
return []
@property
def losses(self):
if hasattr(self.forward_layer, 'losses'):
return self.forward_layer.losses + self.backward_layer.losses
return []
@property
def constraints(self):
constraints = {}
if hasattr(self.forward_layer, 'constraints'):
constraints.update(self.forward_layer.constraints)
constraints.update(self.backward_layer.constraints)
return constraints
def get_config(self):
config = {'merge_mode': self.merge_mode}
base_config = super(Bidirectional, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
batch_size = 64 # Batch size for training.
epochs = 200 # Number of epochs to train for.
latent_dim = 256 # Latent dimensionality of the encoding space.
num_samples = 100 # Number of samples to train on.
num_encoder_tokens = 30
num_decoder_tokens = 32
max_encoder_seq_length = 10
max_decoder_seq_length = 12
encoder_input_data = np.random.rand(num_samples, max_encoder_seq_length, num_encoder_tokens)
decoder_input_data = np.random.rand(num_samples, max_decoder_seq_length, num_decoder_tokens)
decoder_target_data = np.random.rand(num_samples, max_decoder_seq_length, num_decoder_tokens)
# Define an input sequence and process it.
encoder_inputs = Input(shape=(None, num_encoder_tokens))
# encoder = LSTM(latent_dim, return_state=True)
encoder = Bidirectional(LSTM(latent_dim, return_state=True), merge_mode="sum")
outputs_and_states = encoder(encoder_inputs)
# encoder_outputs, state_h, state_c = encoder(encoder_inputs)
# We discard `encoder_outputs` and only keep the states.
# encoder_states = [state_h, state_c]
encoder_states = outputs_and_states[1:]
# encoder_states = [encoder_states_tmp[0] + encoder_states_tmp[1], encoder_states_tmp[2] + encoder_states_tmp[3]]
# Set up the decoder, using `encoder_states` as initial state.
decoder_inputs = Input(shape=(None, num_decoder_tokens))
# We set up our decoder to return full output sequences,
# and to return internal states as well. We don't use the
# return states in the training model, but we will use them in inference.
decoder_lstm = LSTM(latent_dim, return_sequences=True, return_state=True)
decoder_outputs, _, _ = decoder_lstm(decoder_inputs,
initial_state=encoder_states)
decoder_dense = Dense(num_decoder_tokens, activation='softmax')
decoder_outputs = decoder_dense(decoder_outputs)
# Define the model that will turn
# `encoder_input_data` & `decoder_input_data` into `decoder_target_data`
model = Model([encoder_inputs, decoder_inputs], decoder_outputs)
# Run training
model.compile(optimizer='rmsprop', loss='categorical_crossentropy')
print("training")
model.fit([encoder_input_data, decoder_input_data], decoder_target_data,
batch_size=batch_size,
epochs=epochs,
validation_split=0.2)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment