Created
May 17, 2024 09:48
-
-
Save ArthurDelannoyazerty/c7cfc9e4438d1708203c29931f486219 to your computer and use it in GitHub Desktop.
Torch split train test dataset.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def create_data_loader(dirpath:str, test_split:float=0.1, shuffle_dataset:bool=True): | |
"""Create a dataset and split+shuffle it.""" | |
dataset = CustomImageDatasetNumpy(dirpath) | |
dataset_size = len(dataset) | |
indices = list(range(dataset_size)) | |
split = int(np.floor(test_split * dataset_size)) | |
if shuffle_dataset: np.random.shuffle(indices) | |
train_indices, val_indices = indices[split:], indices[:split] | |
train_sampler = SubsetRandomSampler(train_indices) | |
test_sampler = SubsetRandomSampler(val_indices) | |
train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=train_sampler) | |
test_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=test_sampler) | |
return train_loader, test_loader |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment