Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Converting openimages inception_v3 to DeepDetect .pb
### Loads the inceptionv3 graph def and loads the openimages weights checkpoint into the graph
import math
import sys
import os.path
import numpy as np
import tensorflow as tf
from tensorflow.contrib.slim.python.slim.nets import inception
from tensorflow.python.framework import ops
from tensorflow.python.training import saver as tf_saver
from tensorflow.python.training import supervisor
from tensorflow.python.framework import graph_io
slim = tf.contrib.slim
g = tf.Graph()
with g.as_default():
# Create the expected InputImage layer with the correct (4d) shape
input_image = tf.placeholder(tf.float32, shape=(1, 299,299,3), name="InputImage")
## Pre-process image
# Based on the included classify.py from the OpenImages repo,
# we should be doing a central crop (which expects 3d shape), etc
# If we include the pre-processing, it seems to always give the same labels/confidences regardless of input image
# If we do *not* include the pre-processing, it doesn't quite match what's expected, but is more sane (see example at end)
image = tf.squeeze(input_image, [0])
image = tf.image.central_crop(image, central_fraction=0.875)
image = tf.expand_dims(image, [0])
image = tf.image.resize_bilinear(image,
[299, 299],
align_corners=False)
image = tf.multiply(image, 1.0/127.5)
processed_image = tf.subtract(image, 1.0)
# End of pre-processing image based on classify.py
with slim.arg_scope(inception.inception_v3_arg_scope()):
# If excluding the preprocessing steps, replace 'processed_image' below with 'image'
logits, end_points = inception.inception_v3(processed_image, num_classes=6012, is_training=False)
# Add the sigmoid multi-predictions tensor at the end
predictions = end_points['multi_predictions'] = tf.nn.sigmoid(logits, name='multi_predictions')
saver = tf_saver.Saver()
sess = tf.Session()
# Load the model.ckpt weights into the graph, possibly not necessary at this stage since we freeze the graph later
saver.restore(sess, '/path/to/model.ckpt')
# Save the graph def to a plaintext .pb in the current directory
graph_io.write_graph(sess.graph, './', 'input_graph4.pb')
### Use the included freeze_graph.py to save the graph def and weights to a single binary .pb.
### This requires you to have built the freeze_graph app with bazel first
# $ bazel-bin/tensorflow/python/tools/freeze_graph --input_graph /path/to/input_graph4.pb --input_checkpoint /path/to/model.ckpt --output_node_names multi_predictions --output_graph /path/to/output/frozen4.pb
### OPTIONAL: To view this in tensorboard, use the following
import os
import os.path
import tensorflow as tf
from tensorflow.python.platform import gfile
LOG_DIR = '/tmp/graphdeflogdir'
if not os.path.exists(LOG_DIR):
os.makedirs(LOG_DIR)
with tf.Session() as sess:
model_filename = '/path/to/frozen4.pb'
with gfile.FastGFile(model_filename, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(graph_def, name='')
writer = tf.summary.FileWriter(LOG_DIR, graph_def)
writer.close()
### Then run `tensorboard --logdir /tmp/graphdeflogdir` and go to http://127.0.0.1:6006 and look at the graph tab
## Example output without image pre-processing, compare to https://github.com/openimages/dataset/issues/3#issuecomment-259552328
"predictions": [
{
"uri": "https://camo.githubusercontent.com/91dc7653d8e86f6c14b674fb9b250ed86f6e74f1/68747470733a2f2f6661726d352e737461746963666c69636b722e636f6d2f343036362f343338363530353233365f383439333038396335645f622e6a7067",
"classes": [
{
"prob": 0.8925855159759521,
"cat": "animal"
},
{
"prob": 0.8515647649765015,
"cat": "fauna"
},
{
"prob": 0.8000080585479736,
"cat": "mammal"
},
{
"prob": 0.7988731265068054,
"cat": "wildlife"
},
{
"prob": 0.7174997925758362,
"cat": "vertebrate"
},
{
"prob": 0.32620665431022644,
"cat": "grass"
},
{
"prob": 0.2770194113254547,
"cat": "pasture"
},
{
"prob": 0.27196070551872253,
"cat": "prairie"
},
{
"prob": 0.25592267513275146,
"cat": "grassland"
},
{
"prob": 0.16941961646080017,
"last": true,
"cat": "meadow"
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment