Skip to content

Instantly share code, notes, and snippets.

@Cospel
Last active April 19, 2020 14:03
Show Gist options
  • Save Cospel/16a7b0f48430308f8eb5d0381243caa2 to your computer and use it in GitHub Desktop.
Save Cospel/16a7b0f48430308f8eb5d0381243caa2 to your computer and use it in GitHub Desktop.
finetune.py
import tensorflow as tf
base_model = tf.keras.applications.MobileNetV2(
weights="imagenet", input_shape=self.shape, include_top=False
)
# Freeze the base_model
base_model.trainable = False
# Set the base model training=False so batch statistics is not updated
inputs = keras.Input(shape=(224, 224, 3))
x = base_model(inputs, training=False) # IMPORTANT
x = keras.layers.GlobalAveragePooling2D()(x)
x = keras.layers.Dense(128, activation='relu')(x) # just train this and following layer
outputs = keras.layers.Dense(CLASSES)(x)
model = keras.Model(inputs, outputs)
# call fit on freeze backbone
model.compile(optimizer=keras.optimizers.Adam(1e-3))
model.fit(...)
# unfreeze base backbone and fit again with lower learning rate
base_model.trainable = True
model.compile(optimizer=keras.optimizers.Adam(1e-5))
model.fit(...)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment