Skip to content

Instantly share code, notes, and snippets.

@ground0state
Created September 2, 2020 17:26
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ground0state/823df52437336eaf1c47adcb629e687b to your computer and use it in GitHub Desktop.
Save ground0state/823df52437336eaf1c47adcb629e687b to your computer and use it in GitHub Desktop.
class LayeredFoldWrapper(Dataset):
def __init__(self, dataset, n_splits=5, fold=0, valid=False):
self.dataset = dataset
self.n_splits = n_splits
self.fold = fold
self.valid = valid
self.valid_index = list(self._valid_index(len(dataset), n_splits, fold))
self.train_index = list(set(range(len(dataset))) - set(self.valid_index))
def __len__(self):
return len(self._get_index_list(self.valid))
def __getitem__(self, i):
return self.dataset.__getitem__(self._get_index_list(self.valid)[i])
def _valid_index(self, N, n_splits, fold):
"""
N: 全データの数
n_splits: foldのスプリットの数
fold: 各foldを指定する値 0<=fold<=n_splits-1
"""
assert(0<=fold<=n_splits-1)
return range(n_splits - fold - 1, N+1, n_splits)
def _get_index_list(self, valid):
if valid:
return self.valid_index
else:
return self.train_index
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment