Skip to content

Instantly share code, notes, and snippets.

@MSWon
Last active April 16, 2019 12:55
Show Gist options
  • Save MSWon/7d39c4c544af0446a79185d00adece19 to your computer and use it in GitHub Desktop.
Save MSWon/7d39c4c544af0446a79185d00adece19 to your computer and use it in GitHub Desktop.
Channelwise_SelfAttention
# -*- coding: utf-8 -*-
"""
Created on Sat Apr 6 10:47:30 2019
@author: jbk48
"""
from keras.layers import Layer
import tensorflow as tf
class SelfAttention(Layer):
def __init__(self, initializer=tf.contrib.layers.xavier_initializer(), **kwargs):
self.initializer = initializer
super(SelfAttention, self).__init__(**kwargs)
def build(self, input_shape):
# input_shape : (batch_size, size1, size2, channel)
self.channel = input_shape[-1]
self.size1 = input_shape[1]
self.size2 = input_shape[2]
super(SelfAttention, self).build(input_shape)
def call(self, inputs):
inputs_reshape = tf.reshape(inputs, (-1, self.size1*self.size2, self.channel)) ## (batch_size, size1*size2, channel)
inputs_transpose = tf.transpose(inputs_reshape, [0,2,1]) ## (batch_size, channel, size1*size2)
r1 = tf.layers.dense(inputs_reshape, self.channel, kernel_initializer=self.initializer) ## (batch_size, size1*size2, channel)
r2 = tf.matmul(r1, inputs_transpose) ## (batch_size, size1*size2, size1*size2)
Score_matrix = tf.nn.softmax(r2, axis=2) ## (batch_size, size1*size2, size1*size2)
outputs = tf.matmul(Score_matrix, inputs_reshape) ## (batch_size, size1*size2, channel)
outputs = tf.reshape(outputs, (-1, self.size1, self.size2, self.channel)) ## (batch_size, size1, size2, channel)
return outputs
def compute_output_shape(self, input_shape):
return input_shape
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment