Create a gist now

Instantly share code, notes, and snippets.

Embed
What would you like to do?
A simple example for saving a tensorflow model and preparing it for using on Android
# Create a simple TF Graph
# By Omid Alemi - Jan 2017
# Works with TF <r1.0
import tensorflow as tf
I = tf.placeholder(tf.float32, shape=[None,3], name='I') # input
W = tf.Variable(tf.zeros_initializer(shape=[3,2]), dtype=tf.float32, name='W') # weights
b = tf.Variable(tf.zeros_initializer(shape=[2]), dtype=tf.float32, name='b') # biases
O = tf.nn.relu(tf.matmul(I, W) + b, name='O') # activation / output
saver = tf.train.Saver()
init_op = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init_op)
# save the graph
tf.train.write_graph(sess.graph_def, '.', 'hellotensor.pbtxt')
# normally you would do some training here
# we will just assign something to W
sess.run(tf.assign(W, [[1, 2],[4,5],[7,8]]))
sess.run(tf.assign(b, [1,1]))
#save a checkpoint file, which will store the above assignment
saver.save(sess, 'hellotensor.ckpt')
# Create a simple TF Graph
# By Omid Alemi - Jan 2017
# Works with TF r1.0
import tensorflow as tf
I = tf.placeholder(tf.float32, shape=[None,3], name='I') # input
W = tf.Variable(tf.zeros(shape=[3,2]), dtype=tf.float32, name='W') # weights
b = tf.Variable(tf.zeros(shape=[2]), dtype=tf.float32, name='b') # biases
O = tf.nn.relu(tf.matmul(I, W) + b, name='O') # activation / output
saver = tf.train.Saver()
init_op = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init_op)
# save the graph
tf.train.write_graph(sess.graph_def, '.', 'tfdroid.pbtxt')
# normally you would do some training here
# but fornow we will just assign something to W
sess.run(tf.assign(W, [[1, 2],[4,5],[7,8]]))
sess.run(tf.assign(b, [1,1]))
#save a checkpoint file, which will store the above assignment
saver.save(sess, 'tfdroid.ckpt')
# Preparing a TF model for usage in Android
# By Omid Alemi - Jan 2017
# Works with TF <r1.0
import sys
import tensorflow as tf
from tensorflow.python.tools import freeze_graph
from tensorflow.python.tools import optimize_for_inference_lib
MODEL_NAME = 'hellotensor'
# Freeze the graph
input_graph_path = MODEL_NAME+'.pbtxt'
checkpoint_path = './'+MODEL_NAME+'.ckpt'
input_saver_def_path = ""
input_binary = False
output_node_names = "O"
restore_op_name = "save/restore_all"
filename_tensor_name = "save/Const:0"
output_frozen_graph_name = 'frozen_'+MODEL_NAME+'.pb'
output_optimized_graph_name = 'optimized_'+MODEL_NAME+'.pb'
clear_devices = True
freeze_graph.freeze_graph(input_graph_path, input_saver_def_path,
input_binary, checkpoint_path, output_node_names,
restore_op_name, filename_tensor_name,
output_frozen_graph_name, clear_devices, "")
# Optimize for inference
input_graph_def = tf.GraphDef()
with tf.gfile.Open(output_frozen_graph_name, "r") as f:
data = f.read()
input_graph_def.ParseFromString(data)
output_graph_def = optimize_for_inference_lib.optimize_for_inference(
input_graph_def,
["I"], # an array of the input node(s)
["O"], # an array of output nodes
tf.float32.as_datatype_enum)
# Save the optimized graph
f = tf.gfile.FastGFile(output_optimized_graph_name, "w")
f.write(output_graph_def.SerializeToString())
# tf.train.write_graph(output_graph_def, './', output_optimized_graph_name)
# Preparing a TF model for usage in Android
# By Omid Alemi - Jan 2017
# Works with TF r1.0
import sys
import tensorflow as tf
from tensorflow.python.tools import freeze_graph
from tensorflow.python.tools import optimize_for_inference_lib
MODEL_NAME = 'tfdroid'
# Freeze the graph
input_graph_path = MODEL_NAME+'.pbtxt'
checkpoint_path = './'+MODEL_NAME+'.ckpt'
input_saver_def_path = ""
input_binary = False
output_node_names = "O"
restore_op_name = "save/restore_all"
filename_tensor_name = "save/Const:0"
output_frozen_graph_name = 'frozen_'+MODEL_NAME+'.pb'
output_optimized_graph_name = 'optimized_'+MODEL_NAME+'.pb'
clear_devices = True
freeze_graph.freeze_graph(input_graph_path, input_saver_def_path,
input_binary, checkpoint_path, output_node_names,
restore_op_name, filename_tensor_name,
output_frozen_graph_name, clear_devices, "")
# Optimize for inference
input_graph_def = tf.GraphDef()
with tf.gfile.Open(output_frozen_graph_name, "r") as f:
data = f.read()
input_graph_def.ParseFromString(data)
output_graph_def = optimize_for_inference_lib.optimize_for_inference(
input_graph_def,
["I"], # an array of the input node(s)
["O"], # an array of output nodes
tf.float32.as_datatype_enum)
# Save the optimized graph
f = tf.gfile.FastGFile(output_optimized_graph_name, "w")
f.write(output_graph_def.SerializeToString())
# tf.train.write_graph(output_graph_def, './', output_optimized_graph_name)
@PanagiotisPtr

This comment has been minimized.

Show comment
Hide comment
@PanagiotisPtr

PanagiotisPtr Jul 28, 2017

The script "prep_model_tf1.py" has an issue when running it. It give the following error message:
'utf-8' codec can't decode byte 0x80 in position 98: invalid start byte
What cause this error is the fact than the file needs to be read in Binary.

So just change from "r" (read):
with tf.gfile.Open(output_frozen_graph_name, "r") as f:
data = f.read()
input_graph_def.

To "rb" (read binary):
with tf.gfile.Open(output_frozen_graph_name, "rb") as f:
data = f.read()
input_graph_def.

PanagiotisPtr commented Jul 28, 2017

The script "prep_model_tf1.py" has an issue when running it. It give the following error message:
'utf-8' codec can't decode byte 0x80 in position 98: invalid start byte
What cause this error is the fact than the file needs to be read in Binary.

So just change from "r" (read):
with tf.gfile.Open(output_frozen_graph_name, "r") as f:
data = f.read()
input_graph_def.

To "rb" (read binary):
with tf.gfile.Open(output_frozen_graph_name, "rb") as f:
data = f.read()
input_graph_def.

@Pratiush

This comment has been minimized.

Show comment
Hide comment
@Pratiush

Pratiush Dec 20, 2017

I got an issue in prep_model_tf1.py when running it says: KeyError: "The following input nodes were not found: {'I'}\n"

Then I fixed it by actually printing the graph by --- print(input_graph_def)
and finally used first node name as input_node_names. In my case ['I_1'] worked fine.

I got an issue in prep_model_tf1.py when running it says: KeyError: "The following input nodes were not found: {'I'}\n"

Then I fixed it by actually printing the graph by --- print(input_graph_def)
and finally used first node name as input_node_names. In my case ['I_1'] worked fine.

@deepaksuresh

This comment has been minimized.

Show comment
Hide comment
@deepaksuresh

deepaksuresh Mar 5, 2018

How can I convert a SavedModel, that was trained and saved with tf.estimator, to .pb format?
I can load the SavedModel using predictor = tf.contrib.predictor.from_saved_model(saved_model_dir) and perform inference on it. I'd like to use this model on android, which requires the model to in .pb format.

How can I convert a SavedModel, that was trained and saved with tf.estimator, to .pb format?
I can load the SavedModel using predictor = tf.contrib.predictor.from_saved_model(saved_model_dir) and perform inference on it. I'd like to use this model on android, which requires the model to in .pb format.

@divyanshujhawar

This comment has been minimized.

Show comment
Hide comment
@divyanshujhawar

divyanshujhawar Mar 17, 2018

@deepaksuresh Did you find a way to do it?

@deepaksuresh Did you find a way to do it?

@nagasairatnakar

This comment has been minimized.

Show comment
Hide comment
@nagasairatnakar

nagasairatnakar Apr 21, 2018

When the tired to execute the example line to line. Every thing went fine and able to generte optimize pb file.
But when i ported to android. It is not giving output and no errors also.
What should i do

When the tired to execute the example line to line. Every thing went fine and able to generte optimize pb file.
But when i ported to android. It is not giving output and no errors also.
What should i do

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment