Last active
July 6, 2023 21:14
-
-
Save Hehehe421/e6a17cf9aa2a2b5d289a4dec9049c92f to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#3. Report the learning curve | |
import matplotlib.pyplot as plt | |
import numpy as np | |
from sklearn.model_selection import learning_curve | |
def plot_learning_curve(estimator, title, X, y, ylim=None, cv=None, n_jobs=1, train_sizes=np.linspace(.05, 1., 20), verbose=0, plot=True): | |
# Compute learning curve | |
train_sizes, train_scores, test_scores = learning_curve(estimator, X, y, cv=cv, n_jobs=n_jobs, train_sizes=train_sizes, verbose=verbose) | |
# Compute mean and standard deviation of training and test scores | |
train_scores_mean, train_scores_std = np.mean(train_scores, axis=1), np.std(train_scores, axis=1) | |
test_scores_mean, test_scores_std = np.mean(test_scores, axis=1), np.std(test_scores, axis=1) | |
if plot: | |
# Create a new figure | |
plt.figure() | |
# Set the title of the plot | |
plt.title(title) | |
# Set the y-axis limits if provided | |
if ylim is not None: | |
plt.ylim(*ylim) | |
# Set the labels for x and y axes | |
plt.xlabel("Training examples") | |
plt.ylabel("Score") | |
# Invert the y-axis | |
plt.gca().invert_yaxis() | |
# Add grid lines to the plot | |
plt.grid() | |
# Fill the area between the training score curves with a light blue color | |
plt.fill_between(train_sizes, train_scores_mean - train_scores_std, train_scores_mean + train_scores_std, alpha=0.1, color="b") | |
# Fill the area between the test score curves with a light red color | |
plt.fill_between(train_sizes, test_scores_mean - test_scores_std, test_scores_mean + test_scores_std, alpha=0.1, color="r") | |
# Plot the training score curve as blue circles | |
plt.plot(train_sizes, train_scores_mean, 'o-', color="b", label="Training score") | |
# Plot the test score curve as red circles | |
plt.plot(train_sizes, test_scores_mean, 'o-', color="r", label="Cross-validation score") | |
# Add a legend to the plot | |
plt.legend(loc="best") | |
# Show the plot | |
plt.gca().invert_yaxis() | |
plt.show() | |
# Compute the midpoint and difference between training and test scores | |
midpoint = (train_scores_mean[-1] + train_scores_std[-1] + test_scores_mean[-1] - test_scores_std[-1]) / 2 | |
diff = train_scores_mean[-1] + train_scores_std[-1] - (test_scores_mean[-1] - test_scores_std[-1]) | |
# Return the midpoint and difference | |
return midpoint, diff |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment