Skip to content

Instantly share code, notes, and snippets.

@skrish13
Forked from kevinzakka/data_loader.py
Created August 8, 2017 20:58
Show Gist options
  • Save skrish13/e356823ed80e042f7aeb778c64d75f40 to your computer and use it in GitHub Desktop.
Save skrish13/e356823ed80e042f7aeb778c64d75f40 to your computer and use it in GitHub Desktop.
Train, Validation and Test Split for torchvision Datasets
# This is an example for the CIFAR-10 dataset.
# There's a function for creating a train and validation iterator.
# There's also a function for creating a test iterator.
# Inspired by https://discuss.pytorch.org/t/feedback-on-pytorch-for-kaggle-competitions/2252/4
from utils import plot_images
def get_train_valid_loader(data_dir,
batch_size,
augment,
random_seed,
valid_size=0.1,
shuffle=True,
show_sample=False,
num_workers=4,
pin_memory=False):
"""
Utility function for loading and returning train and valid
multi-process iterators over the CIFAR-10 dataset. A sample
9x9 grid of the images can be optionally displayed.
If using CUDA, num_workers should be set to 1 and pin_memory to True.
Params
------
- data_dir: path directory to the dataset.
- batch_size: how many samples per batch to load.
- augment: whether to apply the data augmentation scheme
mentioned in the paper. Only applied on the train split.
- random_seed: fix seed for reproducibility.
- valid_size: percentage split of the training set used for
the validation set. Should be a float in the range [0, 1].
- shuffle: whether to shuffle the train/validation indices.
- show_sample: plot 9x9 sample grid of the dataset.
- num_workers: number of subprocesses to use when loading the dataset.
- pin_memory: whether to copy tensors into CUDA pinned memory. Set it to
True if using GPU.
Returns
-------
- train_loader: training set iterator.
- valid_loader: validation set iterator.
"""
error_msg = "[!] valid_size should be in the range [0, 1]."
assert ((valid_size >= 0) and (valid_size <= 1)), error_msg
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
# define transforms
valid_transform = transforms.Compose([
transforms.ToTensor(),
normalize
])
if augment:
train_transform = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize
])
else:
train_transform = transforms.Compose([
transforms.ToTensor(),
normalize
])
# load the dataset
train_dataset = datasets.CIFAR10(root=data_dir, train=True,
download=True, transform=train_transform)
valid_dataset = datasets.CIFAR10(root=data_dir, train=True,
download=True, transform=valid_transform)
num_train = len(train_dataset)
indices = list(range(num_train))
split = int(np.floor(valid_size * num_train))
if shuffle == True:
np.random.seed(random_seed)
np.random.shuffle(indices)
train_idx, valid_idx = indices[split:], indices[:split]
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)
train_loader = torch.utils.data.DataLoader(train_dataset,
batch_size=batch_size, sampler=train_sampler,
num_workers=num_workers, pin_memory=pin_memory)
valid_loader = torch.utils.data.DataLoader(valid_dataset,
batch_size=batch_size, sampler=valid_sampler,
num_workers=num_workers, pin_memory=pin_memory)
# visualize some images
if show_sample:
sample_loader = torch.utils.data.DataLoader(train_dataset,
batch_size=9,
shuffle=shuffle,
num_workers=num_workers,
pin_memory=pin_memory)
data_iter = iter(sample_loader)
images, labels = data_iter.next()
X = images.numpy()
X = np.transpose(X, [0, 2, 3, 1])
plot_images(X, labels)
return (train_loader, valid_loader)
def get_test_loader(data_dir,
batch_size,
shuffle=True,
num_workers=4,
pin_memory=False):
"""
Utility function for loading and returning a multi-process
test iterator over the CIFAR-10 dataset.
If using CUDA, num_workers should be set to 1 and pin_memory to True.
Params
------
- data_dir: path directory to the dataset.
- batch_size: how many samples per batch to load.
- shuffle: whether to shuffle the dataset after every epoch.
- num_workers: number of subprocesses to use when loading the dataset.
- pin_memory: whether to copy tensors into CUDA pinned memory. Set it to
True if using GPU.
Returns
-------
- data_loader: test set iterator.
"""
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
# define transform
transform = transforms.Compose([
transforms.ToTensor(),
normalize
])
dataset = datasets.CIFAR10(root=data_dir,
train=False,
download=True,
transform=transform)
data_loader = torch.utils.data.DataLoader(dataset,
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers,
pin_memory=pin_memory)
return data_loader
import matplotlib.pyplot as plt
label_names = [
'airplane',
'automobile',
'bird',
'cat',
'deer',
'dog',
'frog',
'horse',
'ship',
'truck'
]
def plot_images(images, cls_true, cls_pred=None):
assert len(images) == len(cls_true) == 9
# Create figure with sub-plots.
fig, axes = plt.subplots(3, 3)
for i, ax in enumerate(axes.flat):
# plot the image
ax.imshow(images[i, :, :, :], interpolation='spline16')
# get its equivalent class name
cls_true_name = label_names[cls_true[i]]
if cls_pred is None:
xlabel = "{0} ({1})".format(cls_true_name, cls_true[i])
else:
cls_pred_name = label_names[cls_pred[i]]
xlabel = "True: {0}\nPred: {1}".format(cls_true_name, cls_pred_name)
ax.set_xlabel(xlabel)
ax.set_xticks([])
ax.set_yticks([])
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment