Skip to content

Instantly share code, notes, and snippets.

@Eranpaz
Created September 2, 2019 08:56
Show Gist options
  • Save Eranpaz/b26343bb976c527040f3c59ca57f2a78 to your computer and use it in GitHub Desktop.
Save Eranpaz/b26343bb976c527040f3c59ca57f2a78 to your computer and use it in GitHub Desktop.
def create_dataloader(self):
print("creating data loaders")
loaders = {}
for s in self.data_params['sets']:
if s == 'train':
tranform = self.create_transform(self.data_params['mean'], self.data_params['std'],
new_size=self.data_params['resize'])
dataset = datasets.ImageFolder(os.path.join(self.data_params['data_path'], 'training'), tranform)
loaders[s] = torch.utils.data.DataLoader(dataset,
self.training_params['batch_size'] * self.exp_params[
'num_gpus'], shuffle=True)
elif s == 'val':
tranform = self.create_transform(self.data_params['mean'], self.data_params['std'],
new_size=self.data_params['resize'])
dataset = datasets.ImageFolder(os.path.join(self.data_params['data_path'], 'val'), tranform)
loaders[s] = torch.utils.data.DataLoader(dataset,
self.training_params['batch_size'] * self.exp_params[
'num_gpus'], shuffle=False)
self.loaders = loaders
def create_transform(self, mean=[0, 0, 0], std=[1, 1, 1], new_size=None):
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)])
if new_size:
shape_transform = transforms.Resize(new_size)
transform = transforms.Compose([shape_transform, transform])
return transform
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment