Skip to content

Instantly share code, notes, and snippets.

@aletheia
Created July 14, 2020 08:31
Show Gist options
  • Save aletheia/c76e8f0a5bec38272214ad7c6aee2f90 to your computer and use it in GitHub Desktop.
Save aletheia/c76e8f0a5bec38272214ad7c6aee2f90 to your computer and use it in GitHub Desktop.
def load_split_train_test(self, valid_size = .2):
'''Loads data and builds training/validation dataset with provided split size
Parameters:
valid_size (float): the percentage of data reserved to validation
Returns:
(torch.utils.data.DataLoader): Training data loader
(torch.utils.data.DataLoader): Validation data loader
(torch.utils.data.DataLoader): Test data loader
'''
num_workers = self.num_workers
# Create transforms for data augmentation. Since we don't care wheter numbers are upside-down, we add a horizontal flip,
# then normalized data to PyTorch defaults
train_transforms = T.Compose([T.RandomHorizontalFlip(),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])])
# Use ImageFolder to load data from main folder. Images are contained in subfolders wich name represents their label. I.e.
# training
# |--> 0
# | |--> image023.png
# | |--> image024.png
# | ...
# |--> 1
# | |--> image032.png
# | |--> image0433.png
# | ...
# ...
train_data = datasets.ImageFolder(self.train_data_dir, transform=train_transforms)
# loads image indexes within dataset, then computes split and shuffles images to add randomness
num_train = len(train_data)
indices = list(range(num_train))
split = int(np.floor(valid_size * num_train))
np.random.shuffle(indices)
# extracts indexes for train and validation, then builds a random sampler
train_idx, val_idx = indices[split:], indices[:split]
train_sampler = SubsetRandomSampler(train_idx)
val_sampler = SubsetRandomSampler(val_idx)
# which is passed to data loader to perform image sampling when loading data
train_loader = torch.utils.data.DataLoader(train_data, sampler=train_sampler, batch_size=self.batch_size, num_workers=num_workers)
val_loader = torch.utils.data.DataLoader(train_data, sampler=val_sampler, batch_size=self.batch_size, num_workers=num_workers)
# if testing dataset is defined, we build its data loader as well
test_loader = None
if self.test_data_dir is not None:
test_transforms = T.Compose([T.ToTensor(),T.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])])
test_data = datasets.ImageFolder(self.test_data_dir, transform=test_transforms)
test_loader = torch.utils.data.DataLoader(train_data,batch_size=self.batch_size, num_workers=num_workers)
return train_loader, val_loader, test_loader
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment