Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Train, Validation and Test Split for torchvision Datasets
"""
Create train, valid, test iterators for CIFAR-10 [1].
Easily extended to MNIST, CIFAR-100 and Imagenet.
[1]: https://discuss.pytorch.org/t/feedback-on-pytorch-for-kaggle-competitions/2252/4
"""
import torch
import numpy as np
from utils import plot_images
from torchvision import datasets
from torchvision import transforms
from torch.utils.data.sampler import SubsetRandomSampler
def get_train_valid_loader(data_dir,
batch_size,
augment,
random_seed,
valid_size=0.1,
shuffle=True,
show_sample=False,
num_workers=4,
pin_memory=False):
"""
Utility function for loading and returning train and valid
multi-process iterators over the CIFAR-10 dataset. A sample
9x9 grid of the images can be optionally displayed.
If using CUDA, num_workers should be set to 1 and pin_memory to True.
Params
------
- data_dir: path directory to the dataset.
- batch_size: how many samples per batch to load.
- augment: whether to apply the data augmentation scheme
mentioned in the paper. Only applied on the train split.
- random_seed: fix seed for reproducibility.
- valid_size: percentage split of the training set used for
the validation set. Should be a float in the range [0, 1].
- shuffle: whether to shuffle the train/validation indices.
- show_sample: plot 9x9 sample grid of the dataset.
- num_workers: number of subprocesses to use when loading the dataset.
- pin_memory: whether to copy tensors into CUDA pinned memory. Set it to
True if using GPU.
Returns
-------
- train_loader: training set iterator.
- valid_loader: validation set iterator.
"""
error_msg = "[!] valid_size should be in the range [0, 1]."
assert ((valid_size >= 0) and (valid_size <= 1)), error_msg
normalize = transforms.Normalize(
mean=[0.4914, 0.4822, 0.4465],
std=[0.2023, 0.1994, 0.2010],
)
# define transforms
valid_transform = transforms.Compose([
transforms.ToTensor(),
normalize,
])
if augment:
train_transform = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
])
else:
train_transform = transforms.Compose([
transforms.ToTensor(),
normalize,
])
# load the dataset
train_dataset = datasets.CIFAR10(
root=data_dir, train=True,
download=True, transform=train_transform,
)
valid_dataset = datasets.CIFAR10(
root=data_dir, train=True,
download=True, transform=valid_transform,
)
num_train = len(train_dataset)
indices = list(range(num_train))
split = int(np.floor(valid_size * num_train))
if shuffle:
np.random.seed(random_seed)
np.random.shuffle(indices)
train_idx, valid_idx = indices[split:], indices[:split]
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size, sampler=train_sampler,
num_workers=num_workers, pin_memory=pin_memory,
)
valid_loader = torch.utils.data.DataLoader(
valid_dataset, batch_size=batch_size, sampler=valid_sampler,
num_workers=num_workers, pin_memory=pin_memory,
)
# visualize some images
if show_sample:
sample_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=9, shuffle=shuffle,
num_workers=num_workers, pin_memory=pin_memory,
)
data_iter = iter(sample_loader)
images, labels = data_iter.next()
X = images.numpy().transpose([0, 2, 3, 1])
plot_images(X, labels)
return (train_loader, valid_loader)
def get_test_loader(data_dir,
batch_size,
shuffle=True,
num_workers=4,
pin_memory=False):
"""
Utility function for loading and returning a multi-process
test iterator over the CIFAR-10 dataset.
If using CUDA, num_workers should be set to 1 and pin_memory to True.
Params
------
- data_dir: path directory to the dataset.
- batch_size: how many samples per batch to load.
- shuffle: whether to shuffle the dataset after every epoch.
- num_workers: number of subprocesses to use when loading the dataset.
- pin_memory: whether to copy tensors into CUDA pinned memory. Set it to
True if using GPU.
Returns
-------
- data_loader: test set iterator.
"""
normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
)
# define transform
transform = transforms.Compose([
transforms.ToTensor(),
normalize,
])
dataset = datasets.CIFAR10(
root=data_dir, train=False,
download=True, transform=transform,
)
data_loader = torch.utils.data.DataLoader(
dataset, batch_size=batch_size, shuffle=shuffle,
num_workers=num_workers, pin_memory=pin_memory,
)
return data_loader
import matplotlib.pyplot as plt
label_names = [
'airplane',
'automobile',
'bird',
'cat',
'deer',
'dog',
'frog',
'horse',
'ship',
'truck'
]
def plot_images(images, cls_true, cls_pred=None):
"""
Adapted from https://github.com/Hvass-Labs/TensorFlow-Tutorials/
"""
fig, axes = plt.subplots(3, 3)
for i, ax in enumerate(axes.flat):
# plot img
ax.imshow(images[i, :, :, :], interpolation='spline16')
# show true & predicted classes
cls_true_name = label_names[cls_true[i]]
if cls_pred is None:
xlabel = "{0} ({1})".format(cls_true_name, cls_true[i])
else:
cls_pred_name = label_names[cls_pred[i]]
xlabel = "True: {0}\nPred: {1}".format(
cls_true_name, cls_pred_name
)
ax.set_xlabel(xlabel)
ax.set_xticks([])
ax.set_yticks([])
plt.show()
@ajwitty

This comment has been minimized.

Copy link

@ajwitty ajwitty commented Oct 20, 2017

Why load the dataset twice into 'train_dataset' and 'valid_dataset'?

@kevinzakka

This comment has been minimized.

Copy link
Owner Author

@kevinzakka kevinzakka commented Oct 22, 2017

@ajwitty train and valid might not always have the same transformations

@MattKleinsmith

This comment has been minimized.

Copy link

@MattKleinsmith MattKleinsmith commented Oct 27, 2017

If using CUDA, num_workers should be set to 1

@kevinzakka Why?

I searched for discussions and documentation about the relationship between using GPUs and setting PyTorch's num_workers, but couldn't find any.

Also, thank you for writing this gist.

@krishvishal

This comment has been minimized.

Copy link

@krishvishal krishvishal commented Nov 14, 2017

Hey, @kevinzakka can you please tell me how to use your script ? Should I copy paste it in my script or import it in my script? What are the modules of torch should I import ? I'm getting errors. Please help.

@MattKleinsmith

This comment has been minimized.

Copy link

@MattKleinsmith MattKleinsmith commented Dec 11, 2017

@krishnavishalv

import numpy as np
import torch
from torchvision import datasets, transforms
from torch.utils.data.sampler import SubsetRandomSampler
@wanglouis49

This comment has been minimized.

Copy link

@wanglouis49 wanglouis49 commented Jan 16, 2018

Hi @kevinzakka, so for the train_loader and test_loader, shuffle has to be False according to the Pytorch documentation on DataLoader. Does that mean in your way we have to sacrifice shuffling during training?

@nasyxx

This comment has been minimized.

Copy link

@nasyxx nasyxx commented Jan 20, 2018

Hi, in my opinion, the normalize should be optional, considering the mean/std in other datasets is not the same as yours (mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225]), though ideally mean/std would not be too different from it, not to mention that we still have batch norm.

@kevinzakka

This comment has been minimized.

Copy link
Owner Author

@kevinzakka kevinzakka commented Jan 25, 2018

@wanglouis49 it actually does not because we use SubsetRandomSampler and according to the documentation: "Samples elements randomly from a given list of indices, without replacement."

@songkangsg

This comment has been minimized.

Copy link

@songkangsg songkangsg commented May 3, 2018

Isn't it pointless to set a fixed random seed? It does help to generate the same order of indices for splitting the training set and validation set. But the SubsetRandomSampler does not use the seed, thus each batch sampled for training will be different every time.

@kevinzakka

This comment has been minimized.

Copy link
Owner Author

@kevinzakka kevinzakka commented May 3, 2018

Isn't it pointless to set a fixed random seed? It does help to generate the same order of indices for splitting the training set and validation set. But the SubsetRandomSampler does not use the seed, thus each batch sampled for training will be different every time.

@songkangsg I'm setting the seed exactly for that purpose: to have the same validation set all the time. I don't care about the order in which I receive the validation images. The goal is to compute a mean validation accuracy and loss.

@sunkevin1214

This comment has been minimized.

Copy link

@sunkevin1214 sunkevin1214 commented May 14, 2018

The mean and std you adopted in this script are for ImageNet not CIFAR10 or CIFAR100

@kevinzakka

This comment has been minimized.

Copy link
Owner Author

@kevinzakka kevinzakka commented May 17, 2018

@sunkevin1214 nice catch! Fixed it now.

@tan1889

This comment has been minimized.

Copy link

@tan1889 tan1889 commented Jun 10, 2018

Using this I have len(train_loader.dataset) = len(val_loader.dataset)=60000, which is wrong.

@kevinzakka

This comment has been minimized.

Copy link
Owner Author

@kevinzakka kevinzakka commented Jun 14, 2018

@tan1889 that's because they both use the same underlying dataset, but a different sampler. You need to do len(train_loader.sampler) instead.

@alyato

This comment has been minimized.

Copy link

@alyato alyato commented Jun 20, 2018

@kevinzakka
im trying the pytorch firstly.
i used to use the keras and the dataset has 3 parts , train,valid,test.
but when i check the https://github.com/pytorch/examples/blob/master/mnist/main.py, it has train function and test function .
I cannot find the valid_dataset,only the train_loader and test_loader
So i think that the valid_dataset doesn't to exist.
It confused me now.
Do you give me some explainations? thanks

@huangchaoxing

This comment has been minimized.

Copy link

@huangchaoxing huangchaoxing commented Sep 5, 2018

The normalisation should only be done on the training set.But here the normalization is on the whole set. It should be a problem

@kevinzakka

This comment has been minimized.

Copy link
Owner Author

@kevinzakka kevinzakka commented Sep 5, 2018

@huangchaoxing validation and test sets should be normalized with train set statistics.

@sytelus

This comment has been minimized.

@kevinzakka

This comment has been minimized.

Copy link
Owner Author

@kevinzakka kevinzakka commented Sep 7, 2018

@sytelus the validation data is taken from the training set. The test set is untouched at all times.

@amobiny

This comment has been minimized.

Copy link

@amobiny amobiny commented May 20, 2019

@kevinzakka
Hey Kevin and thanks for the gist.
I had a quick question about the valid_loader. How do you make sure that the validation sampler sweeps all the samples in the validation set exactly once? My understanding is that it takes batches of the provided indices randomly! so if we execute

for images, labels in valid_loader: ...

to for example compute the loss and accuracy over the validation (feed batch by batch and average), it will not do it correctly as it doesn't sweep the whole set once. Am I correct?

@kevinzakka

This comment has been minimized.

Copy link
Owner Author

@kevinzakka kevinzakka commented May 20, 2019

@amobiny I think you have sampler and dataloader confused. The dataloader traverses the entire data set in batches. It selects the samples from the batch using the sampler. The sampler can be sequential so say for a batch of 4 and a dataset of size 32 you'd have [0, 1, 2, 3], [4, 5, 6, 7], etc until [28, 29, 30, 31]. In our case, the sampler is random and without replacement, in which case you'd have possibly something like [17, 1, 12, 31], [2, 8, 18, 28], etc. that would still cover the whole validation set. Does that make sense?

@brando90

This comment has been minimized.

Copy link

@brando90 brando90 commented Jul 27, 2019

why does train & val not have same statistics usually for normalizing?

@brando90

This comment has been minimized.

Copy link

@brando90 brando90 commented Jul 27, 2019

also the pytorch tutorials use 0.5 as opposte to:

test:

    normalize = transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    )

and
train

    normalize = transforms.Normalize(
        mean=[0.4914, 0.4822, 0.4465],
        std=[0.2023, 0.1994, 0.2010],
    )

why?

@RITCHIEHuang

This comment has been minimized.

Copy link

@RITCHIEHuang RITCHIEHuang commented Aug 12, 2019

why get_train_valid_loader() return None-Type ?

@phelps-matthew

This comment has been minimized.

Copy link

@phelps-matthew phelps-matthew commented Oct 11, 2020

also the pytorch tutorials use 0.5 as opposte to:

test:

    normalize = transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    )

and
train

    normalize = transforms.Normalize(
        mean=[0.4914, 0.4822, 0.4465],
        std=[0.2023, 0.1994, 0.2010],
    )

why?

I can't speak the choice of transform used here, but from my own testing I will say that the transform applied to the train set should be the same as that of the test set. Prior to doing this, I was getting inconsistent accuracies on the test set when compared to the validation set. I chose to set both to

    normalize = transforms.Normalize(
        mean=[0.4914, 0.4822, 0.4465],
        std=[0.2023, 0.1994, 0.2010],
    )
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment