Skip to content

Instantly share code, notes, and snippets.

@NMZivkovic
Created November 10, 2019 20:28
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save NMZivkovic/af048a775527b4ee3a0f962cc779de24 to your computer and use it in GitHub Desktop.
Save NMZivkovic/af048a775527b4ee3a0f962cc779de24 to your computer and use it in GitHub Desktop.
class Wrapper(tf.keras.Model):
def __init__(self, base_model):
super(Wrapper, self).__init__()
self.base_model = base_model
self.average_pooling_layer = tf.keras.layers.GlobalAveragePooling2D()
self.output_layer = tf.keras.layers.Dense(1)
def call(self, inputs):
x = self.base_model(inputs)
x = self.average_pooling_layer(x)
output = self.output_layer(x)
return output
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment