Skip to content

Instantly share code, notes, and snippets.

@KushajveerSingh
Created December 17, 2019 10:21
Show Gist options
  • Select an option

  • Save KushajveerSingh/7705c90dded96c8993306311e7d8dc40 to your computer and use it in GitHub Desktop.

Select an option

Save KushajveerSingh/7705c90dded96c8993306311e7d8dc40 to your computer and use it in GitHub Desktop.
train_dir = '../../../Data/ILSVRC2012/train'
val_dir = '../../../Data/ILSVRC2012/val'
size = 224
batch_size = 32
num_workers = 8
data_transforms = {
'train': transforms.Compose([
transforms.CenterCrop(size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]),
'val': transforms.Compose([
transforms.CenterCrop(size),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
}
image_datasets = {
'train': ImageFolder(train_dir, transform=data_transforms['train']),
'val': ImageFolder(val_dir, transform=data_transforms['val']),
}
data_loader = {
x: torch.utils.data.DataLoader(image_datasets[x],
batch_size=batch_size,
shuffle=True,
num_workers=num_workers) for x in ['train', 'val']
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment