Skip to content

Instantly share code, notes, and snippets.

@amysteier
Last active December 8, 2021 18:41
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save amysteier/21dc1c45d3476810f5f2a143609c1906 to your computer and use it in GitHub Desktop.
Save amysteier/21dc1c45d3476810f5f2a143609c1906 to your computer and use it in GitHub Desktop.
Example Optuna objective function
import time
import optuna
from gretel_client import projects
from gretel_client import create_project
from gretel_client.config import RunnerMode
from gretel_client.helpers import poll
def objective(trial: optuna.Trial):
# Set which hyperparameters you want to tune
config['models'][0]['synthetics']['params']['rnn_units'] = trial.suggest_int(name="rnn_units", low=64, high=1024, step=64)
config['models'][0]['synthetics']['params']['dropout_rate'] = trial.suggest_float("dropout_rate", .1, .75)
config['models'][0]['synthetics']['params']['gen_temp'] = trial.suggest_float("gen_temp", .8, 1.2)
config['models'][0]['synthetics']['params']['learning_rate'] = trial.suggest_float("learning_rate", .0005, 0.01, step=.0005)
config['models'][0]['synthetics']['params']['reset_states'] = trial.suggest_categorical(
"reset_states", choices=[True, False])
config['models'][0]['synthetics']['params']['vocab_size'] = trial.suggest_categorical(
"vocab_size", choices=[0, 5000, 10000, 15000, 20000, 25000, 30000, 35000, 40000, 45000, 50000])
# Create a new Gretel project
seconds = int(time.time())
project_name = "Tuning Experiment" + str(seconds)
project = create_project(display_name=project_name)
# Create a Gretel synthetic model
model = project.create_model_obj(model_config=config)
model.data_source = dataset
model.submit(upload_data_source=True)
# Watch for model completion
status = "active"
sqs = 0
while ((status == "active") or (status == "pending")):
#Sleep a bit here
time.sleep(60)
model._poll_job_endpoint()
status = model.__dict__['_data']['model']['status']
if status == "completed":
report = model.peek_report()
if report:
sqs = report['synthetic_data_quality_score']['score']
else:
sqs = 0
elif status == "error":
sqs = 0
project.delete()
# Return the model Synthetic Quality Score
return sqs
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment