Skip to content

Instantly share code, notes, and snippets.

@ArthurDelannoyazerty
Created May 17, 2024 09:48
Show Gist options
  • Save ArthurDelannoyazerty/c7cfc9e4438d1708203c29931f486219 to your computer and use it in GitHub Desktop.
Save ArthurDelannoyazerty/c7cfc9e4438d1708203c29931f486219 to your computer and use it in GitHub Desktop.
Torch split train test dataset.
def create_data_loader(dirpath:str, test_split:float=0.1, shuffle_dataset:bool=True):
"""Create a dataset and split+shuffle it."""
dataset = CustomImageDatasetNumpy(dirpath)
dataset_size = len(dataset)
indices = list(range(dataset_size))
split = int(np.floor(test_split * dataset_size))
if shuffle_dataset: np.random.shuffle(indices)
train_indices, val_indices = indices[split:], indices[:split]
train_sampler = SubsetRandomSampler(train_indices)
test_sampler = SubsetRandomSampler(val_indices)
train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=train_sampler)
test_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=test_sampler)
return train_loader, test_loader
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment