Skip to content

Instantly share code, notes, and snippets.

@KerryHalupka
Last active August 16, 2020 09:48
Show Gist options
  • Save KerryHalupka/1ac50c02ae8d336d1e2e8ea795ad7ba7 to your computer and use it in GitHub Desktop.
Save KerryHalupka/1ac50c02ae8d336d1e2e8ea795ad7ba7 to your computer and use it in GitHub Desktop.
learning_rate = 0.0001
# get just the feature extraction layers of mobilenet
base_model = tf.keras.applications.MobileNetV2(
include_top=False, weights='imagenet', input_shape=(im_size, im_size, 3))
# freeze the feature extractor convolutional layers
base_model.trainable = False
# define a classification layer on top
global_average_layer = tf.keras.layers.GlobalAveragePooling2D()
prediction_layer = tf.keras.layers.Dense(1)
# fit it all together with some dropout
inputs = tf.keras.Input(shape=(im_size, im_size, 3))
x = base_model(inputs, training=False)
x = global_average_layer(x)
x = tf.keras.layers.Dropout(0.2)(x)
outputs = prediction_layer(x)
model = tf.keras.Model(inputs, outputs)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment