-
-
Save aarestu/ab1ea616b170e42b6b3b622c84dd9f0c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import print_function | |
import numpy as np | |
from keras import backend as K | |
from keras.preprocessing import sequence | |
from keras.models import Model, Sequential | |
from keras.layers import Dense, Dropout, Embedding, LSTM, Wrapper, Input | |
from keras.datasets import imdb | |
from keras.utils.generic_utils import has_arg | |
import copy | |
class Bidirectional(Wrapper): | |
"""Bidirectional wrapper for RNNs. | |
# Arguments | |
layer: `Recurrent` instance. | |
merge_mode: Mode by which outputs of the | |
forward and backward RNNs will be combined. | |
One of {'sum', 'mul', 'concat', 'ave', None}. | |
If None, the outputs will not be combined, | |
they will be returned as a list. | |
# Raises | |
ValueError: In case of invalid `merge_mode` argument. | |
# Examples | |
```python | |
model = Sequential() | |
model.add(Bidirectional(LSTM(10, return_sequences=True), | |
input_shape=(5, 10))) | |
model.add(Bidirectional(LSTM(10))) | |
model.add(Dense(5)) | |
model.add(Activation('softmax')) | |
model.compile(loss='categorical_crossentropy', optimizer='rmsprop') | |
``` | |
""" | |
def __init__(self, layer, merge_mode='concat', weights=None, **kwargs): | |
super(Bidirectional, self).__init__(layer, **kwargs) | |
if merge_mode not in ['sum', 'mul', 'ave', 'concat', None]: | |
raise ValueError('Invalid merge mode. ' | |
'Merge mode should be one of ' | |
'{"sum", "mul", "ave", "concat", None}') | |
self.forward_layer = copy.copy(layer) | |
config = layer.get_config() | |
config['go_backwards'] = not config['go_backwards'] | |
self.backward_layer = layer.__class__.from_config(config) | |
self.forward_layer.name = 'forward_' + self.forward_layer.name | |
self.backward_layer.name = 'backward_' + self.backward_layer.name | |
self.merge_mode = merge_mode | |
if weights: | |
nw = len(weights) | |
self.forward_layer.initial_weights = weights[:nw // 2] | |
self.backward_layer.initial_weights = weights[nw // 2:] | |
self.stateful = layer.stateful | |
self.return_sequences = layer.return_sequences | |
self.return_state = layer.return_state | |
self._states = None | |
self.cell = layer.cell | |
@property | |
def states(self): | |
if self._states is None: | |
if isinstance(self.cell.state_size, int): | |
num_states = 1 | |
else: | |
num_states = len(self.cell.state_size) | |
return [None for _ in range(num_states)] | |
return self._states | |
@states.setter | |
def states(self, states): | |
self._states = states | |
def get_weights(self): | |
return self.forward_layer.get_weights() + self.backward_layer.get_weights() | |
def set_weights(self, weights): | |
nw = len(weights) | |
self.forward_layer.set_weights(weights[:nw // 2]) | |
self.backward_layer.set_weights(weights[nw // 2:]) | |
def compute_output_shape(self, input_shape): | |
if hasattr(self.cell.state_size, '__len__'): | |
output_dim = self.cell.state_size[0] | |
else: | |
output_dim = self.cell.state_size | |
if self.merge_mode in ['sum', 'ave', 'mul']: | |
output_shape = self.forward_layer.compute_output_shape(input_shape) | |
elif self.merge_mode == 'concat': | |
shape = list(self.forward_layer.compute_output_shape(input_shape)) | |
shape[-1] *= 2 | |
output_shape = tuple(shape) | |
elif self.merge_mode is None: | |
output_shape = [self.forward_layer.compute_output_shape(input_shape)] * 2 | |
if self.return_state: | |
state_shape = [(input_shape[0], output_dim) for _ in self.states] | |
return [output_shape] + state_shape | |
else: | |
return output_shape | |
def call(self, inputs, training=None, mask=None): | |
kwargs = {} | |
if has_arg(self.layer.call, 'training'): | |
kwargs['training'] = training | |
if has_arg(self.layer.call, 'mask'): | |
kwargs['mask'] = mask | |
if not self.return_state: | |
y = self.forward_layer.call(inputs, **kwargs) | |
y_rev = self.backward_layer.call(inputs, **kwargs) | |
else: | |
y, states_h, states_c = self.forward_layer.call(inputs, **kwargs) | |
y_rev, states_h_rev, states_c_rev = self.backward_layer.call(inputs, **kwargs) | |
if self.return_sequences: | |
y_rev = K.reverse(y_rev, 1) | |
if self.merge_mode == 'concat': | |
output = K.concatenate([y, y_rev]) | |
if self.return_state: | |
states_final_h = K.concatenate([states_h, states_h_rev]) | |
states_final_c = K.concatenate([states_c, states_c_rev]) | |
elif self.merge_mode == 'sum': | |
output = y + y_rev | |
if self.return_state: | |
states_final_h = states_h+states_h_rev | |
states_final_c = states_c+ states_c_rev | |
elif self.merge_mode == 'ave': | |
output = (y + y_rev) / 2 | |
if self.return_state: | |
states_final_h = states_h * states_h_rev | |
states_final_c = states_c * states_c_rev | |
elif self.merge_mode == 'mul': | |
output = y * y_rev | |
if self.return_state: | |
states_final_h = states_h * states_h_rev | |
states_final_c = states_c * states_c_rev | |
elif self.merge_mode is None: | |
output = [y, y_rev] | |
if self.return_state: | |
states_final_h = [states_h, states_h_rev] | |
states_final_c = [states_c, states_c_rev] | |
# Properly set learning phase | |
if (getattr(y, '_uses_learning_phase', False) or | |
getattr(y_rev, '_uses_learning_phase', False)): | |
if self.merge_mode is None: | |
for out in output: | |
out._uses_learning_phase = True | |
else: | |
output._uses_learning_phase = True | |
if self.return_state: | |
states = [states_final_h, states_final_c] | |
if not isinstance(states, (list, tuple)): | |
states = [states] | |
else: | |
states = list(states) | |
return [output] + states | |
else: | |
return output | |
def reset_states(self): | |
self.forward_layer.reset_states() | |
self.backward_layer.reset_states() | |
def build(self, input_shape): | |
with K.name_scope(self.forward_layer.name): | |
self.forward_layer.build(input_shape) | |
with K.name_scope(self.backward_layer.name): | |
self.backward_layer.build(input_shape) | |
self.built = True | |
def compute_mask(self, inputs, mask): | |
if isinstance(mask, list): | |
mask = mask[0] | |
if self.return_sequences: | |
output_mask = mask | |
state_mask = [None for _ in self.states] | |
return [output_mask] + state_mask | |
else: | |
return None | |
@property | |
def trainable_weights(self): | |
if hasattr(self.forward_layer, 'trainable_weights'): | |
return (self.forward_layer.trainable_weights + | |
self.backward_layer.trainable_weights) | |
return [] | |
@property | |
def non_trainable_weights(self): | |
if hasattr(self.forward_layer, 'non_trainable_weights'): | |
return (self.forward_layer.non_trainable_weights + | |
self.backward_layer.non_trainable_weights) | |
return [] | |
@property | |
def updates(self): | |
if hasattr(self.forward_layer, 'updates'): | |
return self.forward_layer.updates + self.backward_layer.updates | |
return [] | |
@property | |
def losses(self): | |
if hasattr(self.forward_layer, 'losses'): | |
return self.forward_layer.losses + self.backward_layer.losses | |
return [] | |
@property | |
def constraints(self): | |
constraints = {} | |
if hasattr(self.forward_layer, 'constraints'): | |
constraints.update(self.forward_layer.constraints) | |
constraints.update(self.backward_layer.constraints) | |
return constraints | |
def get_config(self): | |
config = {'merge_mode': self.merge_mode} | |
base_config = super(Bidirectional, self).get_config() | |
return dict(list(base_config.items()) + list(config.items())) | |
batch_size = 64 # Batch size for training. | |
epochs = 200 # Number of epochs to train for. | |
latent_dim = 256 # Latent dimensionality of the encoding space. | |
num_samples = 100 # Number of samples to train on. | |
num_encoder_tokens = 30 | |
num_decoder_tokens = 32 | |
max_encoder_seq_length = 10 | |
max_decoder_seq_length = 12 | |
encoder_input_data = np.random.rand(num_samples, max_encoder_seq_length, num_encoder_tokens) | |
decoder_input_data = np.random.rand(num_samples, max_decoder_seq_length, num_decoder_tokens) | |
decoder_target_data = np.random.rand(num_samples, max_decoder_seq_length, num_decoder_tokens) | |
# Define an input sequence and process it. | |
encoder_inputs = Input(shape=(None, num_encoder_tokens)) | |
# encoder = LSTM(latent_dim, return_state=True) | |
encoder = Bidirectional(LSTM(latent_dim, return_state=True), merge_mode="sum") | |
outputs_and_states = encoder(encoder_inputs) | |
# encoder_outputs, state_h, state_c = encoder(encoder_inputs) | |
# We discard `encoder_outputs` and only keep the states. | |
# encoder_states = [state_h, state_c] | |
encoder_states = outputs_and_states[1:] | |
# encoder_states = [encoder_states_tmp[0] + encoder_states_tmp[1], encoder_states_tmp[2] + encoder_states_tmp[3]] | |
# Set up the decoder, using `encoder_states` as initial state. | |
decoder_inputs = Input(shape=(None, num_decoder_tokens)) | |
# We set up our decoder to return full output sequences, | |
# and to return internal states as well. We don't use the | |
# return states in the training model, but we will use them in inference. | |
decoder_lstm = LSTM(latent_dim, return_sequences=True, return_state=True) | |
decoder_outputs, _, _ = decoder_lstm(decoder_inputs, | |
initial_state=encoder_states) | |
decoder_dense = Dense(num_decoder_tokens, activation='softmax') | |
decoder_outputs = decoder_dense(decoder_outputs) | |
# Define the model that will turn | |
# `encoder_input_data` & `decoder_input_data` into `decoder_target_data` | |
model = Model([encoder_inputs, decoder_inputs], decoder_outputs) | |
# Run training | |
model.compile(optimizer='rmsprop', loss='categorical_crossentropy') | |
print("training") | |
model.fit([encoder_input_data, decoder_input_data], decoder_target_data, | |
batch_size=batch_size, | |
epochs=epochs, | |
validation_split=0.2) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment