Last active
October 24, 2020 17:47
-
-
Save adrialuzllompart/c916c4ce3782a98ab5c92fe82ce0d293 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def plot_learning_curves(estimator, X_train, y_train, X_val, y_val, | |
suptitle='', title='', xlabel='', ylabel=''): | |
""" | |
Plots learning curves for a given estimator. | |
Parameters | |
---------- | |
estimator : sklearn estimator | |
X_train : pd.DataFrame | |
training set (features) | |
y_train : pd.Series | |
training set (response) | |
X_val : pd.DataFrame | |
validation set (features) | |
y_val : pd.Series | |
validation set (response) | |
suptitle : str | |
Chart suptitle | |
title: str | |
Chart title | |
xlabel: str | |
Label for the X axis | |
ylabel: str | |
Label for the y axis | |
Returns | |
------- | |
Plot of learning curves | |
""" | |
# create lists to store train and validation scores | |
train_score = [] | |
val_score = [] | |
# create ten incremental training set sizes | |
training_set_sizes = np.linspace(5, len(X_train), 10, dtype='int') | |
# for each one of those training set sizes | |
for i in training_set_sizes: | |
# fit the model only using that many training examples | |
estimator.fit(X_train.iloc[0:i, :], y_train.iloc[0:i]) | |
# calculate the training accuracy only using those training examples | |
train_accuracy = metrics.accuracy_score( | |
y_train.iloc[0:i], | |
estimator.predict(X_train.iloc[0:i, :]) | |
) | |
# calculate the validation accuracy using the whole validation set | |
val_accuracy = metrics.accuracy_score( | |
y_val, | |
estimator.predict(X_val) | |
) | |
# store the scores in their respective lists | |
train_score.append(train_accuracy) | |
val_score.append(val_accuracy) | |
# plot learning curves | |
fig, ax = plt.subplots(figsize=(14, 9)) | |
ax.plot(training_set_sizes, train_score, c='gold') | |
ax.plot(training_set_sizes, val_score, c='steelblue') | |
# format the chart to make it look nice | |
fig.suptitle(suptitle, fontweight='bold', fontsize='20') | |
ax.set_title(title, size=20) | |
ax.set_xlabel(xlabel, size=16) | |
ax.set_ylabel(ylabel, size=16) | |
ax.legend(['training set', 'validation set'], fontsize=16) | |
ax.tick_params(axis='both', labelsize=12) | |
ax.set_ylim(0, 1) | |
def percentages(x, pos): | |
"""The two args are the value and tick position""" | |
if x < 1: | |
return '{:1.0f}'.format(x*100) | |
return '{:1.0f}%'.format(x*100) | |
def numbers(x, pos): | |
"""The two args are the value and tick position""" | |
if x >= 1000: | |
return '{:1,.0f}'.format(x) | |
return '{:1.0f}'.format(x) | |
y_formatter = FuncFormatter(percentages) | |
ax.yaxis.set_major_formatter(y_formatter) | |
x_formatter = FuncFormatter(numbers) | |
ax.xaxis.set_major_formatter(x_formatter) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment