Created
April 28, 2017 07:07
-
-
Save PavlosMelissinos/9daa295d11af87848c3ea0778696eddd to your computer and use it in GitHub Desktop.
im2txt inference modified
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright 2016 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
r"""Generate captions for images using default beam search parameters.""" | |
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import math | |
import os | |
import tensorflow as tf | |
from im2txt import configuration | |
from im2txt import inference_wrapper | |
from im2txt.inference_utils import caption_generator | |
from im2txt.inference_utils import vocabulary | |
from PIL import Image as PILImage | |
# from keras.preprocessing.image import img_to_array | |
# from keras.preprocessing import image as k_image | |
import keras | |
FLAGS = tf.flags.FLAGS | |
tf.flags.DEFINE_string("checkpoint_path", "", | |
"Model checkpoint file or directory containing a " | |
"model checkpoint file.") | |
tf.flags.DEFINE_string("vocab_file", "", "Text file containing the vocabulary.") | |
tf.flags.DEFINE_string("input_files", "", | |
"File pattern or comma-separated list of file patterns " | |
"of image files.") | |
tf.logging.set_verbosity(tf.logging.INFO) | |
def load_image(filename): | |
from keras.preprocessing.image import img_to_array | |
arr = img_to_array(PILImage.open(filename)) | |
return arr | |
def encode_image(filename): | |
g2 = tf.Graph() | |
from keras.preprocessing.image import img_to_array | |
with g2.as_default() as g: | |
with g.name_scope("g2") as g2_scope: | |
arr = img_to_array(PILImage.open(filename)) | |
image = tf.image.encode_jpeg(arr) | |
return image | |
def main(_): | |
# Build the inference graph. | |
g = tf.Graph() | |
with g.as_default(): | |
model = inference_wrapper.InferenceWrapper() | |
restore_fn = model.build_graph_from_config(configuration.ModelConfig(), | |
FLAGS.checkpoint_path) | |
g.finalize() | |
# Create the vocabulary. | |
vocab = vocabulary.Vocabulary(FLAGS.vocab_file) | |
filenames = [] | |
for file_pattern in FLAGS.input_files.split(","): | |
filenames.extend(tf.gfile.Glob(file_pattern)) | |
tf.logging.info("Running caption generation on %d files matching %s", | |
len(filenames), FLAGS.input_files) | |
with tf.Session(graph=g) as sess: | |
# Load the model from checkpoint. | |
restore_fn(sess) | |
# Prepare the caption generator. Here we are implicitly using the default | |
# beam search parameters. See caption_generator.py for a description of the | |
# available beam search parameters. | |
generator = caption_generator.CaptionGenerator(model, vocab) | |
for filename in filenames: | |
# with tf.gfile.GFile(filename, "r") as f: | |
# image = f.read() | |
image = encode_image(filename) | |
# image = load_image(filename) | |
captions = generator.beam_search(sess, image) | |
print("Captions for image %s:" % os.path.basename(filename)) | |
for i, caption in enumerate(captions): | |
# Ignore begin and end words. | |
sentence = [vocab.id_to_word(w) for w in caption.sentence[1:-1]] | |
sentence = " ".join(sentence) | |
print(" %d) %s (p=%f)" % (i, sentence, math.exp(caption.logprob))) | |
if __name__ == "__main__": | |
tf.app.run() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Some clarification:
Lines 32-35, 50-63 and 95-96 are new, the rest is identical to the code in the repo.
Line 95 converts a filename to a tf.Tensor, whereas line 96 converts it to a numpy array