Skip to content

Instantly share code, notes, and snippets.

@saurabhpal97
Created May 30, 2019 10:21
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save saurabhpal97/651e708b7fbd933c55892038a55acca1 to your computer and use it in GitHub Desktop.
Save saurabhpal97/651e708b7fbd933c55892038a55acca1 to your computer and use it in GitHub Desktop.
def squeeze_excite_block(input, ratio=16):
init = input
channel_axis = 1 if K.image_data_format() == "channels_first" else -1
filters = init._keras_shape[channel_axis]
se_shape = (1, 1, filters)
se = GlobalAveragePooling2D()(init)
se = Reshape(se_shape)(se)
se = Dense(filters // ratio, activation='relu', kernel_initializer='he_normal', use_bias=False)(se)
se = Dense(filters, activation='sigmoid', kernel_initializer='he_normal', use_bias=False)(se)
if K.image_data_format() == 'channels_first':
se = Permute((3, 1, 2))(se)
x = multiply([init, se])
return x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment