Skip to content

Instantly share code, notes, and snippets.

@zmonoid
Last active November 21, 2017 14:34
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save zmonoid/f4572c5b8ac9e6b149b1285248316703 to your computer and use it in GitHub Desktop.
Save zmonoid/f4572c5b8ac9e6b149b1285248316703 to your computer and use it in GitHub Desktop.
Evaludate imagenet tfrecord
import tensorflow as tf
import numpy as np
from nets.nets_factory import get_network_fn
from preprocessing.preprocessing_factory import get_preprocessing
import glob
class Config:
batch_size = 32
num_epoch = 1
name = 'resnet_v2_50'
record_pattern = '/home/bzhou/Dataset/imagenet/tfrecords/validation*'
weights = './checkpoints/' + name + '.ckpt'
num_classes = 1000 if name.startswith('vgg') else 1001
network_fn = get_network_fn(name, num_classes)
input_image_size = network_fn.default_image_size
class Classifier:
def __init__(self, config):
self.config = config
self.name = config.name
if self.name.startswith('inception'):
self.name_scope = self.name.replace('_', ' ').title().replace(' ', '')
else:
self.name_scope = self.name
self.weights = config.weights
def build_graph(self):
with tf.name_scope('input'):
self.image = tf.placeholder(tf.float32, [None, self.config.input_image_size, self.config.input_image_size, 3], name='image')
self.label = tf.placeholder(tf.int64, shape=[None])
network_fn = get_network_fn(self.name, self.config.num_classes)
self.logits, _ = network_fn(self.image)
self.label_idx = tf.argmax(self.logits, 1)
if not self.name.startswith('vgg'):
self.label_idx -= 1
self.equals = tf.equal(self.label, self.label_idx)
def make_iterator(config):
filenames = glob.glob(config.record_pattern)
dataset = tf.data.TFRecordDataset(filenames)
preprocess_fcn = get_preprocessing(config.name)
def record_parser(record):
features = {"image/encoded": tf.FixedLenFeature((), tf.string, default_value=""),
"image/class/label": tf.FixedLenFeature((), tf.int64, default_value=0)}
parsed_features = tf.parse_single_example(record, features)
label = parsed_features["image/class/label"]
image = parsed_features["image/encoded"]
image = tf.image.decode_jpeg(image)
image = tf.squeeze(image)
image = tf.cond(tf.equal(tf.rank(image), 3), lambda: image, lambda: tf.stack([image, image, image], 2))
image = preprocess_fcn(image, config.input_image_size, config.input_image_size)
if not config.name.startswith('vgg'):
image /= 255.0
return image, label
dataset = dataset.map(record_parser)
dataset = dataset.shuffle(buffer_size=1000)
dataset = dataset.batch(config.batch_size)
dataset = dataset.repeat(config.num_epoch)
iterator = dataset.make_initializable_iterator()
return iterator
def main():
config = Config()
itr = make_iterator(config)
data = itr.get_next()
model = Classifier(config)
model.build_graph()
var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=model.name_scope)
saver = tf.train.Saver(var_list)
with tf.Session() as sess:
sess.run(itr.initializer)
saver.restore(sess, model.weights)
for _ in range(10):
images, labels = sess.run(data)
labels -= 1
equals, label_idx = sess.run([model.equals, model.label_idx], feed_dict={model.image: images, model.label: labels})
print(np.mean(equals))
print(label_idx)
print('...')
print(labels)
print('-------')
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment