Skip to content

Instantly share code, notes, and snippets.

@MattKleinsmith
Forked from kevinzakka/data_loader.py
Last active January 30, 2023 06:40
Show Gist options
  • Star 8 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save MattKleinsmith/5226a94bad5dd12ed0b871aed98cb123 to your computer and use it in GitHub Desktop.
Save MattKleinsmith/5226a94bad5dd12ed0b871aed98cb123 to your computer and use it in GitHub Desktop.
Train, Validation and Test Split for torchvision MNIST Dataset
# https://gist.github.com/kevinzakka/d33bf8d6c7f06a9d8c76d97a7879f5cb#file-data_loader-py
# This is an example for the MNIST dataset (formerly CIFAR-10).
# There's a function for creating a train and validation iterator.
# There's also a function for creating a test iterator.
# Inspired by https://discuss.pytorch.org/t/feedback-on-pytorch-for-kaggle-competitions/2252/4
# Adapted for MNIST by github.com/MatthewKleinsmith
import numpy as np
import torch
from torchvision import datasets, transforms
from torch.utils.data.sampler import SubsetRandomSampler
from utils import plot_images
def get_train_valid_loader(data_dir,
batch_size,
random_seed,
augment=False,
valid_size=0.2,
shuffle=True,
show_sample=False,
num_workers=1,
pin_memory=True):
"""
Utility function for loading and returning train and valid
multi-process iterators over the MNIST dataset. A sample
9x9 grid of the images can be optionally displayed.
If using CUDA, num_workers should be set to 1 and pin_memory to True.
Params
------
- data_dir: path directory to the dataset.
- batch_size: how many samples per batch to load.
- augment: whether to apply the data augmentation scheme
mentioned in the paper. Only applied on the train split.
- random_seed: fix seed for reproducibility.
- valid_size: percentage split of the training set used for
the validation set. Should be a float in the range [0, 1].
- shuffle: whether to shuffle the train/validation indices.
- show_sample: plot 9x9 sample grid of the dataset.
- num_workers: number of subprocesses to use when loading the dataset.
- pin_memory: whether to copy tensors into CUDA pinned memory. Set it to
True if using GPU.
Returns
-------
- train_loader: training set iterator.
- valid_loader: validation set iterator.
"""
error_msg = "[!] valid_size should be in the range [0, 1]."
assert ((valid_size >= 0) and (valid_size <= 1)), error_msg
normalize = transforms.Normalize((0.1307,), (0.3081,)) # MNIST
# define transforms
valid_transform = transforms.Compose([
transforms.ToTensor(),
normalize
])
if augment:
train_transform = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize
])
else:
train_transform = transforms.Compose([
transforms.ToTensor(),
normalize
])
# load the dataset
train_dataset = datasets.MNIST(root=data_dir, train=True,
download=True, transform=train_transform)
valid_dataset = datasets.MNIST(root=data_dir, train=True,
download=True, transform=valid_transform)
num_train = len(train_dataset)
indices = list(range(num_train))
split = int(np.floor(valid_size * num_train))
if shuffle == True:
np.random.seed(random_seed)
np.random.shuffle(indices)
train_idx, valid_idx = indices[split:], indices[:split]
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)
train_loader = torch.utils.data.DataLoader(train_dataset,
batch_size=batch_size, sampler=train_sampler,
num_workers=num_workers, pin_memory=pin_memory)
valid_loader = torch.utils.data.DataLoader(valid_dataset,
batch_size=batch_size, sampler=valid_sampler,
num_workers=num_workers, pin_memory=pin_memory)
# visualize some images
if show_sample:
sample_loader = torch.utils.data.DataLoader(train_dataset,
batch_size=9,
shuffle=shuffle,
num_workers=num_workers,
pin_memory=pin_memory)
data_iter = iter(sample_loader)
images, labels = data_iter.next()
X = images.numpy()
plot_images(X, labels)
return (train_loader, valid_loader)
def get_test_loader(data_dir,
batch_size,
shuffle=True,
num_workers=1,
pin_memory=True):
"""
Utility function for loading and returning a multi-process
test iterator over the MNIST dataset.
If using CUDA, num_workers should be set to 1 and pin_memory to True.
Params
------
- data_dir: path directory to the dataset.
- batch_size: how many samples per batch to load.
- shuffle: whether to shuffle the dataset after every epoch.
- num_workers: number of subprocesses to use when loading the dataset.
- pin_memory: whether to copy tensors into CUDA pinned memory. Set it to
True if using GPU.
Returns
-------
- data_loader: test set iterator.
"""
normalize = transforms.Normalize((0.1307,), (0.3081,)) # MNIST
# define transform
transform = transforms.Compose([
transforms.ToTensor(),
normalize
])
dataset = datasets.MNIST(root=data_dir,
train=False,
download=True,
transform=transform)
data_loader = torch.utils.data.DataLoader(dataset,
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers,
pin_memory=pin_memory)
return data_loader
# https://gist.github.com/kevinzakka/d33bf8d6c7f06a9d8c76d97a7879f5cb#file-utils-py
import matplotlib.pyplot as plt
def plot_images(images, labels, preds=None):
assert len(images) == len(labels) == 9
# Create figure with sub-plots.
fig, axes = plt.subplots(3, 3)
for i, ax in enumerate(axes.flat):
ax.imshow(images[i, 0, :, :], interpolation='spline16', cmap='gray')
label = str(labels[i])
if preds is None:
xlabel = label
else:
pred = str(preds[i])
xlabel = "True: {0}\nPred: {1}".format(label, pred)
ax.set_xlabel(xlabel)
ax.set_xticks([])
ax.set_yticks([])
plt.show()
@zhiqihuang
Copy link

why not read training set and use torch.utils.data.random_split?

@brando90
Copy link

@zhiqihuang share your code perhaps?

@brando90
Copy link

@MattKleinsmith isn't:

    train_idx, valid_idx = indices[split_idx:], indices[:split_idx]

backwards?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment