Created
February 27, 2018 01:25
-
-
Save Azadehkhojandi/07f2bfe521d318c3ec956a8ff263ffe6 to your computer and use it in GitHub Desktop.
Using exported custom vision tensor-flow model in python
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys | |
import os | |
import tensorflow as tf | |
import numpy as np | |
import scipy | |
from PIL import Image | |
from scipy import misc | |
cwd = os.getcwd() | |
modelFile = cwd + '/model.pb' | |
network_input_size = 227 | |
output_layer = 'loss:0' | |
input_node = 'Placeholder:0' | |
graph_def = None | |
def load_graph(filename): | |
global graph_def, labels | |
with tf.gfile.FastGFile(filename, 'rb') as f: | |
graph_def = tf.GraphDef() | |
graph_def.ParseFromString(f.read()) | |
tf.import_graph_def(graph_def, name='') | |
labels = [line.rstrip() for line in tf.gfile.GFile(cwd + '/labels.txt')] | |
print(labels) | |
def center_crop(img, cropx, cropy): | |
y, x, z = img.shape | |
startx = x // 2 - (cropx // 2) | |
starty = y // 2 - (cropy // 2) | |
return img[starty:starty + cropy, startx:startx + cropx] | |
def classifyImage(testImage): | |
tf.reset_default_graph() | |
tf.import_graph_def(graph_def, name='') | |
with tf.Session() as sess: | |
prob_tensor = sess.graph.get_tensor_by_name(output_layer) | |
image = scipy.misc.imread(testImage) | |
w = image.shape[0] | |
h = image.shape[1] | |
temp = scipy.misc.imresize(image, (227, 227, 3)) | |
# Center Crop | |
try: | |
temp = center_crop(temp, network_input_size, network_input_size) | |
except: | |
print(filename, ":-1", ":-1") | |
return | |
image = temp.astype(float) | |
# RGB -> BGR | |
red, green, blue = tf.split(axis=2, num_or_size_splits=3, value=image) | |
# Apply mean values | |
image_normalized = tf.concat(axis=2, values=[ | |
blue - 104., | |
green - 117., | |
red - 124., | |
]) | |
image_normalized = image_normalized.eval() | |
image_normalized = np.expand_dims(image_normalized, axis=0) | |
predictions, = sess.run(prob_tensor, {input_node: image_normalized}) | |
idx = np.argmax(predictions) | |
print('-----------------------') | |
print('Image classified as: ' + labels[idx]) | |
print('---------------------------') | |
top_k = predictions.argsort()[-5:][::-1] | |
for node_id in top_k: | |
score = predictions[node_id] | |
print('%s (score = %.5f)' % (labels[node_id], score)) | |
load_graph(modelFile) | |
if len(sys.argv) != 2: | |
print('Cannot classify - Usage: classify.py file path.') | |
else: | |
classifyImage(cwd + '/' + sys.argv[1]) |
I'm also at @microsoft. I used most of your code here: https://github.com/berndverst/AzureCustomVision-TensorFlowImageClassification
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
This was super helpful years later. It's still relevant today as long as you replace
tf
withtf.compat.v1
given that Python by default installs tensorflow 2 now.