Skip to content

Instantly share code, notes, and snippets.

@mocquin
Created February 12, 2024 19:17
Show Gist options
  • Save mocquin/905f8b9eae15f6c000ccca2a8e7a242c to your computer and use it in GitHub Desktop.
Save mocquin/905f8b9eae15f6c000ccca2a8e7a242c to your computer and use it in GitHub Desktop.
regression1.py
%matplotlib qt
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.datasets import make_regression
from sklearn.dummy import DummyRegressor
# We first create a toy dataset, with 100 samples and a single feature
X, y = make_regression(n_samples=100, n_features=1, noise=10, random_state=0)
# Split the dataset into train/test sets
X_train, X_test, y_train, y_test = train_test_split(X, y)
# Create a dummy regressor with the mean strategy
dummy = DummyRegressor(strategy="mean")
# Fit the dummy model with the training set: althout we MUST pass both X_train
# and y_train, the model only uses the values in
dummy.fit(X_train, y_train)
# Compute prediction: as it is a DummyRegressor with "mean", all the values
# in the y_pred vector will be equal to the mean of y_train
y_pred = dummy.predict(X_test)
score = dummy.score(X_test, y_test)
# Let's visualize how the model behaves
fig, ax = plt.subplots()
ax.scatter(X_test, y_test, label="y_test")
ax.scatter(X_test, y_pred, label="y_pred")
ax.set_title(r"$\text{DummyRegressor(strategy='mean')}$" + f": $R^2$={score:.2f}")
ax.legend()
fig.tight_layout()
# Say we want to evaluate our linear regression model and compare it to the
# dummy baseline
from sklearn.linear_model import LinearRegression
model = LinearRegression()
model.fit(X_train, y_train)
fig, ax = plt.subplots()
ax.scatter(X_test, y_test, label="Ground truth")
ax.scatter(X_test, dummy.predict(X_test), label="DummyRegressor:$R^2$="+f"{dummy.score(X_test, y_test):.2f}")
ax.scatter(X_test, model.predict(X_test), label="LinearRegressor:$R^2$="+f"{model.score(X_test, y_test):.2f}")
ax.set_title("Comparison between our model and a dummy model")
ax.legend()
fig.tight_layout()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment