Skip to content

Instantly share code, notes, and snippets.

@shamilnabiyev
Created April 9, 2023 20:14
Show Gist options
  • Save shamilnabiyev/518d0a5929ed28d63403a710256883aa to your computer and use it in GitHub Desktop.
Save shamilnabiyev/518d0a5929ed28d63403a710256883aa to your computer and use it in GitHub Desktop.
Transfer learning with VGG and Keras [for image classification]
# Credits:
# Author: Gabriel Cassimiro
# Blog post: https://towardsdatascience.com/transfer-learning-with-vgg16-and-keras-50ea161580b4
# GitHub Repo: https://github.com/gabrielcassimiro17/object-detection
#
import tensorflow_datasets as tfds
from tensorflow.keras import layers, models
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.callbacks import EarlyStopping
## Loading images and labels
(train_ds, train_labels), (test_ds, test_labels) = tfds.load(
"tf_flowers",
split=["train[:70%]", "train[:30%]"], ## Train test split
batch_size=-1,
as_supervised=True, # Include labels
)
## Resizing images
train_ds = tf.image.resize(train_ds, (150, 150))
test_ds = tf.image.resize(test_ds, (150, 150))
## Transforming labels to correct format
train_labels = to_categorical(train_labels, num_classes=5)
test_labels = to_categorical(test_labels, num_classes=5)
from tensorflow.keras.applications.vgg16 import VGG16
from tensorflow.keras.applications.vgg16 import preprocess_input
## Loading VGG16 model
base_model = VGG16(weights="imagenet", include_top=False, input_shape=train_ds[0].shape)
base_model.trainable = False ## Not trainable weights
## Preprocessing input
train_ds = preprocess_input(train_ds)
test_ds = preprocess_input(test_ds)
flatten_layer = layers.Flatten()
dense_layer_1 = layers.Dense(50, activation='relu')
dense_layer_2 = layers.Dense(20, activation='relu')
prediction_layer = layers.Dense(5, activation='softmax')
model = models.Sequential([
base_model,
flatten_layer,
dense_layer_1,
dense_layer_2,
prediction_layer
])
model.compile(
optimizer='adam',
loss='categorical_crossentropy',
metrics=['accuracy'],
)
es = EarlyStopping(monitor='val_accuracy', mode='max', patience=5, restore_best_weights=True)
model.fit(train_ds, train_labels, epochs=50, validation_split=0.2, batch_size=32, callbacks=[es])
model.evaluate(test_ds, test_labels)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment