Skip to content

Instantly share code, notes, and snippets.

@ogrisel
Last active June 24, 2020 14:30
Show Gist options
  • Save ogrisel/bfadf2dffef144d401b6e7c52d744219 to your computer and use it in GitHub Desktop.
Save ogrisel/bfadf2dffef144d401b6e7c52d744219 to your computer and use it in GitHub Desktop.
import numpy as np
import pytest
from sklearn.datasets import load_breast_cancer
from sklearn.utils import shuffle
from sklearn.model_selection import train_test_split
from sklearn.model_selection import GridSearchCV
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_auc_score, roc_curve
X, y = load_breast_cancer(return_X_y=True)
# create an highly imbalanced dataset
idx_positive = np.flatnonzero(y == 1)
idx_negative = np.flatnonzero(y == 0)
idx_selected = np.hstack([idx_negative, idx_positive[:25]])
X, y = X[idx_selected], y[idx_selected]
X, y = shuffle(X, y, random_state=42)
# only use 2 features to make the problem even harder
X = X[:, :2]
y = np.array(
["cancer" if c == 1 else "not cancer" for c in y], dtype=object
)
X_train, X_test, y_train, y_test = train_test_split(
X, y, stratify=y, random_state=0,
)
classifier = LogisticRegression()
classifier.fit(X_train, y_train)
# # sanity check to be sure the positive class is classes_[0] and that we
# # are betrayed by the class imbalance
# assert classifier.classes_.tolist() == ["cancer", "not cancer"]
# pos_label = "cancer"
# pos_idx = classifier.classes_.tolist().index(pos_label)
# y_pred = classifier.predict_proba(X_test)
# # y_pred = classifier.decision_function(X_test)
# if y_pred.ndim == 2:
# # predict_proba
# y_pred = y_pred[:, pos_idx]
# else:
# # decision_function
# if pos_idx == 0:
# y_pred *= -1
# fpr, tpr, _ = roc_curve(y_test, y_pred, pos_label=pos_label)
# roc_auc = roc_auc_score(y_test, y_pred, pos_label=pos_label)
# assert roc_auc == pytest.approx(np.trapz(tpr, fpr))
gs = GridSearchCV(LogisticRegression(),
param_grid={'C': [1e-3, 1, 1e3]},
scoring="roc_auc",
cv=5)
gs.fit(X_train, y_train)
assert gs.best_score_ > 0.9
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment