Skip to content

Instantly share code, notes, and snippets.

@naiborhujosua
Created August 4, 2022 12:58
Show Gist options
  • Save naiborhujosua/5010c0e68c213e98100cfe09c15e40da to your computer and use it in GitHub Desktop.
Save naiborhujosua/5010c0e68c213e98100cfe09c15e40da to your computer and use it in GitHub Desktop.
EffiencientNetB0
# 1. Create base model with tf.keras.applications
base_model = tf.keras.applications.EfficientNetB0(include_top=False)
# 2. Freeze the base model (so the pre-learned patterns remain)
base_model.trainable = False
# 3. Create inputs into the base model
inputs = tf.keras.layers.Input(shape=(224, 224, 3), name="input_layer")
# 4. If using ResNet50V2, add this to speed up convergence, remove for EfficientNet
# x = tf.keras.layers.experimental.preprocessing.Rescaling(1./255)(inputs)
# 5. Pass the inputs to the base_model (note: using tf.keras.applications, EfficientNet inputs don't have to be normalized)
x = base_model(inputs)
# Check data shape after passing it to base_model
print(f"Shape after base_model: {x.shape}")
# 6. Average pool the outputs of the base model (aggregate all the most important information, reduce number of computations)
x = tf.keras.layers.GlobalAveragePooling2D(name="global_average_pooling_layer")(x)
print(f"After GlobalAveragePooling2D(): {x.shape}")
# 7. Create the output activation layer
outputs = tf.keras.layers.Dense(6, activation="softmax", name="output_layer")(x)
# 8. Combine the inputs with the outputs into a model
model_2 = tf.keras.Model(inputs, outputs)
# 9. Compile the model
model_2.compile(loss='categorical_crossentropy',
optimizer=tf.keras.optimizers.Adam(),
metrics=["accuracy"])
# 10. Fit the model (we use less steps for validation so it's faster)
history_2 = model_2.fit(train_data,
epochs=10,
steps_per_epoch=len(train_data),
validation_data=val_data,
# Track model training logs
callbacks=[create_tensorboard_callback("CNN_learning", "transfer_learning_efficientNet_0")])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment