Last active
December 21, 2017 15:10
-
-
Save astrung/5b47436f62f91bebd1666677674b837f to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import print_function | |
from keras.models import Model | |
# from keras.layers import Input, LSTM, Dense | |
import numpy as np | |
from keras import backend as K | |
from keras.layers import Bidirectional | |
from keras.utils.generic_utils import has_arg | |
from keras.layers import Input, LSTM, Dense, Multiply | |
class MyBidirectional(Bidirectional): | |
def __init__(self, layer, **kwargs): | |
super(MyBidirectional, self).__init__(layer, **kwargs) | |
self.return_state = layer.return_state | |
def compute_output_shape(self, input_shape): | |
if hasattr(self.layer.cell.state_size, '__len__'): | |
output_dim = self.layer.cell.state_size[0] | |
else: | |
output_dim = self.layer.cell.state_size | |
if self.merge_mode in ['sum', 'ave', 'mul']: | |
output_shape = self.forward_layer.compute_output_shape(input_shape) | |
elif self.merge_mode == 'concat': | |
shape = list(self.forward_layer.compute_output_shape(input_shape)) | |
shape[-1] *= 2 | |
output_shape = tuple(shape) | |
elif self.merge_mode is None: | |
output_shape = [self.forward_layer.compute_output_shape(input_shape)] * 2 | |
if self.return_state: | |
state_shape = [(input_shape[0], output_dim) for _ in range(2)] | |
return [output_shape] + state_shape | |
else: | |
return output_shape | |
def call(self, inputs, training=None, mask=None): | |
kwargs = {} | |
if has_arg(self.layer.call, 'training'): | |
kwargs['training'] = training | |
if has_arg(self.layer.call, 'mask'): | |
kwargs['mask'] = mask | |
y = self.forward_layer.call(inputs, **kwargs) | |
y_rev = self.backward_layer.call(inputs, **kwargs) | |
# states_final_h=0 | |
# states_final_c=0 | |
if self.return_state: | |
y, states_h, states_c = y | |
y_rev, states_h_rev, states_c_rev = y_rev | |
if self.return_sequences: | |
y_rev = K.reverse(y_rev, 1) | |
if self.merge_mode == 'concat': | |
output = K.concatenate([y, y_rev]) | |
states_final_h = K.concatenate([states_h, states_h_rev]) | |
states_final_c = K.concatenate([states_c, states_c_rev]) | |
elif self.merge_mode == 'sum': | |
output = y + y_rev | |
states_final_h = states_h + states_h_rev | |
states_final_c = states_c + states_c_rev | |
elif self.merge_mode == 'ave': | |
output = (y + y_rev) / 2 | |
states_final_h = states_h * states_h_rev | |
states_final_c = states_c * states_c_rev | |
elif self.merge_mode == 'mul': | |
output = y * y_rev | |
states_final_h = states_h * states_h_rev | |
states_final_c = states_c * states_c_rev | |
elif self.merge_mode is None: | |
output = [y, y_rev] | |
states_final_h = states_h * states_h_rev | |
states_final_c = states_c * states_c_rev | |
# Properly set learning phase | |
if (getattr(y, '_uses_learning_phase', False) or | |
getattr(y_rev, '_uses_learning_phase', False)): | |
if self.merge_mode is None: | |
for out in output: | |
out._uses_learning_phase = True | |
else: | |
output._uses_learning_phase = True | |
if self.return_state: | |
# states = [states_h, states_h_rev, states_c, states_c_rev] | |
states = [states_final_h,states_final_c] | |
if not isinstance(states, (list, tuple)): | |
states = [states] | |
else: | |
states = list(states) | |
return [output] + states | |
else: | |
return output | |
batch_size = 64 # Batch size for training. | |
epochs = 10 # Number of epochs to train for. | |
latent_dim = 256 # Latent dimensionality of the encoding space. | |
num_samples = 10 # Number of samples to train on. | |
num_encoder_tokens = 30 | |
num_decoder_tokens = 32 | |
max_encoder_seq_length = 10 | |
max_decoder_seq_length = 12 | |
encoder_input_data = np.random.rand(num_samples, max_encoder_seq_length, num_encoder_tokens) | |
decoder_input_data = np.random.rand(num_samples, max_decoder_seq_length, num_decoder_tokens) | |
decoder_target_data = np.random.rand(num_samples, max_decoder_seq_length, num_decoder_tokens) | |
# Define an input sequence and process it. | |
encoder_inputs = Input(shape=(None, num_encoder_tokens)) | |
# encoder = LSTM(latent_dim, return_state=True) | |
encoder = MyBidirectional(LSTM(latent_dim, return_state=True), merge_mode="sum") | |
# encoder = LSTM(latent_dim, return_state=True) | |
# outputs_and_states,state_h,state_c = encoder(encoder_inputs) | |
encoder_outputs, state_h, state_c = encoder(encoder_inputs) | |
# We discard `encoder_outputs` and only keep the states. | |
encoder_states = [state_h, state_c] | |
# encoder_states = outputs_and_states[1:] | |
# encoder_states=np.random.rand(2,) | |
# encoder_states = [encoder_states_tmp[0] + encoder_states_tmp[1], encoder_states_tmp[2] + encoder_states_tmp[3]] | |
# Set up the decoder, using `encoder_states` as initial state. | |
decoder_inputs = Input(shape=(None, num_decoder_tokens)) | |
# We set up our decoder to return full output sequences, | |
# and to return internal states as well. We don't use the | |
# return states in the training model, but we will use them in inference. | |
decoder_lstm = LSTM(latent_dim, return_sequences=True, return_state=True) | |
decoder_outputs, _, _ = decoder_lstm(decoder_inputs, | |
initial_state=encoder_states) | |
decoder_dense = Dense(num_decoder_tokens, activation='softmax') | |
decoder_outputs = decoder_dense(decoder_outputs) | |
# Define the model that will turn | |
# `encoder_input_data` & `decoder_input_data` into `decoder_target_data` | |
model = Model([encoder_inputs, decoder_inputs], decoder_outputs) | |
# Run training | |
model.compile(optimizer='rmsprop', loss='categorical_crossentropy') | |
# layer_name = 'my_bidirectional_1' | |
# encoder_model = Model(encoder_inputs, outputs_and_states) | |
# states_value = encoder_model.predict(encoder_input_data[0:1]) | |
# intermediate_layer_model = Model(inputs=model.input, | |
# outputs=model.get_layer(layer_name).output) | |
# intermediate_output = intermediate_layer_model.predict(encoder_input_data, encoder_input_data) | |
print("training") | |
model.fit([encoder_input_data, decoder_input_data], decoder_input_data, | |
batch_size=batch_size, | |
epochs=epochs, | |
validation_split=0.2,verbose=1) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment