Skip to content

Instantly share code, notes, and snippets.

@shartoo
Created January 12, 2023 05:14
Show Gist options
  • Save shartoo/4ed69814f686ef2c07120d739cda4cb5 to your computer and use it in GitHub Desktop.
Save shartoo/4ed69814f686ef2c07120d739cda4cb5 to your computer and use it in GitHub Desktop.
from tensorflow.keras.models import Model
from tensorflow.keras.applications import VGG16
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras.preprocessing.image import ImageDataGenerator
# ImageDataGenerator
train_datagen = ImageDataGenerator(
horizontal_flip=True,
vertical_flip=True,
rotation_range=50,
)
batch_size = 8
train_generator = train_datagen.flow_from_directory(
'./Fruits',
target_size=(224,224),
color_mode='rgb',
batch_size=batch_size,
class_mode='categorical',
shuffle=True)
#---number of fruits---
NO_CLASSES = max(train_generator.class_indices.values()) + 1
#---load the VGG16 model as the base model for training---
base_model = VGG16(include_top=False, input_shape=(224, 224, 3))
#---add our own layers---
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(1024,activation='relu')(x) # add dense layers so
# that the model can
# learn more complex
# functions and
# classify for better
# results.
x = Dense(1024,activation='relu')(x) # dense layer 2
x = Dense(512,activation='relu')(x) # dense layer 3
preds = Dense(NO_CLASSES,
activation='softmax')(x) # final layer with
# softmax activation
#---create a new model with the base model's original
# input and the new model's output---
model = Model(inputs = base_model.input, outputs = preds)
#---don't train the first 19 layers - 0..18---
for layer in model.layers[:19]:
layer.trainable=False
#---train the rest of the layers - 19 onwards---
for layer in model.layers[19:]:
layer.trainable=True
#---compile the model---
model.compile(optimizer='Adam',
loss='categorical_crossentropy',
metrics=['accuracy'])
#---train the model---
step_size_train = train_generator.n // train_generator.batch_size
model.fit(train_generator,
steps_per_epoch=step_size_train,
epochs=15)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment