Created
October 10, 2019 14:48
-
-
Save patientzero/51d4e01902a3c68ee3bfac1fb6b0ecbd to your computer and use it in GitHub Desktop.
Example for the usage of sklearn GroupKFold
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sklearn.model_selection import GroupKFold | |
# define number of splits | |
n_splits = 10 | |
# all data in a list | |
pics = list(data_dir.glob('**/*.png')) | |
# matching labels in list | |
labels = [pic.parent.stem for pic in pics] | |
# get all groups that should not be in the | |
groups = [pic.stem.split('_')[1] for pic in pics] | |
gkf = GroupKFold(n_splits=n_splits) | |
# Iterate over splits: | |
for train, test in gkf.split(pics, labels, groups=groups): | |
# train & test contain indices for pics/labels | |
# do training here |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment