Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Residual LSTM in Keras
____________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
====================================================================================================
input_1 (InputLayer) (None, 32, 10) 0
____________________________________________________________________________________________________
lstm_1 (LSTM) (None, 32, 10) 840 input_1[0][0]
____________________________________________________________________________________________________
add_1 (Add) (None, 32, 10) 0 input_1[0][0]
lstm_1[0][0]
____________________________________________________________________________________________________
lstm_2 (LSTM) (None, 32, 10) 840 add_1[0][0]
____________________________________________________________________________________________________
add_2 (Add) (None, 32, 10) 0 add_1[0][0]
lstm_2[0][0]
____________________________________________________________________________________________________
lstm_3 (LSTM) (None, 32, 10) 840 add_2[0][0]
____________________________________________________________________________________________________
add_3 (Add) (None, 32, 10) 0 add_2[0][0]
lstm_3[0][0]
____________________________________________________________________________________________________
lstm_4 (LSTM) (None, 32, 10) 840 add_3[0][0]
____________________________________________________________________________________________________
add_4 (Add) (None, 32, 10) 0 add_3[0][0]
lstm_4[0][0]
____________________________________________________________________________________________________
lstm_5 (LSTM) (None, 32, 10) 840 add_4[0][0]
____________________________________________________________________________________________________
add_5 (Add) (None, 32, 10) 0 add_4[0][0]
lstm_5[0][0]
____________________________________________________________________________________________________
lstm_6 (LSTM) (None, 32, 10) 840 add_5[0][0]
____________________________________________________________________________________________________
add_6 (Add) (None, 32, 10) 0 add_5[0][0]
lstm_6[0][0]
____________________________________________________________________________________________________
lstm_7 (LSTM) (None, 32, 10) 840 add_6[0][0]
____________________________________________________________________________________________________
add_7 (Add) (None, 32, 10) 0 add_6[0][0]
lstm_7[0][0]
____________________________________________________________________________________________________
lambda_1 (Lambda) (None, 10) 0 add_7[0][0]
____________________________________________________________________________________________________
lstm_8 (LSTM) (None, 10) 840 add_7[0][0]
____________________________________________________________________________________________________
add_8 (Add) (None, 10) 0 lambda_1[0][0]
lstm_8[0][0]
====================================================================================================
Total params: 6,720
Trainable params: 6,720
Non-trainable params: 0
____________________________________________________________________________________________________
# Stacked LSTM with residual connections in depth direction.
#
# Naturally LSTM has something like residual connections in time.
# Here we add residual connection in depth.
#
# Inspired by Google's Neural Machine Translation System (https://arxiv.org/abs/1609.08144).
# They observed that residual connections allow them to use much deeper stacked RNNs.
# Without residual connections they were limited to around 4 layers of depth.
#
# It uses Keras 2 API.
from keras.layers import LSTM, Lambda
from keras.layers.merge import add
def make_residual_lstm_layers(input, rnn_width, rnn_depth, rnn_dropout):
"""
The intermediate LSTM layers return sequences, while the last returns a single element.
The input is also a sequence. In order to match the shape of input and output of the LSTM
to sum them we can do it only for all layers but the last.
"""
x = input
for i in range(rnn_depth):
return_sequences = i < rnn_depth - 1
x_rnn = LSTM(rnn_width, recurrent_dropout=rnn_dropout, dropout=rnn_dropout, return_sequences=return_sequences)(x)
if return_sequences:
# Intermediate layers return sequences, input is also a sequence.
if i > 0 or input.shape[-1] == rnn_width:
x = add([x, x_rnn])
else:
# Note that the input size and RNN output has to match, due to the sum operation.
# If we want different rnn_width, we'd have to perform the sum from layer 2 on.
x = x_rnn
else:
# Last layer does not return sequences, just the last element
# so we select only the last element of the previous output.
def slice_last(x):
return x[..., -1, :]
x = add([Lambda(slice_last)(x), x_rnn])
return x
if __name__ == '__main__':
# Example usage
from keras.layers import Input
from keras.models import Model
input = Input(shape=(32, 10))
output = make_residual_lstm_layers(input, rnn_width=10, rnn_depth=8, rnn_dropout=0.2)
model = Model(inputs=input, outputs=output)
model.summary()
@thingumajig

This comment has been minimized.

Copy link

@thingumajig thingumajig commented Jan 27, 2017

Maybe so:

    x = input
    for i in range(rnn_depth):
        return_sequences = i < rnn_depth - 1
        x_rnn = LSTM(rnn_width, dropout_W=rnn_dropout, dropout_U=rnn_dropout, return_sequences=return_sequences)(x)
        .....

but there need Flatten/Reshape(?) before merge

@Seanny123

This comment has been minimized.

Copy link

@Seanny123 Seanny123 commented Jun 26, 2017

@bzamecnik

This comment has been minimized.

Copy link
Owner Author

@bzamecnik bzamecnik commented Jul 25, 2017

@thingumajig - aah, thanks.

@bzamecnik

This comment has been minimized.

Copy link
Owner Author

@bzamecnik bzamecnik commented Jul 25, 2017

I fixed several bugs (since the code was not properly tested...), upgraded to Keras 2 API and added support to make residual connections at the last layer (just select the last element of the previous output sequence) and also make residual connection at the input optional only if the input matches the RNN output size.

@bzamecnik

This comment has been minimized.

Copy link
Owner Author

@bzamecnik bzamecnik commented Jul 26, 2017

@Seanny123 Thanks for a tip. This was Google's Neural Machine Translation System (https://arxiv.org/abs/1609.08144). What is different in their architecture?

@NoviScl

This comment has been minimized.

Copy link

@NoviScl NoviScl commented Sep 20, 2018

Do you mean 'and' instead of 'or' in line 27?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment