Skip to content

Instantly share code, notes, and snippets.

@mocquin
Created February 12, 2024 18:32
Show Gist options
  • Save mocquin/773ee5f97b6f925bb8b7474c29c18081 to your computer and use it in GitHub Desktop.
Save mocquin/773ee5f97b6f925bb8b7474c29c18081 to your computer and use it in GitHub Desktop.
classif1.py
%matplotlib qt
from sklearn.dummy import DummyClassifier
import pandas as pd
import seaborn as sns
# Load the penguins dataset from seaborn
penguins_df = sns.load_dataset("penguins")
# Split the dataset into features (X) and target variable (y)
X = penguins_df[['bill_length_mm', 'bill_depth_mm', 'flipper_length_mm', 'body_mass_g']]
y = penguins_df['species']
# review quickly the populations
print(y.value_counts(normalize=True))
def plot_dummy_strategy_classifier(
X, y, target_name="species",
strategies=['most_frequent', 'prior', 'stratified', 'uniform', "constant"],
):
# Note that we use stratify=True to keep the class distribution both in the
# train set and test set
X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y)
dfs = []
dfs_proba = []
df_true = pd.DataFrame(pd.Series(y_test, name="species"))
df_true['strategy']="ground_truth"
dfs.append(df_true)
for i, strategy in enumerate(strategies):
dummy = DummyClassifier(strategy=strategy, constant="Chinstrap")
dummy.fit(X_train, y_train)
df = pd.DataFrame(pd.Series(dummy.predict(X_test), name=target_name))
df['strategy']=strategy
dfs.append(df)
df = pd.DataFrame(dummy.predict_proba(X_test), columns=["Adelie", "Chinstrap", "Gentoo"])
df['strategy']=strategy
dfs_proba.append(df)
df = pd.concat(dfs)
df_proba = pd.concat(dfs_proba)
return df, df_proba
df, df_proba = plot_dummy_strategy_classifier(X, y)
sns.catplot(
data=df, kind='count', x="species", col='strategy', hue="species",
hue_order=["Adelie", "Gentoo", "Chinstrap"], order=["Adelie", "Gentoo", "Chinstrap"], # so the order is reproducible
)
fig, axes = plt.subplots(1, df_proba['strategy'].nunique())
cbar_ax = fig.add_axes([0.91, 0.15, 0.02, 0.7]) # Adjust these values to position the colorbar as desired
for ax, strat in zip(axes, df_proba['strategy'].unique()):
sns.heatmap(df_proba.query('strategy==@strat')[["Adelie", "Gentoo", "Chinstrap"]], ax=ax, cbar=False, vmin=0, vmax=1, yticklabels=False)
ax.set_title(f'strategy = {strat}')
fig.colorbar(ax.collections[0], cax=cbar_ax)
sns.displot(df_proba, col="strategy", hue="species"
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment