Skip to content

Instantly share code, notes, and snippets.

@shabazpatel
Created June 12, 2018 01:57
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save shabazpatel/202a21901e34d9a50c504e63fdd86422 to your computer and use it in GitHub Desktop.
Save shabazpatel/202a21901e34d9a50c504e63fdd86422 to your computer and use it in GitHub Desktop.
Prediction using a resnet152 layered classifier trained over imagenet dataset
import os
import mxnet as mx
path='http://data.mxnet.io/models/imagenet-11k/'
if not os.path.exists('resnet-152-symbol.json'):
mx.test_utils.download(path+'resnet-152/resnet-152-symbol.json')
if not os.path.exists('resnet-152-0000.params'):
mx.test_utils.download(path+'resnet-152/resnet-152-0000.params')
if not os.path.exists('synset.txt'):
mx.test_utils.download(path+'synset.txt')
sym, arg_params, aux_params = mx.model.load_checkpoint('resnet-152', 0)
mod = mx.mod.Module(symbol=sym, context=mx.cpu(), label_names=None)
mod.bind(for_training=False, data_shapes=[('data', (1,3,224,224))],
label_shapes=mod._label_shapes)
mod.set_params(arg_params, aux_params, allow_missing=True)
with open('synset.txt', 'r') as f:
labels = [l.rstrip() for l in f]
import cv2
import urllib
import numpy as np
# define a simple data batch
from collections import namedtuple
Batch = namedtuple('Batch', ['data'])
def get_image(url):
# download and show the image
url_response = urllib.urlopen((url))
img_array = np.array(bytearray(url_response.read()), dtype=np.uint8)
image = cv2.imdecode(img_array, -1)
img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
if img is None:
return None
# convert into format (batch, RGB, width, height)
img = cv2.resize(img, (224, 224))
img = np.swapaxes(img, 0, 2)
img = np.swapaxes(img, 1, 2)
img = img[np.newaxis, :]
return img
def predict(self, params):
url = params['image_url']
img = get_image(url)
# compute the predict probabilities
mod.forward(Batch([mx.nd.array(img)]))
prob = mod.get_outputs()[0].asnumpy()
# print the top-5
prob = np.squeeze(prob)
a = np.argsort(prob)[::-1]
classes = []
for i in a[0:5]:
classes.append({'porbability': str(prob[i]), "class": labels[i]})
return {"result":classes}
#print predict('http://writm.com/wp-content/uploads/2016/08/Cat-hd-wallpapers.jpg')
#print predict('http://thenotoriouspug.com/wp-content/uploads/2015/01/Pug-Cookie-1920x1080-1024x576.jpg')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment