Skip to content

Instantly share code, notes, and snippets.

@ClementC
Created June 8, 2021 20:49
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ClementC/1a33a936329b3142807e976ccc1e0e62 to your computer and use it in GitHub Desktop.
Save ClementC/1a33a936329b3142807e976ccc1e0e62 to your computer and use it in GitHub Desktop.
Small Python snippet to generate an illustration of the overfitting phenomenon
import numpy as np
import matplotlib.pyplot as plt
import warnings
np.random.seed(1) # This line is here to get a nice plot on the first try, you can comment it
warnings.simplefilter('ignore', np.RankWarning)
%matplotlib inline
# Main parameters
x_noise = 0.2
y_noise = 0.3
x_min, x_max = -2, 2
# Generate the training and testing data with noise (the true model is y = x ** 2)
train_x = [elem + x_noise * np.random.normal() for elem in range(x_min, x_max + 1)]
test_x = (x_max - x_min + 1) * np.random.rand(8) + x_min
train_y = [elem**2 + x_noise * np.random.normal() for elem in train_x]
test_y = [elem**2 + y_noise * np.random.normal() for elem in test_x]
# Compute a perfect polynomial fit
p = np.poly1d(np.polyfit(train_x, train_y, 1 + len(train_x)))
# Plot everything
fig, ax = plt.subplots(figsize=(12, 8))
ax.scatter(train_x, train_y, color="blue", edgecolor="k", s=80, alpha=0.5, label="Training set")
ax.scatter(test_x, test_y, color="green", edgecolor="k", s=80, alpha=0.5, label="Testing set")
x_ = np.linspace(x_min - 1, x_max + 1, 100)
ax.plot(x_, x_ ** 2, color="blue", alpha=0.5, label="Best fit")
ax.plot(x_, p(x_), color="red", alpha=0.5, label="Overfitted model")
ax.hlines(np.mean(train_y), x_min - 1, x_max + 1, color="orange", alpha=0.5, label="Underfitted model")
plt.figtext(0.1, 0.9, '$y$')
ax.set_xlabel(r"$x$") #, loc="right")
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.set_ylim(-1, (x_max + 1) ** 2 + 0.5)
ax.set_xlim(x_min - 1.5, x_max + 1.5)
plt.tick_params(left=False, right=False, labelleft=False,
labelbottom=False, bottom=False)
ax.legend(loc=0, frameon=False);
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment