Skip to content

Instantly share code, notes, and snippets.

@oliver-batey
Last active March 22, 2022 11:48
Show Gist options
  • Save oliver-batey/c1ce37d220d5d99f89fa726339cf5328 to your computer and use it in GitHub Desktop.
Save oliver-batey/c1ce37d220d5d99f89fa726339cf5328 to your computer and use it in GitHub Desktop.
Building a sampling distribution
import pandas as pd
from sklearn.metrics import (
accuracy_score,
auc,
precision_score,
recall_score,
roc_curve,
)
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.pipeline import Pipeline
from typing import List
def evaluate_model(
data: pd.dataFrame,
pipeline: Pipeline,
features: List[str],
target: str,
test_size: float = 0.3,
):
# Train-test split with no random seed, so different
# Sets are generated each time function is called.
X, y = data[features], data[target]
split = StratifiedShuffleSplit(n_splits=1, test_size=test_size)
for train_index, test_index in split.split(X, y):
X_train, X_test = X.loc[train_index], X.loc[test_index]
y_train, y_test = y.loc[train_index], y.loc[test_index]
# Train model, fit method overwrites training on each call
pipeline.fit(X_train, y_train)
# Evaluate on test set
test_predictions = pipeline.predict(X_test)
recall = recall_score(y_test, test_predictions)
precision = precision_score(y_test, test_predictions)
acc = accuracy_score(y_test, test_predictions)
# Calculate fpr and tpr for each threshold to get an AUROC
test_probabilities = pipeline.predict_proba(X_test)[:, 1]
fpr, tpr, thresholds = roc_curve(
y_test, test_probabilities, pos_label=1, drop_intermediate=True
)
auroc = auc(fpr, tpr)
return recall, precision, acc, auroc
# Call evaluate_model in a loop and record results
n_splits = 100
Recall = np.zeros(n_splits)
Precision = np.zeros(n_splits)
Accuracy = np.zeros(n_splits)
AUC = np.zeros(n_splits)
for j in range(n_splits):
recall, precision, auroc, acc = evaluate_model(data, pipeline, features, target)
Recall[j] = recall
Precision[j] = precision
Accuracy[j] = acc
AUC[j] = auroc
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment