Skip to content

Instantly share code, notes, and snippets.

@himanshurawlani
Last active September 6, 2020 19:14
Show Gist options
  • Save himanshurawlani/41c1ab8ff4e60ea3fa0f83616e72a2f8 to your computer and use it in GitHub Desktop.
Save himanshurawlani/41c1ab8ff4e60ea3fa0f83616e72a2f8 to your computer and use it in GitHub Desktop.
An example script to initialize trainable for Ray Tune and start hyperparameter tuning
class Trainable:
def __init__(self, train_dir, val_dir, snapshot_dir, final_run=False):
# Initializing state variables for the run
self.train_dir = train_dir
self.val_dir = val_dir
self.final_run = final_run
self.snapshot_dir = snapshot_dir
def train(self, config, reporter=None):
# If you get out of memory error try reducing the maximum batch size
train_generator = Generator(self.train_dir, config['batch_size'])
val_generator = Generator(self.val_dir, config['batch_size'])
# Create FCN model
model = FCN_model(config, len_classes=len(train_generator.classes))
# Compile model with losses and metrics
model.compile(optimizer=tf.keras.optimizers.Nadam(lr=config['lr']),
loss='categorical_crossentropy',
metrics=['accuracy'])
# Create callbacks to be used during model training
callbacks = create_callbacks(self.final_run, self.snapshot_dir)
logger.info("Starting model training")
# Start model training
history = model.fit(train_generator,
steps_per_epoch=len(train_generator),
epochs=100,
callbacks=callbacks,
validation_data=val_generator,
validation_steps=len(val_generator)
)
return history
logger.info("Initializing ray Trainable")
# Initialize Trainable for hyperparameter tuning
trainer = Trainable(args.train_dir, args.val_dir, args.snapshot_dir, final_run=False)
logger.info("Starting hyperparameter tuning")
analysis = tune.run(trainer.train,
verbose=1,
num_samples=num_samples,
search_alg=search_alg,
scheduler=scheduler,
raise_on_failed_trial=False,
resources_per_trial={"cpu": 16, "gpu": 2}
)
best_config = analysis.get_best_config(metric="val_loss", mode='min')
logger.info(f'Best config: {best_config}')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment