Skip to content

Instantly share code, notes, and snippets.

@drscotthawley
Last active January 25, 2021 15:33
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save drscotthawley/fb69c3dac447c9c31db79a6bc8d7fd22 to your computer and use it in GitHub Desktop.
Save drscotthawley/fb69c3dac447c9c31db79a6bc8d7fd22 to your computer and use it in GitHub Desktop.
Swaps Validation set with a section of Training set, given a value for k
# In case you didn't think to add k-fold cross-validation until late in your
# ML project,...
# This is built for a situation where datasets are arrays of, say, images.
def kfold_swap(train_X, train_Y, val_X, val_Y, k):
"""
Swaps val with a section of train, given a value for k
"Duct tape" approach used to "retro-fit" k-fold cross-validation while minimally
disturbing the rest of the code, while avoiding reloading data from disk and
keeping RAM use manageable. (e.g. np.append() is bad b/c it would copy all of train)
"Not-quite in-place" swapping means only a val-sized section of train gets duplicated in storage.
For 80-20 train/val split, k can run from 0 to 4 (= 5-fold cross-val)
For 80-10-10 train/val/test split, k can run from 0 to 8 (= 9-fold cross-val)
For 70-15-15 train/val/test split, exceeding k=4 will give you a failed assertion
"""
if k > 0: # k=0 means do nothing
vl = val_X.shape[0]
# sanity checks: make sure sizes are ok
assert train_X.shape[0] == train_Y.shape[0]
assert val_X.shape[0] == val_Y.shape[0]
assert k*vl <= train_X.shape[0]
bgn, end = (k-1)*vl, k*vl # minus sign is from choice that k=0 is no-op
val_X, train_X[bgn:end,:,:] = (train_X[bgn:end,:,:]).copy(), val_X
val_Y, train_Y[bgn:end,:,:] = (train_Y[bgn:end,:,:]).copy(), val_Y
return train_X, train_Y, val_X, val_Y
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment