Skip to content

Instantly share code, notes, and snippets.

@khalidmeister
Created June 23, 2020 14:41
Show Gist options
  • Save khalidmeister/719c3978cb4bc386ebd0af752ba6b5e9 to your computer and use it in GitHub Desktop.
Save khalidmeister/719c3978cb4bc386ebd0af752ba6b5e9 to your computer and use it in GitHub Desktop.
Transform Image Data For The Model
import numpy as np
import time
import copy
import os
import torch
import torch.optim as optim
import torch.nn as nn
import torchvision
import matplotlib.pyplot as plt
from torch.optim import lr_scheduler
from torchvision import datasets, models, transforms
# Set the transformation for each dataset folder
transforms = {
'train': transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'test': transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
}
# Import the dataset
data_dir = 'data'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), transform=transforms[x])
for x in ['train', 'val', 'test']}
# Shuffle the dataset and create batches from the dataset
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4, shuffle=True, num_workers=4)
for x in ['train', 'val', 'test']}
# Get the number of images in each folder
data_size = {x: len(image_datasets[x]) for x in ['train', 'val', 'test']}
# Get the class name
class_names = image_datasets['train'].classes
# Enable the GPU if it exists
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment