Skip to content

Instantly share code, notes, and snippets.

@omartinez182
Created September 10, 2020 19:39
Show Gist options
  • Save omartinez182/4e4401df7f41e550309cd3539d884566 to your computer and use it in GitHub Desktop.
Save omartinez182/4e4401df7f41e550309cd3539d884566 to your computer and use it in GitHub Desktop.
Snippet #4 - Nested Cross-Validation Article
# Loop for each round
for i in range(rounds):
#Define both cross-validation objects (inner & outer)
inner_cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=i)
outer_cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=i)
# Non-nested parameter search and scoring
clf = GridSearchCV(estimator=rf, param_grid=rf_param_grid, cv=inner_cv)
clf.fit(X, y)
outer_scores[i] = clf.best_score_
# Nested CV with parameter optimization
nested_score = cross_val_score(clf, X=X, y=y, cv=outer_cv)
nested_scores[i] = nested_score.mean()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment