Skip to content

Instantly share code, notes, and snippets.

@Azadehkhojandi
Created February 27, 2018 01:25
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Azadehkhojandi/07f2bfe521d318c3ec956a8ff263ffe6 to your computer and use it in GitHub Desktop.
Save Azadehkhojandi/07f2bfe521d318c3ec956a8ff263ffe6 to your computer and use it in GitHub Desktop.
Using exported custom vision tensor-flow model in python
import sys
import os
import tensorflow as tf
import numpy as np
import scipy
from PIL import Image
from scipy import misc
cwd = os.getcwd()
modelFile = cwd + '/model.pb'
network_input_size = 227
output_layer = 'loss:0'
input_node = 'Placeholder:0'
graph_def = None
def load_graph(filename):
global graph_def, labels
with tf.gfile.FastGFile(filename, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')
labels = [line.rstrip() for line in tf.gfile.GFile(cwd + '/labels.txt')]
print(labels)
def center_crop(img, cropx, cropy):
y, x, z = img.shape
startx = x // 2 - (cropx // 2)
starty = y // 2 - (cropy // 2)
return img[starty:starty + cropy, startx:startx + cropx]
def classifyImage(testImage):
tf.reset_default_graph()
tf.import_graph_def(graph_def, name='')
with tf.Session() as sess:
prob_tensor = sess.graph.get_tensor_by_name(output_layer)
image = scipy.misc.imread(testImage)
w = image.shape[0]
h = image.shape[1]
temp = scipy.misc.imresize(image, (227, 227, 3))
# Center Crop
try:
temp = center_crop(temp, network_input_size, network_input_size)
except:
print(filename, ":-1", ":-1")
return
image = temp.astype(float)
# RGB -> BGR
red, green, blue = tf.split(axis=2, num_or_size_splits=3, value=image)
# Apply mean values
image_normalized = tf.concat(axis=2, values=[
blue - 104.,
green - 117.,
red - 124.,
])
image_normalized = image_normalized.eval()
image_normalized = np.expand_dims(image_normalized, axis=0)
predictions, = sess.run(prob_tensor, {input_node: image_normalized})
idx = np.argmax(predictions)
print('-----------------------')
print('Image classified as: ' + labels[idx])
print('---------------------------')
top_k = predictions.argsort()[-5:][::-1]
for node_id in top_k:
score = predictions[node_id]
print('%s (score = %.5f)' % (labels[node_id], score))
load_graph(modelFile)
if len(sys.argv) != 2:
print('Cannot classify - Usage: classify.py file path.')
else:
classifyImage(cwd + '/' + sys.argv[1])
@berndverst
Copy link

This was super helpful years later. It's still relevant today as long as you replace tf with tf.compat.v1 given that Python by default installs tensorflow 2 now.

@berndverst
Copy link

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment