Skip to content

Instantly share code, notes, and snippets.

@ovuruska
Created May 1, 2021 17:16
Show Gist options
  • Save ovuruska/736b1842df864b23a4c2e57108734397 to your computer and use it in GitHub Desktop.
Save ovuruska/736b1842df864b23a4c2e57108734397 to your computer and use it in GitHub Desktop.
Train-test split using PyTorch Dataloader API.
from torch.utils.data import Subset,DataLoader
from sklearn.model_selection import train_test_split
def train_val_dataset(dataset, val_split=0.2,batch_size=16):
train_idx, val_idx = train_test_split(list(range(len(dataset))), test_size=val_split)
datasets = {}
datasets['train'] = Subset(dataset, train_idx)
datasets['val'] = Subset(dataset, val_idx)
dataloaders = {x: DataLoader(datasets[x], batch_size=batch_size,num_workers=8,shuffle=True,drop_last=True,persistent_workers=False) for x in ['train', 'val']}
return dataloaders["train"],dataloaders["val"]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment