Skip to content

Instantly share code, notes, and snippets.

@novasush
Created March 7, 2020 17:47
Show Gist options
  • Save novasush/f6550da2f4330303e165d96d56eb0546 to your computer and use it in GitHub Desktop.
Save novasush/f6550da2f4330303e165d96d56eb0546 to your computer and use it in GitHub Desktop.
def create_model():
input_layer = tf.keras.layers.Input(shape=(224, 224, 3))
base_model = tf.keras.applications.MobileNetV2(input_tensor=input_layer,
weights='imagenet',
include_top=False)
base_model.trainable = False
x = tf.keras.layers.GlobalAveragePooling2D()(base_model.output)
x = tf.keras.layers.Dense(2, activation='softmax')(x)
model = tf.keras.models.Model(inputs=input_layer, outputs=x)
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['acc'])
return model
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment