Skip to content

Instantly share code, notes, and snippets.

@gabraganca
Created April 5, 2019 12:44
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save gabraganca/720f70c2f5fa8857150d67325a3abb38 to your computer and use it in GitHub Desktop.
Save gabraganca/720f70c2f5fa8857150d67325a3abb38 to your computer and use it in GitHub Desktop.
# Obtém os resultados do Grid Search
df_results = pd.DataFrame.from_dict(gs.cv_results)
df_results.columns = df_results.columns.str.replace('param_','')
# Grafica os mapas de calor
n_epochs = len(param_grid['n_epochs'])
fig, axes = plt.subplots(nrows=n_epochs, ncols=3, figsize=(22, 6*n_epochs))
for ax_row, n_epoch in zip(axes, param_grid['n_epochs']):
for ax, metric in zip(ax_row, ['mae', 'rmse', 'time']):
parameter = f'mean_test_{metric}' if metric != 'time' else f'mean_fit_{metric}'
ax = sns.heatmap(
df_results.query(f'n_epochs =={n_epoch}')\
.pivot_table(columns='n_factors', index='lr_all', values=parameter),
annot=True,
fmt='0.4f',
vmin= df_results[parameter].min(),
vmax= df_results[parameter].max(),
ax=ax,
cmap='viridis'
)
metric = metric.capitalize() if metric == 'time' else metric.upper()
ax.set_title(f'# Epochs: {n_epoch} | metric: {metric}')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment