-
-
Save wassname/5292f95000e409e239b9dc973295327a to your computer and use it in GitHub Desktop.
""" | |
A keras attention layer that wraps RNN layers. | |
Based on tensorflows [attention_decoder](https://github.com/tensorflow/tensorflow/blob/c8a45a8e236776bed1d14fd71f3b6755bd63cc58/tensorflow/python/ops/seq2seq.py#L506) | |
and [Grammar as a Foreign Language](https://arxiv.org/abs/1412.7449). | |
date: 20161101 | |
author: wassname | |
url: https://gist.github.com/wassname/5292f95000e409e239b9dc973295327a | |
""" | |
from keras import backend as K | |
from keras.engine import InputSpec | |
from keras.layers import LSTM, activations, Wrapper, Recurrent | |
class Attention(Wrapper): | |
""" | |
This wrapper will provide an attention layer to a recurrent layer. | |
# Arguments: | |
layer: `Recurrent` instance with consume_less='gpu' or 'mem' | |
# Examples: | |
```python | |
model = Sequential() | |
model.add(LSTM(10, return_sequences=True), batch_input_shape=(4, 5, 10)) | |
model.add(TFAttentionRNNWrapper(LSTM(10, return_sequences=True, consume_less='gpu'))) | |
model.add(Dense(5)) | |
model.add(Activation('softmax')) | |
model.compile(loss='categorical_crossentropy', optimizer='rmsprop') | |
``` | |
# References | |
- [Grammar as a Foreign Language](https://arxiv.org/abs/1412.7449) | |
""" | |
def __init__(self, layer, **kwargs): | |
assert isinstance(layer, Recurrent) | |
if layer.get_config()['consume_less']=='cpu': | |
raise Exception("AttentionLSTMWrapper doesn't support RNN's with consume_less='cpu'") | |
self.supports_masking = True | |
super(Attention, self).__init__(layer, **kwargs) | |
def build(self, input_shape): | |
assert len(input_shape) >= 3 | |
self.input_spec = [InputSpec(shape=input_shape)] | |
nb_samples, nb_time, input_dim = input_shape | |
if not self.layer.built: | |
self.layer.build(input_shape) | |
self.layer.built = True | |
super(Attention, self).build() | |
self.W1 = self.layer.init((input_dim, input_dim, 1, 1), name='{}_W1'.format(self.name)) | |
self.W2 = self.layer.init((self.layer.output_dim, input_dim), name='{}_W2'.format(self.name)) | |
self.b2 = K.zeros((input_dim,), name='{}_b2'.format(self.name)) | |
self.W3 = self.layer.init((input_dim*2, input_dim), name='{}_W3'.format(self.name)) | |
self.b3 = K.zeros((input_dim,), name='{}_b3'.format(self.name)) | |
self.V = self.layer.init((input_dim,), name='{}_V'.format(self.name)) | |
self.trainable_weights = [self.W1, self.W2, self.W3, self.V, self.b2, self.b3] | |
def get_output_shape_for(self, input_shape): | |
return self.layer.get_output_shape_for(input_shape) | |
def step(self, x, states): | |
# This is based on [tensorflows implementation](https://github.com/tensorflow/tensorflow/blob/c8a45a8e236776bed1d14fd71f3b6755bd63cc58/tensorflow/python/ops/seq2seq.py#L506). | |
# First, we calculate new attention masks: | |
# attn = softmax(V^T * tanh(W2 * X +b2 + W1 * h)) | |
# and we make the input as a concatenation of the input and weighted inputs which is then | |
# transformed back to the shape x of using W3 | |
# x = W3*(x+X*attn)+b3 | |
# Then, we run the cell on a combination of the input and previous attention masks: | |
# h, state = cell(x, h). | |
nb_samples, nb_time, input_dim = self.input_spec[0].shape | |
h = states[0] | |
X = states[-1] | |
xW1 = states[-2] | |
Xr = K.reshape(X,(-1,nb_time,1,input_dim)) | |
hW2 = K.dot(h,self.W2)+self.b2 | |
hW2 = K.reshape(hW2,(-1,1,1,input_dim)) | |
u = K.tanh(xW1+hW2) | |
a = K.sum(self.V*u,[2,3]) | |
a = K.softmax(a) | |
a = K.reshape(a,(-1, nb_time, 1, 1)) | |
# Weight attention vector by attention | |
Xa = K.sum(a*Xr,[1,2]) | |
Xa = K.reshape(Xa,(-1,input_dim)) | |
# Merge input and attention weighted inputs into one vector of the right size. | |
x = K.dot(K.concatenate([x,Xa],1),self.W3)+self.b3 | |
h, new_states = self.layer.step(x, states) | |
return h, new_states | |
def get_constants(self, x): | |
constants = self.layer.get_constants(x) | |
# Calculate K.dot(x, W2) only once per sequence by making it a constant | |
nb_samples, nb_time, input_dim = self.input_spec[0].shape | |
Xr = K.reshape(x,(-1,nb_time,input_dim,1)) | |
Xrt = K.permute_dimensions(Xr, (0, 2, 1, 3)) | |
xW1t = K.conv2d(Xrt,self.W1,border_mode='same') | |
xW1 = K.permute_dimensions(xW1t, (0, 2, 3, 1)) | |
constants.append(xW1) | |
# we need to supply the full sequence of inputs to step (as the attention_vector) | |
constants.append(x) | |
return constants | |
def call(self, x, mask=None): | |
# input shape: (nb_samples, time (padded with zeros), input_dim) | |
input_shape = self.input_spec[0].shape | |
if K._BACKEND == 'tensorflow': | |
if not input_shape[1]: | |
raise Exception('When using TensorFlow, you should define ' | |
'explicitly the number of timesteps of ' | |
'your sequences.\n' | |
'If your first layer is an Embedding, ' | |
'make sure to pass it an "input_length" ' | |
'argument. Otherwise, make sure ' | |
'the first layer has ' | |
'an "input_shape" or "batch_input_shape" ' | |
'argument, including the time axis. ' | |
'Found input shape at layer ' + self.name + | |
': ' + str(input_shape)) | |
if self.layer.stateful: | |
initial_states = self.layer.states | |
else: | |
initial_states = self.layer.get_initial_states(x) | |
constants = self.get_constants(x) | |
preprocessed_input = self.layer.preprocess_input(x) | |
last_output, outputs, states = K.rnn(self.step, preprocessed_input, | |
initial_states, | |
go_backwards=self.layer.go_backwards, | |
mask=mask, | |
constants=constants, | |
unroll=self.layer.unroll, | |
input_length=input_shape[1]) | |
if self.layer.stateful: | |
self.updates = [] | |
for i in range(len(states)): | |
self.updates.append((self.layer.states[i], states[i])) | |
if self.layer.return_sequences: | |
return outputs | |
else: | |
return last_output | |
# test likes in https://github.com/fchollet/keras/blob/master/tests/keras/layers/test_wrappers.py | |
import pytest | |
import numpy as np | |
from numpy.testing import assert_allclose | |
from keras.utils.test_utils import keras_test | |
from keras.layers import wrappers, Input, recurrent, InputLayer | |
from keras.layers import core, convolutional, recurrent | |
from keras.models import Sequential, Model, model_from_json | |
nb_samples, timesteps, embedding_dim, output_dim = 2, 5, 3, 4 | |
embedding_num = 12 | |
x = np.random.random((nb_samples, timesteps, embedding_dim)) | |
y = np.random.random((nb_samples, timesteps, output_dim)) | |
# base line test with LSTM | |
model = Sequential() | |
model.add(InputLayer(batch_input_shape=(nb_samples, timesteps, embedding_dim))) | |
model.add(Attention(recurrent.LSTM(output_dim, input_dim=embedding_dim, return_sequences=True, consume_less='mem'))) | |
model.add(core.Activation('relu')) | |
model.compile(optimizer='rmsprop', loss='mse') | |
model.fit(x,y, nb_epoch=1, batch_size=nb_samples) | |
# test stacked with all RNN layers and consume_less options | |
model = Sequential() | |
model.add(InputLayer(batch_input_shape=(nb_samples, timesteps, embedding_dim))) | |
# test supported consume_less options | |
# model.add(Attention(recurrent.LSTM(embedding_dim, input_dim=embedding_dim,, consume_less='cpu' return_sequences=True))) # not supported | |
model.add(Attention(recurrent.LSTM(output_dim, input_dim=embedding_dim, consume_less='gpu', return_sequences=True))) | |
model.add(Attention(recurrent.LSTM(embedding_dim, input_dim=embedding_dim, consume_less='mem', return_sequences=True))) | |
# test each other RNN type | |
model.add(Attention(recurrent.GRU(embedding_dim, input_dim=embedding_dim, consume_less='mem', return_sequences=True))) | |
model.add(Attention(recurrent.SimpleRNN(embedding_dim, input_dim=embedding_dim, consume_less='mem', return_sequences=True))) | |
model.add(core.Activation('relu')) | |
model.compile(optimizer='rmsprop', loss='mse') | |
model.fit(x,y, nb_epoch=1, batch_size=nb_samples) | |
# test with return_sequence = False | |
model = Sequential() | |
model.add(InputLayer(batch_input_shape=(nb_samples, timesteps, embedding_dim))) | |
model.add(Attention(recurrent.LSTM(output_dim, input_dim=embedding_dim, return_sequences=False, consume_less='mem'))) | |
model.add(core.Activation('relu')) | |
model.compile(optimizer='rmsprop', loss='mse') | |
model.fit(x,y[:,-1,:], nb_epoch=1, batch_size=nb_samples) | |
# with bidirectional encoder | |
model = Sequential() | |
model.add(InputLayer(batch_input_shape=(nb_samples, timesteps, embedding_dim))) | |
model.add(wrappers.Bidirectional(recurrent.LSTM(embedding_dim, input_dim=embedding_dim, return_sequences=True))) | |
model.add(Attention(recurrent.LSTM(output_dim, input_dim=embedding_dim, return_sequences=True, consume_less='mem'))) | |
model.add(core.Activation('relu')) | |
model.compile(optimizer='rmsprop', loss='mse') | |
model.fit(x,y, nb_epoch=1, batch_size=nb_samples) | |
# test config | |
model.get_config() | |
# test to and from json | |
model = model_from_json(model.to_json(),custom_objects=dict(Attention=Attention)) | |
model.summary() | |
# test with functional API | |
input = Input(batch_shape=(nb_samples, timesteps, embedding_dim)) | |
output = Attention(recurrent.LSTM(output_dim, input_dim=embedding_dim, return_sequences=True, consume_less='mem'))(input) | |
model = Model(input, output) | |
model.compile(optimizer='rmsprop', loss='mse') | |
model.fit(x, y, nb_epoch=1, batch_size=nb_samples) |
@wassname @billhsia @thomasjungblut
It's because the version of keras in use. You could update the source code by referring to https://github.com/fchollet/keras/wiki/Keras-2.0-release-notes
However, I get another error at line:
x = K.dot(K.concatenate([x,Xa],1),self.W3)+self.b3
with error:
"ValueError: Dimensions must be equal, but are 1800 and 1200 for 'attention_1/MatMul_5' (op: 'MatMul') with input shapes: [?,1800], [1200,600]."
The dimensions cannot match. Could anyone help?
Does anyone really get this code running? I met exactly the same issues as others did.
Thanks for this implementation!
If somebody wants a much more easier and compact implementation of the attention mechanism for RNN, have a look at:
@billhsia sorry I didn't see this until now, the keras version is in this requirements.txt,
everyone, latest version is here
@philipperemy that's a nice implementation. Is it that simple, and whats the performance like? If you added some tests to the repo using example data instead of random data I would definitely use it
@wassname I think pop index error still persist for tensorflow.
@wassname @billhsia running into the same issue now, any hints?