Skip to content

Instantly share code, notes, and snippets.

@shashankprasanna
Created July 13, 2020 23:40
Show Gist options
  • Save shashankprasanna/42fe4bf3903a4ca735dc260df6efca18 to your computer and use it in GitHub Desktop.
Save shashankprasanna/42fe4bf3903a4ca735dc260df6efca18 to your computer and use it in GitHub Desktop.
for trial_hyp in trial_hyperparameter_set:
# Combine static hyperparameters and trial specific hyperparameters
hyperparams = {**static_hyperparams, **trial_hyp}
# Create unique job name with hyperparameter and time
time_append = int(time.time())
hyp_append = "-".join([str(elm) for elm in trial_hyp.values()])
job_name = f'cifar10-training-{hyp_append}-{time_append}'
# Create a Tracker to track Trial specific hyperparameters
with Tracker.create(display_name=f"trial-metadata-{time_append}",
artifact_bucket=bucket_name,
artifact_prefix=f"{training_experiment.experiment_name}/{job_name}",
sagemaker_boto_client=sm) as trial_tracker:
trial_tracker.log_parameters(hyperparams)
# Create a new Trial and associate Tracker to it
tf_trial = Trial.create(
trial_name = f'trial-{hyp_append}-{time_append}',
experiment_name = training_experiment.experiment_name,
sagemaker_boto_client = sm)
tf_trial.add_trial_component(exp_tracker.trial_component)
time.sleep(2) #To prevent ThrottlingException
tf_trial.add_trial_component(trial_tracker.trial_component)
# Create an experiment config that associates training job to the Trial
experiment_config = {"ExperimentName" : training_experiment.experiment_name,
"TrialName" : tf_trial.trial_name,
"TrialComponentDisplayName": job_name}
metric_definitions = [{'Name': 'loss', 'Regex': 'loss: ([0-9\\.]+)'},
{'Name': 'acc', 'Regex': 'acc: ([0-9\\.]+)'},
{'Name': 'val_loss', 'Regex': 'val_loss: ([0-9\\.]+)'},
{'Name': 'val_acc', 'Regex': 'val_acc: ([0-9\\.]+)'},
{'Name': 'test_acc', 'Regex': 'test_acc: ([0-9\\.]+)'},
{'Name': 'test_loss', 'Regex': 'test_loss: ([0-9\\.]+)'}]
# Create a TensorFlow Estimator with the Trial specific hyperparameters
tf_estimator = TensorFlow(entry_point = 'cifar10-training-sagemaker.py',
source_dir = 'code',
output_path = f's3://{bucket_name}/{training_experiment.experiment_name}/',
code_location = f's3://{bucket_name}/{training_experiment.experiment_name}',
role = role,
train_instance_count = 1,
train_instance_type = 'ml.p3.2xlarge',
framework_version = '1.15',
py_version = 'py3',
script_mode = True,
metric_definitions = metric_definitions,
sagemaker_session = sagemaker_session,
hyperparameters = hyperparams,
enable_sagemaker_metrics = True)
# Launch a training job
tf_estimator.fit({'training' : datasets,
'validation': datasets,
'eval' : datasets},
job_name = job_name,
wait = False,
experiment_config = experiment_config)
time.sleep(3) #To prevent ThrottlingException
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment