Skip to content

Instantly share code, notes, and snippets.

@rmdort
Forked from luthfianto/Attention.py
Created April 24, 2017 03:23
Show Gist options
  • Save rmdort/602558102f90c1d060300c05c23261b2 to your computer and use it in GitHub Desktop.
Save rmdort/602558102f90c1d060300c05c23261b2 to your computer and use it in GitHub Desktop.
Keras Layer that implements an Attention mechanism for temporal data. Supports Masking. Follows the work of Raffel et al. [https://arxiv.org/abs/1512.08756]
from keras.layers.core import Layer
from keras import initializers, regularizers, constraints
from keras import backend as K
class Attention(Layer):
def __init__(self,
kernel_regularizer=None, bias_regularizer=None,
kernel_constraint=None, bias_constraint=None,
use_bias=True, **kwargs):
"""
Keras Layer that implements an Attention mechanism for temporal data.
Supports Masking.
Follows the work of Raffel et al. [https://arxiv.org/abs/1512.08756]
# Input shape
3D tensor with shape: `(samples, steps, features)`.
# Output shape
2D tensor with shape: `(samples, features)`.
:param kwargs:
Just put it on top of an RNN Layer (GRU/LSTM/SimpleRNN) with return_sequences=True.
The dimensions are inferred based on the output shape of the RNN.
Example:
model.add(LSTM(64, return_sequences=True))
model.add(Attention())
"""
self.supports_masking = True
self.kernel_initializer = initializers.get('glorot_uniform')
self.kernel_regularizer = regularizers.get(kernel_regularizer)
self.bias_regularizer = regularizers.get(bias_regularizer)
self.kernel_constraint = constraints.get(kernel_constraint)
self.bias_constraint = constraints.get(bias_constraint)
self.use_bias = use_bias
super(Attention, self).__init__(**kwargs)
def build(self, input_shape):
assert len(input_shape) == 3
self.kernel = self.add_weight((input_shape[-1], 1),
initializer=self.kernel_initializer,
name='{}_W'.format(self.name),
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint)
if self.use_bias:
self.bias = self.add_weight((input_shape[1],),
initializer='zero',
name='{}_b'.format(self.name),
regularizer=self.bias_regularizer,
constraint=self.bias_constraint)
else:
self.bias = None
self.built = True
def compute_mask(self, x, input_mask=None):
# do not pass the mask to the next layers
return None
def call(self, x, mask=None):
print("W: ", K.int_shape(self.kernel))
print("x: ", K.int_shape(x))
eij = K.dot(x, self.kernel)
print("eij: ", K.int_shape(eij))
eij = K.squeeze(eij, -1)
print("eij: ", K.int_shape(eij))
if self.use_bias:
eij += self.bias
eij = K.tanh(eij)
a = K.softmax(eij)
a = K.expand_dims(a)
weighted_input = x * a
return K.sum(weighted_input, axis=1)
def compute_output_shape(self, input_shape):
return (input_shape[0], input_shape[-1],)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment