Created
February 12, 2024 19:30
-
-
Save mocquin/cc06e0753bf74aff2cd177e2b751b8a2 to your computer and use it in GitHub Desktop.
regression2.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
%matplotlib qt | |
import numpy as np | |
import pandas as pd | |
import seaborn as sns | |
from sklearn.datasets import make_regression | |
from sklearn.dummy import DummyRegressor | |
# Create a synthetic dataset for regression | |
X, y = make_regression(n_samples=100, n_features=1, noise=10, random_state=42) | |
CONSTANT = 30 | |
QUANTILE = 0.2 | |
STRATEGIES = ['mean', 'median', 'quantile', 'constant'] | |
def plot_dummy_strategy_regressor_1d( | |
X, y, target_name="target", | |
): | |
dfs = [] # list of df | |
# df for input and ground true | |
df_true = pd.DataFrame({target_name: y.flatten()}) | |
df_true['strategy'] = "ground_true" | |
df_true['x'] = X[:] | |
dfs.append(df_true) | |
# add df for each strategy | |
for i, strategy in enumerate(STRATEGIES): | |
dummy = DummyRegressor(strategy=strategy, constant=CONSTANT, quantile=QUANTILE) # Constant value for the 'constant' strategy | |
dummy.fit(X, y) | |
preds = dummy.predict(X) | |
df = pd.DataFrame({target_name: preds.flatten()}) | |
if strategy=="constant": | |
strategy=f"constant(={CONSTANT})" | |
elif strategy=='quantile': | |
strategy=f"quantile(={QUANTILE})" | |
strategy+= f"($R^2$={dummy.score(X_test, y_test):.2f})" | |
df['strategy'] = strategy | |
df['x']=X[:] | |
dfs.append(df) | |
# concatenate all dataframe | |
df = pd.concat(dfs) | |
# plot them all, colored by strategy | |
fg = sns.relplot(data=df, kind='scatter', x='x', y=target_name, hue='strategy') | |
return fg, df | |
fg_1d, df_regressor_1d = plot_dummy_strategy_regressor_1d(X, y, target_name="target") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment