Skip to content

Instantly share code, notes, and snippets.

@t-vi
Last active August 18, 2017 16:08
Show Gist options
  • Star 7 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save t-vi/9f6118ff84867e89f3348707c7a1271f to your computer and use it in GitHub Desktop.
Save t-vi/9f6118ff84867e89f3348707c7a1271f to your computer and use it in GitHub Desktop.
Torch validation set split (MNIST example)
import torch.utils.data
from torchvision import datasets, transforms
class PartialDataset(torch.utils.data.Dataset):
def __init__(self, parent_ds, offset, length):
self.parent_ds = parent_ds
self.offset = offset
self.length = length
assert len(parent_ds)>=offset+length, Exception("Parent Dataset not long enough")
super(PartialDataset, self).__init__()
def __len__(self):
return self.length
def __getitem__(self, i):
return self.parent_ds[i+self.offset]
def validation_split(dataset, val_share=0.1):
"""
Split a (training and vaidation combined) dataset into training and validation.
Note that to be statistically sound, the items in the dataset should be statistically
independent (e.g. not sorted by class, not several instances of the same dataset that
could end up in either set).
inputs:
dataset: ("training") dataset to split into training and validation
val_share: fraction of validation data (should be 0<val_share<1, default: 0.1)
returns: input dataset split into test_ds, val_ds
"""
val_offset = int(len(dataset)*(1-val_share))
return PartialDataset(dataset, 0, val_offset), PartialDataset(dataset, val_offset, len(dataset)-val_offset)
mnist_train_ds = datasets.MNIST(os.path.expanduser('~/data/datasets/mnist'), train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
]))
train_ds, val_ds = validation_split(mnist_train_ds)
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=64, shuffle=True, **kwargs)
val_loader = torch.utils.data.DataLoader(val_ds, batch_size=64, shuffle=True, **kwargs)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment