Skip to content

Instantly share code, notes, and snippets.

@pierreelliott
Created August 4, 2020 10:17
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save pierreelliott/64a261ff382ea0b5994615e53c7156eb to your computer and use it in GitHub Desktop.
Save pierreelliott/64a261ff382ea0b5994615e53c7156eb to your computer and use it in GitHub Desktop.
#!/usr/bin/env python
# coding: utf-8
# In[1]:
import wandb
import random
import numpy as np
import tensorflow as tf
import kerastuner as kt
# In[2]:
def create_datasets():
X = np.random.random((10, 5))
y = np.random.random((10,))
return X, y
def create_model(hp):
# keras-tuner definition of hyperparameter space
hp.Int('param1', 0, 10, step=1)
hp.Int('param2', 0, 10, step=1)
hp.Int('param3', 0, 10, step=1)
X = tf.keras.Input(shape=5)
out = tf.keras.layers.Dense(1)(X)
model = tf.keras.Model(X, out)
model.compile(optimizer='rmsprop', loss='binary_crossentropy')
return model
# In[3]:
class WandbLogger(kt.Logger):
def register_tuner(self, tuner_state):
pass # Called at the beginning of the search
def exit(self):
pass # Called when search exits
def report_trial_state(self, trial_id, trial_state):
pass
def register_trial(self, trial_id, trial_state):
# Called at the beginning of each trial (ie, with each set of parameters)
hp_config = trial_state['hyperparameters']['values']
wandb.init(project=PROJECT_NAME, config=hp_config,
sync_tensorboard=True, reinit=True)
# In[4]:
class CustomTuner(kt.Tuner):
def on_epoch_end(self, trial, model, epoch, logs):
logs['accuracy'] = 0 # This metric isn't present for my first model but keras-tuner oracle needs it at each epoch
super(CustomTuner, self).on_epoch_end(trial, model, epoch, logs)
def run_trial(self, trial, *fit_args, test_ds=None, **fit_kwargs):
# Wandb init has been called here
fit_kwargs['callbacks'].append(wandb.keras.WandbCallback())
super(CustomTuner, self).run_trial(trial, *fit_args, **fit_kwargs)
# Create second model and get metrics
# ... model.evaluate()
metrics = {'accuracy': random.random(), 'metric2': 5, 'metric3': 300}
# Alert the oracle of the last 'true' value of the accuracy metric
self.oracle.update_trial(trial.trial_id, {'accuracy': metrics['accuracy']}, step=EPOCHS+1) # model_history.epoch[-1]+1
# I need to tell the oracle which step it is, otherwise it puts it at step 0
wandb.config.update({'new_param': random.random()})
wandb.run.summary['best_accuracy'] = metrics['accuracy']
# Logging
wandb.log(metrics)
wandb.join()
# In[5]:
PROJECT_NAME = 'test-project'
NB_MODELS = 5
EPOCHS = 4
# In[6]:
tuner = CustomTuner(
oracle=kt.oracles.RandomSearch(
objective=kt.Objective('accuracy', 'max'),
max_trials=NB_MODELS),
hypermodel=create_model,
logger=WandbLogger(),
directory='.kt/',
project_name=PROJECT_NAME
)
# In[7]:
ds_X, ds_y = create_datasets()
# In[8]:
tuner.search(x=ds_X, y=ds_y, epochs=EPOCHS, callbacks=[tf.keras.callbacks.TensorBoard()]
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment