Skip to content

Instantly share code, notes, and snippets.

@ozgurshn
Created June 21, 2018 22:50
Show Gist options
  • Save ozgurshn/deef535cf47d8641d365a50b485e0467 to your computer and use it in GitHub Desktop.
Save ozgurshn/deef535cf47d8641d365a50b485e0467 to your computer and use it in GitHub Desktop.
def modelFitGenerator(fitModel):
num_train_samples = sum([len(files) for r, d, files in os.walk(train_data_dir)])
num_valid_samples = sum([len(files) for r, d, files in os.walk(validation_data_dir)])
num_train_steps = math.floor(num_train_samples/batch_size)
num_valid_steps = math.floor(num_valid_samples/batch_size)
train_datagen = ImageDataGenerator(
rotation_range=90,
horizontal_flip=True,
vertical_flip=True,
zoom_range=0.4)
test_datagen = ImageDataGenerator()
train_generator = train_datagen.flow_from_directory(
train_data_dir,
target_size=image_size ,
batch_size=batch_size,
class_mode='categorical', shuffle=True
)
validation_generator = test_datagen.flow_from_directory(
validation_data_dir,
target_size=image_size ,
batch_size=batch_size,
class_mode='categorical', shuffle=True
)
print("start history model")
history = fitModel.fit_generator(
train_generator,
steps_per_epoch=num_train_steps,
epochs=nb_epoch,
validation_data=validation_generator,
validation_steps=num_valid_steps)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment