Skip to content

Instantly share code, notes, and snippets.

@haimat
Created September 6, 2019 08:31
Show Gist options
  • Save haimat/65e9dddc9cd8004ea8870853f84d7b02 to your computer and use it in GitHub Desktop.
Save haimat/65e9dddc9cd8004ea8870853f84d7b02 to your computer and use it in GitHub Desktop.
A minimal working example for Tensorflow issue #32239
from tensorflow.keras import models
from tensorflow.keras import layers
from tensorflow.keras import optimizers
from tensorflow.keras.preprocessing import image
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import VGG16
from tensorflow.keras.applications import vgg16
from tensorflow.keras.callbacks import ModelCheckpoint
import os
# Create Keras model
image_size = 150
input_layer = layers.Input(shape=(image_size, image_size, 3), name="model_input")
base_model = VGG16(weights="imagenet", include_top=False, input_tensor=input_layer)
model_head = base_model.output
model_head = layers.Flatten(name="model_head_flatten")(model_head)
model_head = layers.Dense(256, activation="relu")(model_head)
model_head = layers.Dense(2, activation="softmax")(model_head)
model = models.Model(inputs=input_layer, outputs=model_head)
# Create image date generators
# You need one image data folder with three sub-folders "train", "validation", "test"
image_dir = "/home/mfb/Development/tf-github/data"
datagen = ImageDataGenerator(preprocessing_function=vgg16.preprocess_input)
training_img_generator = datagen.flow_from_directory(os.path.join(image_dir, 'train'),
target_size=(image_size, image_size), batch_size=20, class_mode="categorical")
validation_img_generator = datagen.flow_from_directory(os.path.join(image_dir, 'validation'),
target_size=(image_size, image_size), batch_size=20, class_mode="categorical")
test_img_generator = datagen.flow_from_directory(os.path.join(image_dir, 'test'),
target_size=(image_size, image_size), batch_size=20, class_mode="categorical")
# Compile Keras model
model.compile(loss="categorical_crossentropy", optimizer=optimizers.Adam(), metrics=["accuracy"])
# Train Keras model
auto_save_path = "/home/mfb/Development/tf-github/models"
checkpoint = ModelCheckpoint(auto_save_path, monitor="val_acc", verbose=0, save_best_only=True)
model.fit_generator(training_img_generator,
steps_per_epoch=50, epochs=25, validation_steps=50,
validation_data=validation_img_generator,
callbacks=[checkpoint], verbose=1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment