Skip to content

Instantly share code, notes, and snippets.

@mirceast
Created May 22, 2019 14:00
Show Gist options
  • Save mirceast/77129d7cfdd1c3affe6bbbec3712ba4a to your computer and use it in GitHub Desktop.
Save mirceast/77129d7cfdd1c3affe6bbbec3712ba4a to your computer and use it in GitHub Desktop.
Transfer learning 1
import torchvision, time, os, copy
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
# Data augmentation and normalization for training
# Just normalization for validation
data_transforms = {
'train': transforms.Compose([
transforms.RandomResizedCrop(224), # ImageNet models were trained on 224x224 images
transforms.RandomHorizontalFlip(), # flip horizontally 50% of the time - increases train set variability
transforms.ToTensor(), # convert it to a PyTorch tensor
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # ImageNet models expect this norm
]),
'val': transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}
data_dir = 'hymenoptera_data'
# Create train and validation datasets and loaders
image_datasets = {
x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x])
for x in ['train', 'val']
}
dataloaders = {
x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4, shuffle=True, num_workers=4)
for x in ['train', 'val']
}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment