Created
June 12, 2018 01:57
-
-
Save shabazpatel/202a21901e34d9a50c504e63fdd86422 to your computer and use it in GitHub Desktop.
Prediction using a resnet152 layered classifier trained over imagenet dataset
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import mxnet as mx | |
path='http://data.mxnet.io/models/imagenet-11k/' | |
if not os.path.exists('resnet-152-symbol.json'): | |
mx.test_utils.download(path+'resnet-152/resnet-152-symbol.json') | |
if not os.path.exists('resnet-152-0000.params'): | |
mx.test_utils.download(path+'resnet-152/resnet-152-0000.params') | |
if not os.path.exists('synset.txt'): | |
mx.test_utils.download(path+'synset.txt') | |
sym, arg_params, aux_params = mx.model.load_checkpoint('resnet-152', 0) | |
mod = mx.mod.Module(symbol=sym, context=mx.cpu(), label_names=None) | |
mod.bind(for_training=False, data_shapes=[('data', (1,3,224,224))], | |
label_shapes=mod._label_shapes) | |
mod.set_params(arg_params, aux_params, allow_missing=True) | |
with open('synset.txt', 'r') as f: | |
labels = [l.rstrip() for l in f] | |
import cv2 | |
import urllib | |
import numpy as np | |
# define a simple data batch | |
from collections import namedtuple | |
Batch = namedtuple('Batch', ['data']) | |
def get_image(url): | |
# download and show the image | |
url_response = urllib.urlopen((url)) | |
img_array = np.array(bytearray(url_response.read()), dtype=np.uint8) | |
image = cv2.imdecode(img_array, -1) | |
img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
if img is None: | |
return None | |
# convert into format (batch, RGB, width, height) | |
img = cv2.resize(img, (224, 224)) | |
img = np.swapaxes(img, 0, 2) | |
img = np.swapaxes(img, 1, 2) | |
img = img[np.newaxis, :] | |
return img | |
def predict(self, params): | |
url = params['image_url'] | |
img = get_image(url) | |
# compute the predict probabilities | |
mod.forward(Batch([mx.nd.array(img)])) | |
prob = mod.get_outputs()[0].asnumpy() | |
# print the top-5 | |
prob = np.squeeze(prob) | |
a = np.argsort(prob)[::-1] | |
classes = [] | |
for i in a[0:5]: | |
classes.append({'porbability': str(prob[i]), "class": labels[i]}) | |
return {"result":classes} | |
#print predict('http://writm.com/wp-content/uploads/2016/08/Cat-hd-wallpapers.jpg') | |
#print predict('http://thenotoriouspug.com/wp-content/uploads/2015/01/Pug-Cookie-1920x1080-1024x576.jpg') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment