Skip to content

Instantly share code, notes, and snippets.

@shinichi-takayanagi
Last active November 12, 2021 22:27
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save shinichi-takayanagi/eabdfe25a7bab90dfb6b65bcc68d4244 to your computer and use it in GitHub Desktop.
Save shinichi-takayanagi/eabdfe25a7bab90dfb6b65bcc68d4244 to your computer and use it in GitHub Desktop.
Use a Calibrated Model with scikit-learn
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.datasets import make_classification
# See: https://scikit-learn.org/stable/modules/generated/sklearn.calibration.CalibratedClassifierCV.html
from sklearn.calibration import CalibratedClassifierCV, calibration_curve
# Dummy data (numpy.array)
X, y = make_classification(n_samples=1000, n_classes=2, weights=[1,1], random_state=1)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, random_state=2)
# Make caliibrated model
model = LogisticRegression(random_state=71)
calibrated = CalibratedClassifierCV(model, method='sigmoid', cv=5)
calibrated.fit(X_train, y_train)
# predict probabilities
probs = calibrated.predict_proba(X_test)[:, 1]
# The plot to checkit
from matplotlib import pyplot
# reliability diagram
fop, mpv = calibration_curve(y_test, probs, n_bins=10, normalize=True)
# plot diagonal line
pyplot.plot([0, 1], [0, 1], linestyle='--')
# plot calibrated reliability
pyplot.plot(mpv, fop, marker='.')
pyplot.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment