Skip to content

Instantly share code, notes, and snippets.

@patientzero
Created October 10, 2019 14:48
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save patientzero/51d4e01902a3c68ee3bfac1fb6b0ecbd to your computer and use it in GitHub Desktop.
Save patientzero/51d4e01902a3c68ee3bfac1fb6b0ecbd to your computer and use it in GitHub Desktop.
Example for the usage of sklearn GroupKFold
from sklearn.model_selection import GroupKFold
# define number of splits
n_splits = 10
# all data in a list
pics = list(data_dir.glob('**/*.png'))
# matching labels in list
labels = [pic.parent.stem for pic in pics]
# get all groups that should not be in the
groups = [pic.stem.split('_')[1] for pic in pics]
gkf = GroupKFold(n_splits=n_splits)
# Iterate over splits:
for train, test in gkf.split(pics, labels, groups=groups):
# train & test contain indices for pics/labels
# do training here
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment