This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch.optim import Adam | |
# Check if gpu support is available | |
cuda_avail = torch.cuda.is_available() | |
# Create model, optimizer and loss function | |
model = SimpleNet(num_classes=10) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Create a learning rate adjustment function that divides the learning rate by 10 every 30 epochs | |
def adjust_learning_rate(epoch): | |
lr = 0.001 | |
if epoch > 180: | |
lr = lr / 1000000 | |
elif epoch > 150: | |
lr = lr / 100000 | |
elif epoch > 120: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def save_models(epoch): | |
torch.save(model.state_dict(), "cifar10model_{}.model".format(epoch)) | |
print("Chekcpoint saved") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def test(): | |
model.eval() | |
test_acc = 0.0 | |
for i, (images, labels) in enumerate(test_loader): | |
if cuda_avail: | |
images = Variable(images.cuda()) | |
labels = Variable(labels.cuda()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def train(num_epochs): | |
best_acc = 0.0 | |
for epoch in range(num_epochs): | |
model.train() | |
train_acc = 0.0 | |
train_loss = 0.0 | |
for i, (images, labels) in enumerate(train_loader): | |
# Move images and labels to gpu if available |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#Import needed packages | |
import torch | |
import torch.nn as nn | |
from torchvision.datasets import CIFAR10 | |
from torchvision.transforms import transforms | |
from torch.utils.data import DataLoader | |
from torch.optim import Adam | |
from torch.autograd import Variable | |
import numpy as np |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Import needed packages | |
import torch | |
import torch.nn as nn | |
from torchvision.transforms import transforms | |
import matplotlib.pyplot as plt | |
import numpy as np | |
from torch.autograd import Variable | |
from torchvision.models import squeezenet1_1 | |
import torch.functional as F | |
import requests |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Import needed packages | |
import torch | |
import torch.nn as nn | |
from torchvision.transforms import transforms | |
from torch.autograd import Variable | |
from torchvision.models import squeezenet1_1 | |
import requests | |
import shutil | |
from io import open | |
import os |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def predict_image(image_path): | |
print("Prediction in progress") | |
image = Image.open(image_path) | |
# Define transformations for the image, should (note that imagenet models are trained with image size 224) | |
transformation = transforms.Compose([ | |
transforms.CenterCrop(224), | |
transforms.ToTensor(), | |
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
if __name__ == "__main__": | |
imagefile = "image.png" | |
imagepath = os.path.join(os.getcwd(), imagefile) | |
# Donwload image if it doesn't exist | |
if not os.path.exists(imagepath): | |
data = requests.get( | |
"https://github.com/OlafenwaMoses/ImageAI/raw/master/images/3.jpg", stream=True) |