Skip to content

Instantly share code, notes, and snippets.

@budui
Created April 26, 2019 10:56
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save budui/6fdad7fd4b6dabcbdcc0ae140230de02 to your computer and use it in GitHub Desktop.
Save budui/6fdad7fd4b6dabcbdcc0ae140230de02 to your computer and use it in GitHub Desktop.
random split torch dataset into `train` and `val` dataset
train_size = int(0.8 * len(full_dataset))
test_size = len(full_dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size])
## or
dataset = MyCustomDataset(my_path)
batch_size = 16
validation_split = .2
shuffle_dataset = True
random_seed= 42
# Creating data indices for training and validation splits:
dataset_size = len(dataset)
indices = list(range(dataset_size))
split = int(np.floor(validation_split * dataset_size))
if shuffle_dataset :
np.random.seed(random_seed)
np.random.shuffle(indices)
train_indices, val_indices = indices[split:], indices[:split]
# Creating PT data samplers and loaders:
train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(val_indices)
train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
sampler=train_sampler)
validation_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
sampler=valid_sampler)
# Usage Example:
num_epochs = 10
for epoch in range(num_epochs):
# Train:
for batch_index, (faces, labels) in enumerate(train_loader):
# ...
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment