-
-
Save pierreelliott/64a261ff382ea0b5994615e53c7156eb to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# coding: utf-8 | |
# In[1]: | |
import wandb | |
import random | |
import numpy as np | |
import tensorflow as tf | |
import kerastuner as kt | |
# In[2]: | |
def create_datasets(): | |
X = np.random.random((10, 5)) | |
y = np.random.random((10,)) | |
return X, y | |
def create_model(hp): | |
# keras-tuner definition of hyperparameter space | |
hp.Int('param1', 0, 10, step=1) | |
hp.Int('param2', 0, 10, step=1) | |
hp.Int('param3', 0, 10, step=1) | |
X = tf.keras.Input(shape=5) | |
out = tf.keras.layers.Dense(1)(X) | |
model = tf.keras.Model(X, out) | |
model.compile(optimizer='rmsprop', loss='binary_crossentropy') | |
return model | |
# In[3]: | |
class WandbLogger(kt.Logger): | |
def register_tuner(self, tuner_state): | |
pass # Called at the beginning of the search | |
def exit(self): | |
pass # Called when search exits | |
def report_trial_state(self, trial_id, trial_state): | |
pass | |
def register_trial(self, trial_id, trial_state): | |
# Called at the beginning of each trial (ie, with each set of parameters) | |
hp_config = trial_state['hyperparameters']['values'] | |
wandb.init(project=PROJECT_NAME, config=hp_config, | |
sync_tensorboard=True, reinit=True) | |
# In[4]: | |
class CustomTuner(kt.Tuner): | |
def on_epoch_end(self, trial, model, epoch, logs): | |
logs['accuracy'] = 0 # This metric isn't present for my first model but keras-tuner oracle needs it at each epoch | |
super(CustomTuner, self).on_epoch_end(trial, model, epoch, logs) | |
def run_trial(self, trial, *fit_args, test_ds=None, **fit_kwargs): | |
# Wandb init has been called here | |
fit_kwargs['callbacks'].append(wandb.keras.WandbCallback()) | |
super(CustomTuner, self).run_trial(trial, *fit_args, **fit_kwargs) | |
# Create second model and get metrics | |
# ... model.evaluate() | |
metrics = {'accuracy': random.random(), 'metric2': 5, 'metric3': 300} | |
# Alert the oracle of the last 'true' value of the accuracy metric | |
self.oracle.update_trial(trial.trial_id, {'accuracy': metrics['accuracy']}, step=EPOCHS+1) # model_history.epoch[-1]+1 | |
# I need to tell the oracle which step it is, otherwise it puts it at step 0 | |
wandb.config.update({'new_param': random.random()}) | |
wandb.run.summary['best_accuracy'] = metrics['accuracy'] | |
# Logging | |
wandb.log(metrics) | |
wandb.join() | |
# In[5]: | |
PROJECT_NAME = 'test-project' | |
NB_MODELS = 5 | |
EPOCHS = 4 | |
# In[6]: | |
tuner = CustomTuner( | |
oracle=kt.oracles.RandomSearch( | |
objective=kt.Objective('accuracy', 'max'), | |
max_trials=NB_MODELS), | |
hypermodel=create_model, | |
logger=WandbLogger(), | |
directory='.kt/', | |
project_name=PROJECT_NAME | |
) | |
# In[7]: | |
ds_X, ds_y = create_datasets() | |
# In[8]: | |
tuner.search(x=ds_X, y=ds_y, epochs=EPOCHS, callbacks=[tf.keras.callbacks.TensorBoard()] | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment