-
-
Save kevinzakka/d33bf8d6c7f06a9d8c76d97a7879f5cb to your computer and use it in GitHub Desktop.
""" | |
Create train, valid, test iterators for CIFAR-10 [1]. | |
Easily extended to MNIST, CIFAR-100 and Imagenet. | |
[1]: https://discuss.pytorch.org/t/feedback-on-pytorch-for-kaggle-competitions/2252/4 | |
""" | |
import torch | |
import numpy as np | |
from utils import plot_images | |
from torchvision import datasets | |
from torchvision import transforms | |
from torch.utils.data.sampler import SubsetRandomSampler | |
def get_train_valid_loader(data_dir, | |
batch_size, | |
augment, | |
random_seed, | |
valid_size=0.1, | |
shuffle=True, | |
show_sample=False, | |
num_workers=4, | |
pin_memory=False): | |
""" | |
Utility function for loading and returning train and valid | |
multi-process iterators over the CIFAR-10 dataset. A sample | |
9x9 grid of the images can be optionally displayed. | |
If using CUDA, num_workers should be set to 1 and pin_memory to True. | |
Params | |
------ | |
- data_dir: path directory to the dataset. | |
- batch_size: how many samples per batch to load. | |
- augment: whether to apply the data augmentation scheme | |
mentioned in the paper. Only applied on the train split. | |
- random_seed: fix seed for reproducibility. | |
- valid_size: percentage split of the training set used for | |
the validation set. Should be a float in the range [0, 1]. | |
- shuffle: whether to shuffle the train/validation indices. | |
- show_sample: plot 9x9 sample grid of the dataset. | |
- num_workers: number of subprocesses to use when loading the dataset. | |
- pin_memory: whether to copy tensors into CUDA pinned memory. Set it to | |
True if using GPU. | |
Returns | |
------- | |
- train_loader: training set iterator. | |
- valid_loader: validation set iterator. | |
""" | |
error_msg = "[!] valid_size should be in the range [0, 1]." | |
assert ((valid_size >= 0) and (valid_size <= 1)), error_msg | |
normalize = transforms.Normalize( | |
mean=[0.4914, 0.4822, 0.4465], | |
std=[0.2023, 0.1994, 0.2010], | |
) | |
# define transforms | |
valid_transform = transforms.Compose([ | |
transforms.ToTensor(), | |
normalize, | |
]) | |
if augment: | |
train_transform = transforms.Compose([ | |
transforms.RandomCrop(32, padding=4), | |
transforms.RandomHorizontalFlip(), | |
transforms.ToTensor(), | |
normalize, | |
]) | |
else: | |
train_transform = transforms.Compose([ | |
transforms.ToTensor(), | |
normalize, | |
]) | |
# load the dataset | |
train_dataset = datasets.CIFAR10( | |
root=data_dir, train=True, | |
download=True, transform=train_transform, | |
) | |
valid_dataset = datasets.CIFAR10( | |
root=data_dir, train=True, | |
download=True, transform=valid_transform, | |
) | |
num_train = len(train_dataset) | |
indices = list(range(num_train)) | |
split = int(np.floor(valid_size * num_train)) | |
if shuffle: | |
np.random.seed(random_seed) | |
np.random.shuffle(indices) | |
train_idx, valid_idx = indices[split:], indices[:split] | |
train_sampler = SubsetRandomSampler(train_idx) | |
valid_sampler = SubsetRandomSampler(valid_idx) | |
train_loader = torch.utils.data.DataLoader( | |
train_dataset, batch_size=batch_size, sampler=train_sampler, | |
num_workers=num_workers, pin_memory=pin_memory, | |
) | |
valid_loader = torch.utils.data.DataLoader( | |
valid_dataset, batch_size=batch_size, sampler=valid_sampler, | |
num_workers=num_workers, pin_memory=pin_memory, | |
) | |
# visualize some images | |
if show_sample: | |
sample_loader = torch.utils.data.DataLoader( | |
train_dataset, batch_size=9, shuffle=shuffle, | |
num_workers=num_workers, pin_memory=pin_memory, | |
) | |
data_iter = iter(sample_loader) | |
images, labels = data_iter.next() | |
X = images.numpy().transpose([0, 2, 3, 1]) | |
plot_images(X, labels) | |
return (train_loader, valid_loader) | |
def get_test_loader(data_dir, | |
batch_size, | |
shuffle=True, | |
num_workers=4, | |
pin_memory=False): | |
""" | |
Utility function for loading and returning a multi-process | |
test iterator over the CIFAR-10 dataset. | |
If using CUDA, num_workers should be set to 1 and pin_memory to True. | |
Params | |
------ | |
- data_dir: path directory to the dataset. | |
- batch_size: how many samples per batch to load. | |
- shuffle: whether to shuffle the dataset after every epoch. | |
- num_workers: number of subprocesses to use when loading the dataset. | |
- pin_memory: whether to copy tensors into CUDA pinned memory. Set it to | |
True if using GPU. | |
Returns | |
------- | |
- data_loader: test set iterator. | |
""" | |
normalize = transforms.Normalize( | |
mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225], | |
) | |
# define transform | |
transform = transforms.Compose([ | |
transforms.ToTensor(), | |
normalize, | |
]) | |
dataset = datasets.CIFAR10( | |
root=data_dir, train=False, | |
download=True, transform=transform, | |
) | |
data_loader = torch.utils.data.DataLoader( | |
dataset, batch_size=batch_size, shuffle=shuffle, | |
num_workers=num_workers, pin_memory=pin_memory, | |
) | |
return data_loader |
import matplotlib.pyplot as plt | |
label_names = [ | |
'airplane', | |
'automobile', | |
'bird', | |
'cat', | |
'deer', | |
'dog', | |
'frog', | |
'horse', | |
'ship', | |
'truck' | |
] | |
def plot_images(images, cls_true, cls_pred=None): | |
""" | |
Adapted from https://github.com/Hvass-Labs/TensorFlow-Tutorials/ | |
""" | |
fig, axes = plt.subplots(3, 3) | |
for i, ax in enumerate(axes.flat): | |
# plot img | |
ax.imshow(images[i, :, :, :], interpolation='spline16') | |
# show true & predicted classes | |
cls_true_name = label_names[cls_true[i]] | |
if cls_pred is None: | |
xlabel = "{0} ({1})".format(cls_true_name, cls_true[i]) | |
else: | |
cls_pred_name = label_names[cls_pred[i]] | |
xlabel = "True: {0}\nPred: {1}".format( | |
cls_true_name, cls_pred_name | |
) | |
ax.set_xlabel(xlabel) | |
ax.set_xticks([]) | |
ax.set_yticks([]) | |
plt.show() |
@sytelus the validation data is taken from the training set. The test set is untouched at all times.
@kevinzakka
Hey Kevin and thanks for the gist.
I had a quick question about the valid_loader. How do you make sure that the validation sampler sweeps all the samples in the validation set exactly once? My understanding is that it takes batches of the provided indices randomly! so if we execute
for images, labels in valid_loader: ...
to for example compute the loss and accuracy over the validation (feed batch by batch and average), it will not do it correctly as it doesn't sweep the whole set once. Am I correct?
@amobiny I think you have sampler
and dataloader
confused. The dataloader
traverses the entire data set in batches. It selects the samples from the batch using the sampler
. The sampler
can be sequential so say for a batch of 4 and a dataset of size 32 you'd have [0, 1, 2, 3]
, [4, 5, 6, 7]
, etc until [28, 29, 30, 31]
. In our case, the sampler
is random and without replacement, in which case you'd have possibly something like [17, 1, 12, 31]
, [2, 8, 18, 28]
, etc. that would still cover the whole validation set. Does that make sense?
why does train & val not have same statistics usually for normalizing?
also the pytorch tutorials use 0.5 as opposte to:
test:
normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
)
and
train
normalize = transforms.Normalize(
mean=[0.4914, 0.4822, 0.4465],
std=[0.2023, 0.1994, 0.2010],
)
why?
why get_train_valid_loader() return None-Type ?
also the pytorch tutorials use 0.5 as opposte to:
test:
normalize = transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], )
and
trainnormalize = transforms.Normalize( mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010], )
why?
I can't speak the choice of transform used here, but from my own testing I will say that the transform applied to the train set should be the same as that of the test set. Prior to doing this, I was getting inconsistent accuracies on the test set when compared to the validation set. I chose to set both to
normalize = transforms.Normalize(
mean=[0.4914, 0.4822, 0.4465],
std=[0.2023, 0.1994, 0.2010],
)
You need train=False in below line:
https://gist.github.com/kevinzakka/d33bf8d6c7f06a9d8c76d97a7879f5cb#file-data_loader-py-L86