Skip to content

Instantly share code, notes, and snippets.

@himanshurawlani
Created December 26, 2019 18:59
Show Gist options
  • Save himanshurawlani/9c3a98daf52569f571bfd5f0331f666b to your computer and use it in GitHub Desktop.
Save himanshurawlani/9c3a98daf52569f571bfd5f0331f666b to your computer and use it in GitHub Desktop.
def train(model, train_generator, val_generator, epochs = 50):
model.compile(optimizer=tf.keras.optimizers.Adam(lr=0.0001),
loss='categorical_crossentropy',
metrics=['accuracy'])
checkpoint_path = './snapshots'
os.makedirs(checkpoint_path, exist_ok=True)
model_path = os.path.join(checkpoint_path, 'model_epoch_{epoch:02d}_loss_{loss:.2f}_acc_{acc:.2f}_val_loss_{val_loss:.2f}_val_acc_{val_acc:.2f}.h5')
history = model.fit_generator(generator=train_generator,
steps_per_epoch=len(train_generator),
epochs=epochs,
callbacks=[tf.keras.callbacks.ModelCheckpoint(model_path, monitor='val_loss', save_best_only=True, verbose=1)],
validation_data=val_generator,
validation_steps=len(val_generator))
return history
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment