Created
January 6, 2021 08:26
-
-
Save suhaskv/b5befad6f661a826140c7623ff9e3206 to your computer and use it in GitHub Desktop.
VSB Power Line Blog - Raffel's Attention layer implementation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# https://keras.io/layers/writing-your-own-keras-layers/ | |
class Attention(Layer): | |
""" | |
Performs basic attention layer operation. | |
References | |
---------- | |
.. [1] https://arxiv.org/pdf/1512.08756.pdf | |
.. [2] https://www.kaggle.com/qqgeogor/keras-lstm-attention-glove840b-lb-0-043 | |
.. [3] https://www.kaggle.com/tarunpaparaju/vsb-competition-attention-bilstm-with-features/notebook?scriptVersionId=10690570 | |
""" | |
def __init__(self): | |
""" | |
Initialize the attention layer. | |
Following parametes are initialized: | |
* Weight initializer - 'glorot_uniform' | |
* Weight regularizer | |
* Bias regularizer | |
* Weight constraints | |
* Bias constraints | |
* Initial bias | |
""" | |
self.supports_masking = True | |
self.init = initializers.get('glorot_uniform') | |
# https://keras.io/regularizers/ | |
# Define weight and bias regularizer | |
self.W_regularizer = regularizers.get(None) | |
self.b_regularizer = regularizers.get(None) | |
# https://keras.io/constraints/ | |
# Define weight and bias constraints | |
# Contraints => to keep check on the weight and bias values | |
self.W_constraint = constraints.get(None) | |
self.b_constraint = constraints.get(None) | |
self.bias = True | |
self.step_dim = input_shape[1] | |
self.features_dim = 0 | |
super(Attention, self).__init__() | |
def build(self, input_shape): | |
""" | |
Build the Attention Layer. | |
""" | |
assert len(input_shape) == 3 | |
# add_weight() comes from keras.layers.add_weight() | |
self.W = self.add_weight(shape=(input_shape[-1],), initializer=self.init, | |
name="{}_W".format(self.name), | |
regularizer=self.W_regularizer, | |
constraint=self.W_constraint) | |
self.features_dim = input_shape[-1] | |
if self.bias: | |
self.b = self.add_weight(shape=(input_shape[1],), | |
initializer='zero', | |
name='{}_b'.format(self.name), | |
regularizer=self.b_regularizer, | |
constraint=self.b_constraint) | |
else: | |
self.b = None | |
self.built = True | |
def compute_mask(self, input, input_mask=None): | |
""" | |
Do not pass the mask to the next layer. | |
""" | |
return None | |
def call(self, x, mask=None): | |
""" | |
Performs attention mechanism. | |
""" | |
features_dim = self.features_dim | |
step_dim = self.step_dim | |
# https://keras.io/backend/#reshape | |
# K.reshape(x, shape) | |
# x -> tensor or variable to be reshaped | |
# shape -> target shape | |
# K.reshape(x, (-1,cols)) -> will reshape the variable x according to the given columns, no. of rows needed is adjusted accordingly | |
# Get the dot product of (x,self.W) and reshape the dot product of (x,self.W) to have step_dim no. of columns and rows are adjusted accordingly | |
eij = K.reshape(K.dot(K.reshape(x, (-1, features_dim)), K.reshape(self.W, (features_dim, 1))), (-1, step_dim)) | |
if self.bias: | |
eij += self.b | |
eij = K.tanh(eij) | |
# https://www.analyticsvidhya.com/blog/2019/11/comprehensive-guide-attention-mechanism-deep-learning/ | |
a = K.exp(eij) | |
if mask is not None: | |
# typecast mask to a 32-bit float value | |
a *= K.cast(mask, K.floatx()) | |
# Perform softmax operation | |
a /= K.cast(K.sum(a, axis=1, keepdims=True) + K.epsilon(), K.floatx()) | |
a = K.expand_dims(a) | |
weighted_input = x * a | |
return K.sum(weighted_input, axis=1) | |
def compute_output_shape(self, input_shape): | |
""" | |
Compute the shape of the output. | |
""" | |
return input_shape[0], self.features_dim |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment