Skip to content

Instantly share code, notes, and snippets.

@StephenFordham
Created October 23, 2023 09:51
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save StephenFordham/732ffee7f8feab82ef7a8d23504cf0ce to your computer and use it in GitHub Desktop.
Save StephenFordham/732ffee7f8feab82ef7a8d23504cf0ce to your computer and use it in GitHub Desktop.
Precision-recall curve for a skillful model
import matplotlib.pyplot as plt
from sklearn.datasets import make_classification
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_recall_curve, auc
import seaborn as sns
X, y = make_classification(n_samples=1000, n_classes=2, random_state=1)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=2)
model = LogisticRegression(solver='lbfgs')
model.fit(X_train, y_train)
lr_proba = model.predict_proba(X_test)
lr_proba = lr_proba[:, 1]
yhat = model.predict(X_test)
lr_precision, lr_recall, thresholds = precision_recall_curve(y_test, lr_proba)
lr_auc = auc(lr_recall, lr_precision)
round(lr_auc, 3)
# Console output:
# AUC = 0.927
no_skill = len(y_test[y_test==1]) / len(y_test)
# Plotting model comparison
plt.figure(figsize=(10, 6))
ax = sns.lineplot(x=lr_recall, y=lr_precision, estimator=None,
color='#31D19A',linewidth=2.0,
label='Skillful model')
ax2 = sns.lineplot(x=[0,1], y=no_skill, estimator=None,
color='#EE472A',linewidth=2.0, linestyle="--",
label='No Skill')
ax.set_ylabel('Precision', fontname='Ubuntu', fontsize=12)
ax.set_xlabel('Recall', fontname='Ubuntu', fontsize=12)
ax.set_title('Precision-recall Curve for a skillful model',size=14, pad=30, fontname='Ubuntu', weight='bold')
sns.despine(top=True, right=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment