Skip to content

Instantly share code, notes, and snippets.

@conormm
Last active October 8, 2018 19:45
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save conormm/2c179c59132e3bcde7af026509aa0ec8 to your computer and use it in GitHub Desktop.
Save conormm/2c179c59132e3bcde7af026509aa0ec8 to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
from torchvision.datasets.folder import ImageFolder, default_loader
from torchvision.datasets.utils import check_integrity
from torchvision import transforms
from torchvision import models
import matplotlib.pyplot as plt
from src.utils_cm import ModelParameters
images_dir = "data/sample"
NUM_EPOCHS = 3
IMG_SIZE = 250
BATCH_SIZE = 10
# these are standard pytorch values for image normalization
normmean = [0.485, 0.456, 0.406]
normstd = [0.229, 0.224, 0.225]
def fine_tuning_model(model, n_classes=120):
ModelParameters.freeze_all(model.parameters())
assert ModelParameters.all_frozen(model.parameters())
model.ft_layer = nn.Linear(1000, n_classes)
assert model.ft_layer.weight.requires_grad
return model
train_trans = transforms.Compose([
transforms.Resize(IMG_SIZE),
transforms.RandomCrop(224),
transforms.ColorJitter(.3, .3, .3),
transforms.RandomHorizontalFlip(p=.3),
transforms.ToTensor(),
transforms.g
transforms.Normalize(normmean, normstd)
])
val_trains = transforms.Compose([
transforms.Resize(IMG_SIZE),
transforms.CenterCrop(),
transforms.ToTensor(),
transforms.Normalize()
])
img_f = ImageFolder(images_dir, transform=train_trans)
n_classes = len(img_f.classes)
ds = DataLoader(img_f, batch_size=BATCH_SIZE, shuffle=True)
VGG16 = models.vgg16(pretrained=True)
VGG16 = fine_tuning_model(VGG16)
optim = torch.optim.Adam(
ModelParameters.get_trainable(VGG16.parameters()),
lr=0.001
)
criterion = nn.CrossEntropyLoss()
VGG16.train()
for epoch in range(NUM_EPOCHS):
print(f"Epoch number {epoch}")
for ix, (X, y) in enumerate(ds):
optim.zero_grad()
X.requires_grad = True
preds = VGG16(X)
loss = criterion(preds, y)
loss.backward()
optim.step()
print(f"Loss: {loss.item()}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment