|
#!/usr/bin/env python |
|
# ---------------------------------------------------------------------------- |
|
# Copyright 2015 Nervana Systems Inc. |
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
|
# you may not use this file except in compliance with the License. |
|
# You may obtain a copy of the License at |
|
# |
|
# http://www.apache.org/licenses/LICENSE-2.0 |
|
# |
|
# Unless required by applicable law or agreed to in writing, software |
|
# distributed under the License is distributed on an "AS IS" BASIS, |
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
# See the License for the specific language governing permissions and |
|
# limitations under the License. |
|
# ---------------------------------------------------------------------------- |
|
|
|
from neon.backends import gen_backend |
|
from neon.util.persist import load_obj |
|
from neon.models.model import Model |
|
|
|
# import Image |
|
from PIL import Image |
|
import PIL.ImageOps as ImageOps |
|
import numpy as np |
|
import pandas as pd |
|
|
|
# download image from https://s3-us-west-1.amazonaws.com/nervana-modelzoo/example_images/german_pointer.jpg |
|
# load image |
|
jpgfile = Image.open("../input/others/german_pointer.jpg") |
|
|
|
""" |
|
# test for statefarm image |
|
import pandas as pd |
|
import os |
|
imgs_df = pd.read_csv('../input/driver_imgs_list.csv') |
|
folder = '../input/train' |
|
images = [os.path.join(folder, r["classname"], r["img"]) for i, r in imgs_df.iterrows()] |
|
jpgfile = Image.open(images[0]) |
|
""" |
|
|
|
# reshape to be 224 x 224 |
|
im_rs = ImageOps.fit(jpgfile, (224, 224)) |
|
|
|
#preprocess the image |
|
# put into BGR order |
|
im = np.array(im_rs) |
|
im = im[:, :, ::-1] |
|
|
|
# subtract the I1K means (see the macrobatch meta data file) |
|
# R_mean 104.412277 |
|
# G_mean 119.213318 |
|
# B_mean 126.806091 |
|
im = im - np.array([126.806091, 119.213318, 104.412277]) |
|
|
|
# reorder to be (C, H, W) and flatten |
|
im = im.transpose((2, 0, 1)).flatten() |
|
|
|
# create a neon backend, using 32 images per minibatch here |
|
be = gen_backend(backend='gpu', batch_size=32) |
|
|
|
# make an image buffer on host, pad out to 32 images since batch size is 32 |
|
host_buf = np.zeros((3*224*224, be.bsz)) |
|
# set the first image to be the jpeg data loaded above |
|
host_buf[:, 0] = im.copy() |
|
|
|
dev_buf = be.zeros((3*224*224, be.bsz)) |
|
# copy images to device buffer |
|
dev_buf[:] = host_buf |
|
|
|
# load up the neon model |
|
# obtained from https://s3-us-west-1.amazonaws.com/nervana-modelzoo/alexnet/alexnet.p |
|
model_dict = load_obj('../input/weights/alexnet.p') |
|
model = Model(model_dict) |
|
|
|
# set up model buffers by ginving the input image shape |
|
# here that would be (3, 224, 224) |
|
model.initialize(model_dict['train_input_shape']) |
|
|
|
out = model.fprop(dev_buf).get() |
|
out = out[:,0] # get the predictions for all 1000 classes for the first image |
|
# find the top-5 categories |
|
catg = np.argsort(out)[-5:][::-1] |
|
|
|
""" |
|
# load the synsets from the ILSVCR 2012 dev kit |
|
# there may be a nicer way to load these: |
|
# may need to add the path to the devkit |
|
from scipy.io import loadmat |
|
meta = loadmat('ILSVRC2012_devkit_t12/data/meta.mat') |
|
#get the class names |
|
names = [meta['synsets'][ind][0][2][0] for ind in catg] |
|
print 'Top 5 classes for this image:' |
|
print names |
|
""" |
|
|
|
with open("neon/synsets.txt") as fl: |
|
labels = [s.rstrip() for s in fl.readlines()] |
|
print "label has {} items".format(len(labels)) |
|
|
|
names = [labels[ind] for ind in catg] |
|
print "top 5 index", catg |
|
print 'Top 5 classes for this image:' |
|
print names |
|
|
|
# for the record this dog is part German shorthair pointer |
|
|
|
""" |
|
label has 1860 items |
|
|
|
top 5 index [127 123 179 4 112] |
|
Top 5 classes for this image: |
|
['Border collie', 'collie', 'bluetick', 'English springer, English springer spaniel', 'Boston bull, Boston terrier'] |
|
|
|
top 5 index [783 953 781 865 346] |
|
Top 5 classes for this image: |
|
['football helmet', 'ballplayer, baseball player', 'maillot', 'military uniform', 'cornet, horn, trumpet, trump'] |
|
|
|
""" |