Skip to content

Instantly share code, notes, and snippets.

@zmonoid
Last active November 21, 2017 11:41
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save zmonoid/8d1fc625b3ca43073be276e494f50ac1 to your computer and use it in GitHub Desktop.
Save zmonoid/8d1fc625b3ca43073be276e494f50ac1 to your computer and use it in GitHub Desktop.
Read Single Tensorflow Image
"""
put the file under slim folder like: models/research/slim/classify_single_image.py
save checkpoint files under slim/checkpoints/*.ckpt
save class id to label text file like slim/imagenet1000_clsid_to_human.txt, download from https://gist.github.com/yrevar/942d3a0ac09ec9e5eb3a
put test images *.jpg under slim
run it!
"""
import tensorflow as tf
import numpy as np
from nets.nets_factory import get_network_fn
from preprocessing.preprocessing_factory import get_preprocessing
import matplotlib.pyplot as plt
class Classifier:
def __init__(self):
self.name = 'resnet_v2_50'
if self.name.startswith('inception'):
self.name_scope = self.name.replace('_', ' ').title().replace(' ', '')
else:
self.name_scope = self.name
if self.name.startswith('vgg'):
self.num_classes = 1000
else:
self.num_classes = 1001
self.weights = './checkpoints/' + self.name + '.ckpt'
def build_graph(self):
with tf.name_scope('input'):
self.image = tf.placeholder(tf.float32, [None, None, 3], name='image')
self.label = tf.placeholder(tf.int32, shape=[None])
network_fn = get_network_fn(self.name, num_classes=self.num_classes)
self.input_size = network_fn.default_image_size
process_image_fn = get_preprocessing(self.name)
processed_image = process_image_fn(self.image, self.input_size, self.input_size)
processed_image = tf.expand_dims(processed_image, 0)
if not self.name.startswith('vgg'):
processed_image /= 255.0
print(processed_image)
self.logits, _ = network_fn(processed_image)
self.label_idx = tf.argmax(self.logits, 1)
if not self.name.startswith('vgg'):
self.label_idx -= 1
def read_images(image_size):
from PIL import ImageOps, Image
import glob
images = glob.glob('*.jpg')
def read_and_resize_image(image_path):
image = Image.open(image_path)
image = ImageOps.fit(image, (image_size, image_size))
image = np.array(image)
return image
return map(read_and_resize_image, images)
def read_label_table():
with open('imagenet1000_clsid_to_human.txt', 'r') as f:
content = f.read()
return eval(content)
def main():
model = Classifier()
model.build_graph()
var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=model.name_scope)
saver = tf.train.Saver(var_list)
images = read_images(model.input_size)
table = read_label_table()
with tf.Session() as sess:
saver.restore(sess, model.weights)
for image in images:
label_idx = model.label_idx.eval(feed_dict={model.image:image})
print(label_idx[0])
print(table[label_idx[0]])
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment