Skip to content

Instantly share code, notes, and snippets.

@suhaskv
Created January 6, 2021 08:26
Show Gist options
  • Save suhaskv/b5befad6f661a826140c7623ff9e3206 to your computer and use it in GitHub Desktop.
Save suhaskv/b5befad6f661a826140c7623ff9e3206 to your computer and use it in GitHub Desktop.
VSB Power Line Blog - Raffel's Attention layer implementation
# https://keras.io/layers/writing-your-own-keras-layers/
class Attention(Layer):
"""
Performs basic attention layer operation.
References
----------
.. [1] https://arxiv.org/pdf/1512.08756.pdf
.. [2] https://www.kaggle.com/qqgeogor/keras-lstm-attention-glove840b-lb-0-043
.. [3] https://www.kaggle.com/tarunpaparaju/vsb-competition-attention-bilstm-with-features/notebook?scriptVersionId=10690570
"""
def __init__(self):
"""
Initialize the attention layer.
Following parametes are initialized:
* Weight initializer - 'glorot_uniform'
* Weight regularizer
* Bias regularizer
* Weight constraints
* Bias constraints
* Initial bias
"""
self.supports_masking = True
self.init = initializers.get('glorot_uniform')
# https://keras.io/regularizers/
# Define weight and bias regularizer
self.W_regularizer = regularizers.get(None)
self.b_regularizer = regularizers.get(None)
# https://keras.io/constraints/
# Define weight and bias constraints
# Contraints => to keep check on the weight and bias values
self.W_constraint = constraints.get(None)
self.b_constraint = constraints.get(None)
self.bias = True
self.step_dim = input_shape[1]
self.features_dim = 0
super(Attention, self).__init__()
def build(self, input_shape):
"""
Build the Attention Layer.
"""
assert len(input_shape) == 3
# add_weight() comes from keras.layers.add_weight()
self.W = self.add_weight(shape=(input_shape[-1],), initializer=self.init,
name="{}_W".format(self.name),
regularizer=self.W_regularizer,
constraint=self.W_constraint)
self.features_dim = input_shape[-1]
if self.bias:
self.b = self.add_weight(shape=(input_shape[1],),
initializer='zero',
name='{}_b'.format(self.name),
regularizer=self.b_regularizer,
constraint=self.b_constraint)
else:
self.b = None
self.built = True
def compute_mask(self, input, input_mask=None):
"""
Do not pass the mask to the next layer.
"""
return None
def call(self, x, mask=None):
"""
Performs attention mechanism.
"""
features_dim = self.features_dim
step_dim = self.step_dim
# https://keras.io/backend/#reshape
# K.reshape(x, shape)
# x -> tensor or variable to be reshaped
# shape -> target shape
# K.reshape(x, (-1,cols)) -> will reshape the variable x according to the given columns, no. of rows needed is adjusted accordingly
# Get the dot product of (x,self.W) and reshape the dot product of (x,self.W) to have step_dim no. of columns and rows are adjusted accordingly
eij = K.reshape(K.dot(K.reshape(x, (-1, features_dim)), K.reshape(self.W, (features_dim, 1))), (-1, step_dim))
if self.bias:
eij += self.b
eij = K.tanh(eij)
# https://www.analyticsvidhya.com/blog/2019/11/comprehensive-guide-attention-mechanism-deep-learning/
a = K.exp(eij)
if mask is not None:
# typecast mask to a 32-bit float value
a *= K.cast(mask, K.floatx())
# Perform softmax operation
a /= K.cast(K.sum(a, axis=1, keepdims=True) + K.epsilon(), K.floatx())
a = K.expand_dims(a)
weighted_input = x * a
return K.sum(weighted_input, axis=1)
def compute_output_shape(self, input_shape):
"""
Compute the shape of the output.
"""
return input_shape[0], self.features_dim
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment