This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Download a pre-trained ResNet18 model and freeze its weights | |
model = torchvision.models.resnet18(pretrained=True) | |
for param in model.parameters(): | |
param.requires_grad = False | |
# Replace the final fully connected layer | |
# Parameters of newly constructed modules have requires_grad=True by default | |
num_ftrs = model.fc.in_features | |
model.fc = nn.Linear(num_ftrs, 2) | |
# Send the model to the GPU |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def train_model(model, criterion, optimizer, scheduler, num_epochs=25): | |
since = time.time() | |
best_model_wts = copy.deepcopy(model.state_dict()) | |
best_acc = 0.0 | |
epoch_time = [] # we'll keep track of the time needed for each epoch | |
for epoch in range(num_epochs): | |
epoch_start = time.time() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Helper function for displaying images | |
def imshow(inp, title=None): | |
"""Imshow for Tensor.""" | |
inp = inp.numpy().transpose((1, 2, 0)) | |
mean = np.array([0.485, 0.456, 0.406]) | |
std = np.array([0.229, 0.224, 0.225]) | |
# Un-normalize the images | |
inp = std * inp + mean | |
# Clip just in case |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torchvision, time, os, copy | |
import torch.nn as nn | |
import torch.optim as optim | |
from torch.optim import lr_scheduler | |
# Data augmentation and normalization for training | |
# Just normalization for validation | |
data_transforms = { | |
'train': transforms.Compose([ | |
transforms.RandomResizedCrop(224), # ImageNet models were trained on 224x224 images |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
fig, ax = plt.subplots(); # the semicolon silences the irrelevant output | |
ax.plot(fps) | |
ax.set_xlabel("Iteration"); | |
ax.set_ylabel("FPS"); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
fps = np.zeros(200) | |
with torch.no_grad(): # speed it up by not computing gradients since we don't need them for inference | |
for i in range(200): | |
t0 = time.time() | |
out = model(image) | |
fps[i] = 1 / (time.time() - t0) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Add the path to torchvision - change as needed | |
import sys | |
sys.path.insert(0, '/home/mircea/python-envs/env/lib/python3.6/site-packages/vision') | |
# Choose an image to pass through the model | |
test_image = 'images/dog.jpg' | |
# Imports | |
import torch, json | |
import numpy as np |