Skip to content

Instantly share code, notes, and snippets.

View Eranpaz's full-sized avatar

Eran Paz Eranpaz

View GitHub Profile
def create_dataloader(self):
print("creating data loaders")
loaders = {}
for s in self.data_params['sets']:
if s == 'train':
tranform = self.create_transform(self.data_params['mean'], self.data_params['std'],
new_size=self.data_params['resize'])
dataset = datasets.ImageFolder(os.path.join(self.data_params['data_path'], 'training'), tranform)
loaders[s] = torch.utils.data.DataLoader(dataset,
self.training_params['batch_size'] * self.exp_params[