Skip to content

Instantly share code, notes, and snippets.

@stoney95
Created November 23, 2018 09:17
Show Gist options
  • Save stoney95/3fc4216db7af0675df4d6326adb8aee0 to your computer and use it in GitHub Desktop.
Save stoney95/3fc4216db7af0675df4d6326adb8aee0 to your computer and use it in GitHub Desktop.
Attentive Convolution with custom Attention layer
from keras.layers import Lambda, Reshape, RepeatVector, Concatenate, Conv1D, Activation
from keras.layers import Layer
from keras import activations
class Attention(Layer):
def __init__(self, kernel_activation='hard_sigmoid', before=False, **kwargs):
super(Attention, self).__init__(**kwargs)
self.kernel_activation = activations.get(kernel_activation)
K.set_floatx('float32')
self.before = before
def build(self, input_shape):
self.num_words = input_shape[0][1]
#self.em_dim = input_shape[0][2]
super(Attention, self).build(input_shape)
def get_output_shape_for(self, input_shape):
if self.before:
return input_shape
return (input_shape[0], input_shape[2])
def compute_output_shape(self, input_shape):
input_shape = input_shape[0]
if self.before:
return input_shape
return (input_shape[0], input_shape[2])
def call(self, x, mask=None):
text = x[0]
context = x[1]
length = 1
for i in range(len(context.shape)):
if i > 0:
length *= int(context.shape[i])
context = Lambda(lambda x: Reshape((length,))(x))(context)
context_repeated = RepeatVector(self.num_words)(context)
merged = Concatenate(axis=2)([context_repeated, text])
scores = Conv1D(1,1)(merged)
weights = Activation(activation='softmax')(scores)
#weighted = K.transpose(tf.multiply(K.transpose(text), weights))
if not self.before:
weigthed = K.batch_dot(K.permute_dimensions(text, (0,2,1)), weights)
return K.squeeze(weigthed, 2)
weigthed = tf.multiply(text, weights)
return weigthed
def get_config(self):
config = {'kernel_activation': activations.serialize(self.kernel_activation),
'before': self.before}
base_config = super(Attention, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
from keras.layers import Flatten, Dropout, Dense, Lambda, ZeroPadding2D, MaxPooling2D, Concatenate, Conv2D, LSTM, Bidirectional, Activation
from keras import backend as K
from Core.NeuralNetwork.CustomLayers.Attention import Attention
def build_cnn(input, dropout, kernel_sizes, num_stages, num_filters, pool_sizes, attention=False, context='same', fully_connected_dimension=1000):
'''
This method builds a cnn with the given parameters
:param input: Output of preceding layer
:param dropout: dropout-rate
:param kernel_sizes: list of lists, inner list defines different kernel-size per stage, outer list defines different stages
:param num_stages: describes the number of stages
:param num_filters: list, different number of filters can be used per stage
:param pool_sizes: list, different pool-sizes can be used per stage
:param attention: defines if attention should be applied
:param context: defines the context. 'same' means self-attention. Otherwise a layer-output can be given to apply query-attention
:return: input & output of the encoder
'''
input_cpy = input
for i in range(num_stages):
if i > 0:
attention = False
input_cpy = _build_stage(kernel_sizes[i], input_cpy, num_filters[i], context, pool_sizes[i], attention)
flatten = Flatten()(input_cpy)
dropout = Dropout(dropout)(flatten)
fully_connected = Dense(units=fully_connected_dimension)(dropout)
return input, fully_connected
def _build_stage(kernel_sizes, pre_layer, num_filters, context, pool_size, attention):
convs = []
for size in kernel_sizes:
reshape = Lambda(lambda x: K.expand_dims(x, 3))(pre_layer)
if size % 2 == 0:
padded_input = ZeroPadding2D(padding=((int(size / 2), int(size / 2) - 1), (0,0)))(reshape)
else:
padded_input = ZeroPadding2D(padding=((int(size / 2), int(size / 2)), (0,0)))(reshape)
conv = Conv2D(num_filters, (size, int(reshape.shape[2])), activation='relu', padding='valid')(padded_input)
convs.append(conv)
if len(convs) > 1:
all_filters = Concatenate(axis=3)(convs)
else:
all_filters = convs[0]
if attention:
all_filters = Lambda(lambda x: K.squeeze(x, 2))(all_filters)
if context == 'same':
attentive_context = Attention(before=True)([all_filters, all_filters])
else:
attentive_context = Attention(before=True)([all_filters, context])
attentive_context = Lambda(lambda x: K.expand_dims(x, axis=2))(attentive_context)
reshape = Lambda(lambda x: K.permute_dimensions(x, (0,1,3,2)))(attentive_context)
else:
reshape = Lambda(lambda x: K.permute_dimensions(x, (0, 1, 3, 2)))(all_filters)
if pool_size > int(reshape.shape[1]):
pool_size = int(reshape.shape[1])
filtered = MaxPooling2D(pool_size=(pool_size, 1))(reshape)
reshape = Lambda(lambda x: K.squeeze(x, 3))(filtered)
return reshape
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment