Last active
November 21, 2017 14:34
-
-
Save zmonoid/f4572c5b8ac9e6b149b1285248316703 to your computer and use it in GitHub Desktop.
Evaludate imagenet tfrecord
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
import numpy as np | |
from nets.nets_factory import get_network_fn | |
from preprocessing.preprocessing_factory import get_preprocessing | |
import glob | |
class Config: | |
batch_size = 32 | |
num_epoch = 1 | |
name = 'resnet_v2_50' | |
record_pattern = '/home/bzhou/Dataset/imagenet/tfrecords/validation*' | |
weights = './checkpoints/' + name + '.ckpt' | |
num_classes = 1000 if name.startswith('vgg') else 1001 | |
network_fn = get_network_fn(name, num_classes) | |
input_image_size = network_fn.default_image_size | |
class Classifier: | |
def __init__(self, config): | |
self.config = config | |
self.name = config.name | |
if self.name.startswith('inception'): | |
self.name_scope = self.name.replace('_', ' ').title().replace(' ', '') | |
else: | |
self.name_scope = self.name | |
self.weights = config.weights | |
def build_graph(self): | |
with tf.name_scope('input'): | |
self.image = tf.placeholder(tf.float32, [None, self.config.input_image_size, self.config.input_image_size, 3], name='image') | |
self.label = tf.placeholder(tf.int64, shape=[None]) | |
network_fn = get_network_fn(self.name, self.config.num_classes) | |
self.logits, _ = network_fn(self.image) | |
self.label_idx = tf.argmax(self.logits, 1) | |
if not self.name.startswith('vgg'): | |
self.label_idx -= 1 | |
self.equals = tf.equal(self.label, self.label_idx) | |
def make_iterator(config): | |
filenames = glob.glob(config.record_pattern) | |
dataset = tf.data.TFRecordDataset(filenames) | |
preprocess_fcn = get_preprocessing(config.name) | |
def record_parser(record): | |
features = {"image/encoded": tf.FixedLenFeature((), tf.string, default_value=""), | |
"image/class/label": tf.FixedLenFeature((), tf.int64, default_value=0)} | |
parsed_features = tf.parse_single_example(record, features) | |
label = parsed_features["image/class/label"] | |
image = parsed_features["image/encoded"] | |
image = tf.image.decode_jpeg(image) | |
image = tf.squeeze(image) | |
image = tf.cond(tf.equal(tf.rank(image), 3), lambda: image, lambda: tf.stack([image, image, image], 2)) | |
image = preprocess_fcn(image, config.input_image_size, config.input_image_size) | |
if not config.name.startswith('vgg'): | |
image /= 255.0 | |
return image, label | |
dataset = dataset.map(record_parser) | |
dataset = dataset.shuffle(buffer_size=1000) | |
dataset = dataset.batch(config.batch_size) | |
dataset = dataset.repeat(config.num_epoch) | |
iterator = dataset.make_initializable_iterator() | |
return iterator | |
def main(): | |
config = Config() | |
itr = make_iterator(config) | |
data = itr.get_next() | |
model = Classifier(config) | |
model.build_graph() | |
var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=model.name_scope) | |
saver = tf.train.Saver(var_list) | |
with tf.Session() as sess: | |
sess.run(itr.initializer) | |
saver.restore(sess, model.weights) | |
for _ in range(10): | |
images, labels = sess.run(data) | |
labels -= 1 | |
equals, label_idx = sess.run([model.equals, model.label_idx], feed_dict={model.image: images, model.label: labels}) | |
print(np.mean(equals)) | |
print(label_idx) | |
print('...') | |
print(labels) | |
print('-------') | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment