Skip to content

Instantly share code, notes, and snippets.

@ucalyptus
Created February 6, 2020 05:34
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ucalyptus/a455f1101dc3d5b95d0632c2907af549 to your computer and use it in GitHub Desktop.
Save ucalyptus/a455f1101dc3d5b95d0632c2907af549 to your computer and use it in GitHub Desktop.
class PAM(Layer):
def __init__(self,
gamma_initializer=tf.zeros_initializer(),
gamma_regularizer=None,
gamma_constraint=None,
**kwargs):
super(PAM, self).__init__(**kwargs)
self.gamma_initializer = gamma_initializer
self.gamma_regularizer = gamma_regularizer
self.gamma_constraint = gamma_constraint
def build(self, input_shape):
self.gamma = self.add_weight(shape=(1, ),
initializer=self.gamma_initializer,
name='gamma',
regularizer=self.gamma_regularizer,
constraint=self.gamma_constraint)
self.built = True
def compute_output_shape(self, input_shape):
return input_shape
def call(self, input):
input_shape = input.get_shape().as_list()
_, h, w, filters = input_shape
b = Conv2D(filters // 8, 1, use_bias=False, kernel_initializer='he_normal')(input)
c = Conv2D(filters // 8, 1, use_bias=False, kernel_initializer='he_normal')(input)
d = Conv2D(filters, 1, use_bias=False, kernel_initializer='he_normal')(input)
vec_b = K.reshape(b, (-1, h * w, filters // 8))
vec_cT = tf.transpose(K.reshape(c, (-1, h * w, filters // 8)), (0, 2, 1))
bcT = K.batch_dot(vec_b, vec_cT)
softmax_bcT = Activation('softmax')(bcT)
vec_d = K.reshape(d, (-1, h * w, filters))
bcTd = K.batch_dot(softmax_bcT, vec_d)
bcTd = K.reshape(bcTd, (-1, h, w, filters))
out = self.gamma*bcTd + input
return out
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment