Skip to content

Instantly share code, notes, and snippets.

@greed2411
Created January 2, 2018 13:32
Show Gist options
  • Save greed2411/02a50712c8d33aa514ea9d08b0a8a3e8 to your computer and use it in GitHub Desktop.
Save greed2411/02a50712c8d33aa514ea9d08b0a8a3e8 to your computer and use it in GitHub Desktop.
script for loading the pytorch model, and processing incoming image and get the output for the skin cancer detection.
import torch
import torchvision
from torchvision import datasets, transforms, models
import torch.nn as nn
from torch.autograd import Variable
from PIL import Image
import numpy as np
# >>> torch.__version__
# '0.2.0_4'
# >>> numpy.__version__
# '1.11.3'
# >>> PIL.__version__
# '4.0.0'
# have the pytorch model, in the same directory and images in the directory called 'validating', change them according to your convenience.
if __name__ == "__main__":
model = models.resnet50()
model.fc = nn.Linear(2048, 2)
model.load_state_dict(torch.load('./best_model.pth', map_location={'cuda:0': 'cpu'}))
# Load image
real = Image.open('./validating/ISIC_0012151.jpg')
preprocess = transforms.Compose([
transforms.Scale(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# Load input
input_image = preprocess(real).unsqueeze_(0)
# pass it through the model
prediction = model(Variable(input_image))
# get the result out and reshape it
cpu_pred = prediction.cpu()
result = cpu_pred.data.numpy()
print(result)
if (np.argmax(result) == 1): # there are two cases 0 position for being melanoma and 1 position for keratosis
str_label = 'benign'
else:
str_label = 'malignant'
print(str_label)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment