Skip to content

Instantly share code, notes, and snippets.

@1pha
Created December 1, 2020 01:58
Show Gist options
  • Save 1pha/0ed5e4f76ccb9f7556a44500759ffc4d to your computer and use it in GitHub Desktop.
Save 1pha/0ed5e4f76ccb9f7556a44500759ffc4d to your computer and use it in GitHub Desktop.
from sklearn.model_selection import KFold
kfold = KFold(n_splits=10)
for i, (trn_idx, val_idx) in enumerate(kfold.split(X_train)):
print("Working on {}th Fold".format(i))
train_multi_ds = TensorDataset(X_train[trn_idx], y_multi_train[trn_idx])
val_multi_ds = TensorDataset(X_train[val_idx], y_multi_train[val_idx])
train_multi_loader = DataLoader(train_multi_ds, batch_size=128, shuffle=True)
val_multi_loader = DataLoader(val_multi_ds, batch_size=128, shuffle=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment