Skip to content

Instantly share code, notes, and snippets.

@desilinguist
Created September 8, 2022 23:13
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save desilinguist/f2998789053286a83ab4585706efe1ed to your computer and use it in GitHub Desktop.
Save desilinguist/f2998789053286a83ab4585706efe1ed to your computer and use it in GitHub Desktop.
`generate_learning_curve_plots()` from SKLL but using `catplot()`
def generate_learning_curve_plots(experiment_name,
output_dir,
learning_curve_tsv_file):
"""
Generate the learning curve plots given the TSV output
file from a learning curve experiment.
Parameters
----------
experiment_name : str
The name of the experiment.
output_dir : str
Path to the output directory for the plots.
learning_curve_tsv_file : str
The path to the learning curve TSV file.
"""
# use pandas to read in the TSV file into a data frame
# and massage it from wide to long format for plotting
df = pd.read_csv(learning_curve_tsv_file, sep='\t')
num_learners = len(df['learner_name'].unique())
num_metrics = len(df['metric'].unique())
df_melted = pd.melt(df, id_vars=[c for c in df.columns
if c not in ['train_score_mean', 'test_score_mean']])
# make sure the "variable" column is categorical since it will be
# mapped to hue levels in the learning curve below
df_melted["variable"] = df_melted["variable"].astype("category")
# if there are any training sizes greater than 1000,
# then we should probably rotate the tick labels
# since otherwise the labels are not clearly rendered
rotate_labels = np.any([size >= 1000 for size in df['training_set_size'].unique()])
# set up and draw the actual learning curve figures, one for
# each of the featuresets
for fs_name, df_fs in df_melted.groupby('featureset_name'):
fig = plt.figure()
fig.set_size_inches(2.5 * num_learners, 2.5 * num_metrics)
# compute ylimits for this feature set for each objective
with sns.axes_style('whitegrid', {"grid.linestyle": ':',
"xtick.major.size": 3.0}):
train_color, test_color = sns.color_palette(palette="Set1", n_colors=2)
g = sns.catplot(data=df_fs, row="metric", col="learner_name",
x="training_set_size", y="value", hue="variable",
kind="point", height=2.5, aspect=1, margin_titles=True,
sharex=False, sharey=False, legend_out=False,
scale=.5, errorbar=None,
palette={"train_score_mean": train_color,
"test_score_mean": test_color})
ylimits = _compute_ylimits_for_featureset(df_fs, g.row_names)
for ax in g.axes.flat:
plt.setp(ax.texts, text="")
g = (g.set_titles(row_template='', col_template='{col_name}')
.set_axis_labels('Training Examples', 'Score'))
if rotate_labels:
g = g.set_xticklabels(rotation=60)
for i, row_name in enumerate(g.row_names):
for j, col_name in enumerate(g.col_names):
ax = g.axes[i][j]
ax.set(ylim=ylimits[row_name])
df_ax_train = df_fs[(df_fs['learner_name'] == col_name) &
(df_fs['metric'] == row_name) &
(df_fs['variable'] == 'train_score_mean')]
df_ax_test = df_fs[(df_fs['learner_name'] == col_name) &
(df_fs['metric'] == row_name) &
(df_fs['variable'] == 'test_score_mean')]
ax.fill_between(list(range(len(df_ax_train))),
df_ax_train['value'] - df_ax_train['train_score_std'],
df_ax_train['value'] + df_ax_train['train_score_std'],
alpha=0.1,
color=train_color)
ax.fill_between(list(range(len(df_ax_test))),
df_ax_test['value'] - df_ax_test['test_score_std'],
df_ax_test['value'] + df_ax_test['test_score_std'],
alpha=0.1,
color=test_color)
if j == 0:
ax.set_ylabel(row_name)
if i == 0:
# set up the legend handles for this plot
plot_handles = [matplotlib.lines.Line2D([],
[],
color=c,
label=l,
linestyle='-')
for c, l in zip([train_color, test_color],
['Training', 'Cross-validation'])]
ax.legend(handles=plot_handles,
loc=4,
fancybox=True,
fontsize='x-small',
ncol=1,
frameon=True)
g.fig.tight_layout(w_pad=1)
plt.savefig(join(output_dir, f'{experiment_name}_{fs_name}.png'),
dpi=300)
# explicitly close figure to save memory
plt.close(fig)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment