Skip to content

Instantly share code, notes, and snippets.

@allanbatista
Last active April 3, 2020 18:04
Show Gist options
  • Save allanbatista/996c1398b4fbca45878dc1e7e1126f41 to your computer and use it in GitHub Desktop.
Save allanbatista/996c1398b4fbca45878dc1e7e1126f41 to your computer and use it in GitHub Desktop.
Keras Attention Layer (tf 2.x)
import tensorflow as tf
from tensorflow.keras.layers import Layer
class Attention(Layer):
"""
this is a attention layer compatible with keras and use a tensorflow 2.x functions.
@example
> from tensorflow.keras.models import Model
> from tensorflow.keras.layers import Dense, Input
> sentence_size = 19 # quantity of works
> vec_size = 64 # size of vec of each word. ex.: 64
> text_input = Input(shape=(sentence_size, vec_size), dtype='float32', name="text_input")
> attention = Attention()(text_input)
> output = Dense(64, activation='softmax')(attention)
> model = Model(inputs=text_input, outputs=output)
"""
def build(self, input_shape):
self.W = self.add_weight(name="att_weight", shape=(input_shape[-1],1), initializer="normal")
self.b = self.add_weight(name="att_bias", shape=(input_shape[1],1), initializer="zeros")
super(Attention, self).build(input_shape)
def call(self, x):
et = tf.squeeze(tf.tanh(tf.matmul(x, self.W) + self.b), axis=-1)
at = tf.nn.softmax(et, axis=-1)
at = tf.expand_dims(at, axis=-1)
return tf.reduce_sum(x*at, axis=1)
def compute_output_shape(self, input_shape):
return input_shape[0], input_shape[-1]
def get_config(self):
return super(Attention, self).get_config()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment