Skip to content

Instantly share code, notes, and snippets.

@adrialuzllompart
Last active October 24, 2020 17:47
Show Gist options
  • Save adrialuzllompart/c916c4ce3782a98ab5c92fe82ce0d293 to your computer and use it in GitHub Desktop.
Save adrialuzllompart/c916c4ce3782a98ab5c92fe82ce0d293 to your computer and use it in GitHub Desktop.
def plot_learning_curves(estimator, X_train, y_train, X_val, y_val,
suptitle='', title='', xlabel='', ylabel=''):
"""
Plots learning curves for a given estimator.
Parameters
----------
estimator : sklearn estimator
X_train : pd.DataFrame
training set (features)
y_train : pd.Series
training set (response)
X_val : pd.DataFrame
validation set (features)
y_val : pd.Series
validation set (response)
suptitle : str
Chart suptitle
title: str
Chart title
xlabel: str
Label for the X axis
ylabel: str
Label for the y axis
Returns
-------
Plot of learning curves
"""
# create lists to store train and validation scores
train_score = []
val_score = []
# create ten incremental training set sizes
training_set_sizes = np.linspace(5, len(X_train), 10, dtype='int')
# for each one of those training set sizes
for i in training_set_sizes:
# fit the model only using that many training examples
estimator.fit(X_train.iloc[0:i, :], y_train.iloc[0:i])
# calculate the training accuracy only using those training examples
train_accuracy = metrics.accuracy_score(
y_train.iloc[0:i],
estimator.predict(X_train.iloc[0:i, :])
)
# calculate the validation accuracy using the whole validation set
val_accuracy = metrics.accuracy_score(
y_val,
estimator.predict(X_val)
)
# store the scores in their respective lists
train_score.append(train_accuracy)
val_score.append(val_accuracy)
# plot learning curves
fig, ax = plt.subplots(figsize=(14, 9))
ax.plot(training_set_sizes, train_score, c='gold')
ax.plot(training_set_sizes, val_score, c='steelblue')
# format the chart to make it look nice
fig.suptitle(suptitle, fontweight='bold', fontsize='20')
ax.set_title(title, size=20)
ax.set_xlabel(xlabel, size=16)
ax.set_ylabel(ylabel, size=16)
ax.legend(['training set', 'validation set'], fontsize=16)
ax.tick_params(axis='both', labelsize=12)
ax.set_ylim(0, 1)
def percentages(x, pos):
"""The two args are the value and tick position"""
if x < 1:
return '{:1.0f}'.format(x*100)
return '{:1.0f}%'.format(x*100)
def numbers(x, pos):
"""The two args are the value and tick position"""
if x >= 1000:
return '{:1,.0f}'.format(x)
return '{:1.0f}'.format(x)
y_formatter = FuncFormatter(percentages)
ax.yaxis.set_major_formatter(y_formatter)
x_formatter = FuncFormatter(numbers)
ax.xaxis.set_major_formatter(x_formatter)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment