Skip to content

Instantly share code, notes, and snippets.

@srikarplus
Last active May 18, 2022 10:17
Show Gist options
  • Save srikarplus/8bdb5bedf0ca25e894e39ea78fce2f39 to your computer and use it in GitHub Desktop.
Save srikarplus/8bdb5bedf0ca25e894e39ea78fce2f39 to your computer and use it in GitHub Desktop.
Train and Validation Split for Pytorch torchvision Datasets
import torch
import numpy as np
from utils import plot_images
from torchvision import datasets
from torchvision import transforms
from torch.utils.data.sampler import SubsetRandomSampler
def get_train_valid_loader(data_dir,
batch_size,
augment,
random_seed,
valid_size=0.1,
shuffle=True,
show_sample=False,
num_workers=4,
pin_memory=False):
"""
Params
------
- data_dir: path directory to the dataset.
- batch_size: how many samples per batch to load.
- augment: whether to apply the data augmentation scheme
mentioned in the paper. Only applied on the train split.
- random_seed: fix seed for reproducibility.
- valid_size: percentage split of the training set used for
the validation set. Should be a float in the range [0, 1].
- shuffle: whether to shuffle the train/validation indices.
- show_sample: plot 9x9 sample grid of the dataset.
- num_workers: number of subprocesses to use when loading the dataset.
- pin_memory: whether to copy tensors into CUDA pinned memory. Set it to
True if using GPU.
Returns
-------
- train_loader: training set iterator.
- valid_loader: validation set iterator.
"""
error_msg = "[!] valid_size should be in the range [0, 1]."
assert ((valid_size >= 0) and (valid_size <= 1)), error_msg
normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
)
# define transforms
valid_transform = transforms.Compose([transforms.Resize(255),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize])
if augment:
train_transform = transforms.Compose([transforms.RandomRotation(30),
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize])
else:
train_transform = transforms.Compose([transforms.Resize(255),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize])
# load the dataset
train_dataset = datasets.ImageFolder(
root=data_dir, transform=train_transform,
)
valid_dataset = datasets.ImageFolder(
root=data_dir, transform=valid_transform,
)
num_train = len(train_dataset)
indices = list(range(num_train))
split = int(np.floor(valid_size * num_train))
if shuffle:
np.random.seed(random_seed)
np.random.shuffle(indices)
train_idx, valid_idx = indices[split:], indices[:split]
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size, sampler=train_sampler,
num_workers=num_workers, pin_memory=pin_memory,
)
valid_loader = torch.utils.data.DataLoader(
valid_dataset, batch_size=batch_size, sampler=valid_sampler,
num_workers=num_workers, pin_memory=pin_memory,
)
return (train_loader, valid_loader)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment