Last active
November 21, 2017 11:41
-
-
Save zmonoid/8d1fc625b3ca43073be276e494f50ac1 to your computer and use it in GitHub Desktop.
Read Single Tensorflow Image
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
put the file under slim folder like: models/research/slim/classify_single_image.py | |
save checkpoint files under slim/checkpoints/*.ckpt | |
save class id to label text file like slim/imagenet1000_clsid_to_human.txt, download from https://gist.github.com/yrevar/942d3a0ac09ec9e5eb3a | |
put test images *.jpg under slim | |
run it! | |
""" | |
import tensorflow as tf | |
import numpy as np | |
from nets.nets_factory import get_network_fn | |
from preprocessing.preprocessing_factory import get_preprocessing | |
import matplotlib.pyplot as plt | |
class Classifier: | |
def __init__(self): | |
self.name = 'resnet_v2_50' | |
if self.name.startswith('inception'): | |
self.name_scope = self.name.replace('_', ' ').title().replace(' ', '') | |
else: | |
self.name_scope = self.name | |
if self.name.startswith('vgg'): | |
self.num_classes = 1000 | |
else: | |
self.num_classes = 1001 | |
self.weights = './checkpoints/' + self.name + '.ckpt' | |
def build_graph(self): | |
with tf.name_scope('input'): | |
self.image = tf.placeholder(tf.float32, [None, None, 3], name='image') | |
self.label = tf.placeholder(tf.int32, shape=[None]) | |
network_fn = get_network_fn(self.name, num_classes=self.num_classes) | |
self.input_size = network_fn.default_image_size | |
process_image_fn = get_preprocessing(self.name) | |
processed_image = process_image_fn(self.image, self.input_size, self.input_size) | |
processed_image = tf.expand_dims(processed_image, 0) | |
if not self.name.startswith('vgg'): | |
processed_image /= 255.0 | |
print(processed_image) | |
self.logits, _ = network_fn(processed_image) | |
self.label_idx = tf.argmax(self.logits, 1) | |
if not self.name.startswith('vgg'): | |
self.label_idx -= 1 | |
def read_images(image_size): | |
from PIL import ImageOps, Image | |
import glob | |
images = glob.glob('*.jpg') | |
def read_and_resize_image(image_path): | |
image = Image.open(image_path) | |
image = ImageOps.fit(image, (image_size, image_size)) | |
image = np.array(image) | |
return image | |
return map(read_and_resize_image, images) | |
def read_label_table(): | |
with open('imagenet1000_clsid_to_human.txt', 'r') as f: | |
content = f.read() | |
return eval(content) | |
def main(): | |
model = Classifier() | |
model.build_graph() | |
var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=model.name_scope) | |
saver = tf.train.Saver(var_list) | |
images = read_images(model.input_size) | |
table = read_label_table() | |
with tf.Session() as sess: | |
saver.restore(sess, model.weights) | |
for image in images: | |
label_idx = model.label_idx.eval(feed_dict={model.image:image}) | |
print(label_idx[0]) | |
print(table[label_idx[0]]) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment