Skip to content

Instantly share code, notes, and snippets.

@shashankprasanna
Created July 13, 2020 23:53
Show Gist options
  • Save shashankprasanna/dfdcea9164399eeaf723f03ee91a610d to your computer and use it in GitHub Desktop.
Save shashankprasanna/dfdcea9164399eeaf723f03ee91a610d to your computer and use it in GitHub Desktop.
from smdebug.trials import create_trial
def tensor_df(tname):
tval = trial.tensor(tname).values()
df = pd.DataFrame.from_dict(tval,orient='index',columns=[tname])
df_tval = df.reset_index().rename(columns={'index':'steps'})
return df_tval
def trial_perf_curves(job_name, tname, experiment_name):
debug_data = f's3://{bucket_name}/{experiment_name}/{job_name}/debug-output'
trial = create_trial(debug_data)
tval = trial.tensor(tname).values()
df = pd.DataFrame.from_dict(tval,orient='index',columns=[tname])
return df
def get_metric_dataframe(metric, trial_comp_ds, experiment_name):
df = pd.DataFrame()
for tc_name in trial_comp_ds['DisplayName']:
print(f'\nLoading training job: {tc_name}')
print(f'--------------------------------\n')
trial_perf = trial_perf_curves(tc_name, metric, experiment_name)
trial_perf.columns = [tc_name]
df = pd.concat([df, trial_perf],axis=1)
return df
val_acc_df = get_metric_dataframe('val_acc', trial_comp_ds_jobs, experiment_name)
fig = plt.figure()
fig.set_size_inches([15, 10])
# Replace the Trial names with the ones you want to plot, or remove indexing to plot all jobs
val_acc_df[['cifar10-training-adam-custom-120-1594536575','cifar10-training-adam-custom-60-1594536571','cifar10-training-rmsprop-custom-30-1594536622']].plot(style='-',ax=plt.gca())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment