Created
August 30, 2022 15:31
-
-
Save andrea-dagostino/d6a5ffd62825bb6da5861d237709df25 to your computer and use it in GitHub Desktop.
cross val eng
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sklearn.model_selection import KFold | |
from sklearn.ensemble import RandomForestClassifier | |
from sklearn import datasets | |
from sklearn import metrics | |
# create a dataset for a classification task | |
X, y = datasets.make_classification(n_samples=2000, n_features=20, n_classes=2, random_state=42) | |
# create the KFold object using sturges law | |
sturges = int(1 + np.log(len(X))) | |
kf = KFold(n_splits=sturges, shuffle=True, random_state=42) | |
fold = 0 | |
aucs = [] | |
# CROSS-VALIDATION LOOP | |
for train_idx, val_idx, in kf.split(X, y): | |
# the kf object generates the indeces and values for the respective Xs and Ys | |
X_tr = X[train_idx] | |
y_tr = y[train_idx] | |
X_val = X[val_idx] | |
y_val = y[val_idx] | |
# ---- | |
# apply here our logics | |
# ... | |
# ---- | |
# train the model | |
clf = RandomForestClassifier(n_estimators=100) | |
clf.fit(X_tr, y_tr) | |
# create the predictions and save the scores | |
pred = clf.predict(X_val) | |
pred_prob = clf.predict_proba(X_val)[:, 1] | |
acc_score = metrics.accuracy_score(y_val, pred) | |
auc_score = metrics.roc_auc_score(y_val, pred_prob) | |
print(f"======= Fold {fold} ========") | |
print( | |
f"Accuracy on the validation set is {acc_score:0.4f} and AUC is {auc_score:0.4f}" | |
) | |
# update the folds | |
fold += 1 | |
aucs.append(auc_score) | |
general_auc_score = np.mean(aucs) | |
print(f'\nOur out of fold AUC score is {general_auc_score:0.4f}') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment