Skip to content

Instantly share code, notes, and snippets.

@omarsar
Created December 29, 2019 16:01
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save omarsar/19bc5d597d0f2a58d216ca6b23625085 to your computer and use it in GitHub Desktop.
Save omarsar/19bc5d597d0f2a58d216ca6b23625085 to your computer and use it in GitHub Desktop.
## configure root folder on your gdrive
data_dir = 'gdrive/My Drive/DAIR RESOURCES/TF to PT/datasets/hymenoptera_data'
## custom transformer to flatten the image tensors
class ReshapeTransform:
def __init__(self, new_size):
self.new_size = new_size
def __call__(self, img):
result = torch.reshape(img, self.new_size)
return result
## transformations used to standardize and normalize the datasets
data_transforms = {
'train': transforms.Compose([
transforms.Resize(224),
transforms.CenterCrop(224),
transforms.ToTensor(),
ReshapeTransform((-1,)) # flattens the data
]),
'val': transforms.Compose([
transforms.Resize(224),
transforms.CenterCrop(224),
transforms.ToTensor(),
ReshapeTransform((-1,)) # flattens the data
]),
}
## load the correspoding folders
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
data_transforms[x])
for x in ['train', 'val']}
## load the entire dataset; we are not using minibatches here
train_dataset = torch.utils.data.DataLoader(image_datasets['train'],
batch_size=len(image_datasets['train']),
shuffle=True)
test_dataset = torch.utils.data.DataLoader(image_datasets['val'],
batch_size=len(image_datasets['val']),
shuffle=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment