[Medium] How to Choose the Best Keras Pre-Trained Model for Image Classification

All the code snippets for the medium article Link

# Imports
import tensorflow as tf
import tensorflow_datasets as tfds
import pandas as pd
import matplotlib.pyplot as plt
import inspect
from tqdm import tqdm
# Set batch size for training and validation
batch_size = 32
# List all available models
model_dictionary = {m[0]:m[1] for m in inspect.getmembers(tf.keras.applications, inspect.isfunction)}
# Download the training and validation data
(train, validation), metadata = tfds.load('cats_vs_dogs', split=['train[:70%]', 'train[70%:]'], with_info=True, as_supervised=True)
# Number of training examples and labels
num_train = len(list(train))
num_validation = len(list(validation))
num_classes = len(metadata.features['label'].names)
num_iterations = int(num_train/batch_size)
# Print important info
print(f'Num train images: {num_train} \
\nNum validation images: {num_validation} \
\nNum classes: {num_classes} \
\nNum iterations per epoch: {num_iterations}')
def normalize_img(image, label, img_size):
# Resize image to the desired img_size and normalize it
# One hot encode the label
image = tf.image.resize(image, img_size)
image = tf.cast(image, tf.float32) / 255.
label = tf.one_hot(label, depth=num_classes)
return image, label
def preprocess_data(train, validation, batch_size, img_size):
# Apply the normalize_img function on all train and validation data and create batches
train_processed = image, label: normalize_img(image, label, img_size))
# If your data is already batched (eg, when using the image_dataset_from_directory function), remove .batch(batch_size)
train_processed = train_processed.batch(batch_size).repeat()
validation_processed = image, label: normalize_img(image, label, img_size))
# If your data is already batched (eg, when using the image_dataset_from_directory function), remove .batch(batch_size)
validation_processed = validation_processed.batch(batch_size)
return train_processed, validation_processed
# Run preprocessing
train_processed_224, validation_processed_224 = preprocess_data(train, validation, batch_size, img_size=[224,224])
train_processed_331, validation_processed_331 = preprocess_data(train, validation, batch_size, img_size=[331,331])
# Loop over each model available in Keras
model_benchmarks = {'model_name': [], 'num_model_params': [], 'validation_accuracy': []}
for model_name, model in tqdm(model_dictionary.items()):
# Special handling for "NASNetLarge" since it requires input images with size (331,331)
if 'NASNetLarge' in model_name:
train_processed = train_processed_331
validation_processed = validation_processed_331
train_processed = train_processed_224
validation_processed = validation_processed_224
# load the pre-trained model with global average pooling as the last layer and freeze the model weights
pre_trained_model = model(include_top=False, pooling='avg', input_shape=input_shape)
pre_trained_model.trainable = False
# custom modifications on top of pre-trained model and fit
clf_model = tf.keras.models.Sequential()
clf_model.add(tf.keras.layers.Dense(num_classes, activation='softmax'))
clf_model.compile(loss='categorical_crossentropy', metrics=['accuracy'])
history =, epochs=3, validation_data=validation_processed, steps_per_epoch=num_iterations)
# Calculate all relevant metrics
# Convert Results to DataFrame for easy viewing
benchmark_df = pd.DataFrame(model_benchmarks)
# sort in ascending order of num_model_params column
benchmark_df.sort_values('num_model_params', inplace=True)
# write results to csv file
benchmark_df.to_csv('benchmark_df.csv', index=False)
# Loop over each row and plot the num_model_params vs validation_accuracy
for row in benchmark_df.itertuples():
plt.scatter(row.num_model_params, row.validation_accuracy, label=row.model_name, marker=markers[row.Index], s=150, linewidths=2)
plt.xlabel('Number of Parameters in Model')
plt.ylabel('Validation Accuracy after 3 Epochs')
plt.title('Accuracy vs Model Size')
# Move legend out of the plot
plt.legend(bbox_to_anchor=(1, 1), loc='upper left');
