`generate_learning_curve_plots()` from SKLL but using `catplot()`
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def generate_learning_curve_plots(experiment_name, | |
output_dir, | |
learning_curve_tsv_file): | |
""" | |
Generate the learning curve plots given the TSV output | |
file from a learning curve experiment. | |
Parameters | |
---------- | |
experiment_name : str | |
The name of the experiment. | |
output_dir : str | |
Path to the output directory for the plots. | |
learning_curve_tsv_file : str | |
The path to the learning curve TSV file. | |
""" | |
# use pandas to read in the TSV file into a data frame | |
# and massage it from wide to long format for plotting | |
df = pd.read_csv(learning_curve_tsv_file, sep='\t') | |
num_learners = len(df['learner_name'].unique()) | |
num_metrics = len(df['metric'].unique()) | |
df_melted = pd.melt(df, id_vars=[c for c in df.columns | |
if c not in ['train_score_mean', 'test_score_mean']]) | |
# make sure the "variable" column is categorical since it will be | |
# mapped to hue levels in the learning curve below | |
df_melted["variable"] = df_melted["variable"].astype("category") | |
# if there are any training sizes greater than 1000, | |
# then we should probably rotate the tick labels | |
# since otherwise the labels are not clearly rendered | |
rotate_labels = np.any([size >= 1000 for size in df['training_set_size'].unique()]) | |
# set up and draw the actual learning curve figures, one for | |
# each of the featuresets | |
for fs_name, df_fs in df_melted.groupby('featureset_name'): | |
fig = plt.figure() | |
fig.set_size_inches(2.5 * num_learners, 2.5 * num_metrics) | |
# compute ylimits for this feature set for each objective | |
with sns.axes_style('whitegrid', {"grid.linestyle": ':', | |
"xtick.major.size": 3.0}): | |
train_color, test_color = sns.color_palette(palette="Set1", n_colors=2) | |
g = sns.catplot(data=df_fs, row="metric", col="learner_name", | |
x="training_set_size", y="value", hue="variable", | |
kind="point", height=2.5, aspect=1, margin_titles=True, | |
sharex=False, sharey=False, legend_out=False, | |
scale=.5, errorbar=None, | |
palette={"train_score_mean": train_color, | |
"test_score_mean": test_color}) | |
ylimits = _compute_ylimits_for_featureset(df_fs, g.row_names) | |
for ax in g.axes.flat: | |
plt.setp(ax.texts, text="") | |
g = (g.set_titles(row_template='', col_template='{col_name}') | |
.set_axis_labels('Training Examples', 'Score')) | |
if rotate_labels: | |
g = g.set_xticklabels(rotation=60) | |
for i, row_name in enumerate(g.row_names): | |
for j, col_name in enumerate(g.col_names): | |
ax = g.axes[i][j] | |
ax.set(ylim=ylimits[row_name]) | |
df_ax_train = df_fs[(df_fs['learner_name'] == col_name) & | |
(df_fs['metric'] == row_name) & | |
(df_fs['variable'] == 'train_score_mean')] | |
df_ax_test = df_fs[(df_fs['learner_name'] == col_name) & | |
(df_fs['metric'] == row_name) & | |
(df_fs['variable'] == 'test_score_mean')] | |
ax.fill_between(list(range(len(df_ax_train))), | |
df_ax_train['value'] - df_ax_train['train_score_std'], | |
df_ax_train['value'] + df_ax_train['train_score_std'], | |
alpha=0.1, | |
color=train_color) | |
ax.fill_between(list(range(len(df_ax_test))), | |
df_ax_test['value'] - df_ax_test['test_score_std'], | |
df_ax_test['value'] + df_ax_test['test_score_std'], | |
alpha=0.1, | |
color=test_color) | |
if j == 0: | |
ax.set_ylabel(row_name) | |
if i == 0: | |
# set up the legend handles for this plot | |
plot_handles = [matplotlib.lines.Line2D([], | |
[], | |
color=c, | |
label=l, | |
linestyle='-') | |
for c, l in zip([train_color, test_color], | |
['Training', 'Cross-validation'])] | |
ax.legend(handles=plot_handles, | |
loc=4, | |
fancybox=True, | |
fontsize='x-small', | |
ncol=1, | |
frameon=True) | |
g.fig.tight_layout(w_pad=1) | |
plt.savefig(join(output_dir, f'{experiment_name}_{fs_name}.png'), | |
dpi=300) | |
# explicitly close figure to save memory | |
plt.close(fig) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment