Skip to content

Instantly share code, notes, and snippets.

@omimo
Last active September 26, 2023 08:37
Show Gist options
  • Star 38 You must be signed in to star a gist
  • Fork 17 You must be signed in to fork a gist
  • Save omimo/5d393ed5b64d2ca0c591e4da04af6009 to your computer and use it in GitHub Desktop.
Save omimo/5d393ed5b64d2ca0c591e4da04af6009 to your computer and use it in GitHub Desktop.
A simple example for saving a tensorflow model and preparing it for using on Android
# Create a simple TF Graph
# By Omid Alemi - Jan 2017
# Works with TF <r1.0
import tensorflow as tf
I = tf.placeholder(tf.float32, shape=[None,3], name='I') # input
W = tf.Variable(tf.zeros_initializer(shape=[3,2]), dtype=tf.float32, name='W') # weights
b = tf.Variable(tf.zeros_initializer(shape=[2]), dtype=tf.float32, name='b') # biases
O = tf.nn.relu(tf.matmul(I, W) + b, name='O') # activation / output
saver = tf.train.Saver()
init_op = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init_op)
# save the graph
tf.train.write_graph(sess.graph_def, '.', 'hellotensor.pbtxt')
# normally you would do some training here
# we will just assign something to W
sess.run(tf.assign(W, [[1, 2],[4,5],[7,8]]))
sess.run(tf.assign(b, [1,1]))
#save a checkpoint file, which will store the above assignment
saver.save(sess, 'hellotensor.ckpt')
# Create a simple TF Graph
# By Omid Alemi - Jan 2017
# Works with TF r1.0
import tensorflow as tf
I = tf.placeholder(tf.float32, shape=[None,3], name='I') # input
W = tf.Variable(tf.zeros(shape=[3,2]), dtype=tf.float32, name='W') # weights
b = tf.Variable(tf.zeros(shape=[2]), dtype=tf.float32, name='b') # biases
O = tf.nn.relu(tf.matmul(I, W) + b, name='O') # activation / output
saver = tf.train.Saver()
init_op = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init_op)
# save the graph
tf.train.write_graph(sess.graph_def, '.', 'tfdroid.pbtxt')
# normally you would do some training here
# but fornow we will just assign something to W
sess.run(tf.assign(W, [[1, 2],[4,5],[7,8]]))
sess.run(tf.assign(b, [1,1]))
#save a checkpoint file, which will store the above assignment
saver.save(sess, 'tfdroid.ckpt')
# Preparing a TF model for usage in Android
# By Omid Alemi - Jan 2017
# Works with TF <r1.0
import sys
import tensorflow as tf
from tensorflow.python.tools import freeze_graph
from tensorflow.python.tools import optimize_for_inference_lib
MODEL_NAME = 'hellotensor'
# Freeze the graph
input_graph_path = MODEL_NAME+'.pbtxt'
checkpoint_path = './'+MODEL_NAME+'.ckpt'
input_saver_def_path = ""
input_binary = False
output_node_names = "O"
restore_op_name = "save/restore_all"
filename_tensor_name = "save/Const:0"
output_frozen_graph_name = 'frozen_'+MODEL_NAME+'.pb'
output_optimized_graph_name = 'optimized_'+MODEL_NAME+'.pb'
clear_devices = True
freeze_graph.freeze_graph(input_graph_path, input_saver_def_path,
input_binary, checkpoint_path, output_node_names,
restore_op_name, filename_tensor_name,
output_frozen_graph_name, clear_devices, "")
# Optimize for inference
input_graph_def = tf.GraphDef()
with tf.gfile.Open(output_frozen_graph_name, "r") as f:
data = f.read()
input_graph_def.ParseFromString(data)
output_graph_def = optimize_for_inference_lib.optimize_for_inference(
input_graph_def,
["I"], # an array of the input node(s)
["O"], # an array of output nodes
tf.float32.as_datatype_enum)
# Save the optimized graph
f = tf.gfile.FastGFile(output_optimized_graph_name, "w")
f.write(output_graph_def.SerializeToString())
# tf.train.write_graph(output_graph_def, './', output_optimized_graph_name)
# Preparing a TF model for usage in Android
# By Omid Alemi - Jan 2017
# Works with TF r1.0
import sys
import tensorflow as tf
from tensorflow.python.tools import freeze_graph
from tensorflow.python.tools import optimize_for_inference_lib
MODEL_NAME = 'tfdroid'
# Freeze the graph
input_graph_path = MODEL_NAME+'.pbtxt'
checkpoint_path = './'+MODEL_NAME+'.ckpt'
input_saver_def_path = ""
input_binary = False
output_node_names = "O"
restore_op_name = "save/restore_all"
filename_tensor_name = "save/Const:0"
output_frozen_graph_name = 'frozen_'+MODEL_NAME+'.pb'
output_optimized_graph_name = 'optimized_'+MODEL_NAME+'.pb'
clear_devices = True
freeze_graph.freeze_graph(input_graph_path, input_saver_def_path,
input_binary, checkpoint_path, output_node_names,
restore_op_name, filename_tensor_name,
output_frozen_graph_name, clear_devices, "")
# Optimize for inference
input_graph_def = tf.GraphDef()
with tf.gfile.Open(output_frozen_graph_name, "r") as f:
data = f.read()
input_graph_def.ParseFromString(data)
output_graph_def = optimize_for_inference_lib.optimize_for_inference(
input_graph_def,
["I"], # an array of the input node(s)
["O"], # an array of output nodes
tf.float32.as_datatype_enum)
# Save the optimized graph
f = tf.gfile.FastGFile(output_optimized_graph_name, "w")
f.write(output_graph_def.SerializeToString())
# tf.train.write_graph(output_graph_def, './', output_optimized_graph_name)
@Pratiush
Copy link

I got an issue in prep_model_tf1.py when running it says: KeyError: "The following input nodes were not found: {'I'}\n"

Then I fixed it by actually printing the graph by --- print(input_graph_def)
and finally used first node name as input_node_names. In my case ['I_1'] worked fine.

@deepaksuresh
Copy link

How can I convert a SavedModel, that was trained and saved with tf.estimator, to .pb format?
I can load the SavedModel using predictor = tf.contrib.predictor.from_saved_model(saved_model_dir) and perform inference on it. I'd like to use this model on android, which requires the model to in .pb format.

@divyanshujhawar
Copy link

@deepaksuresh Did you find a way to do it?

@nagasairatnakar
Copy link

When the tired to execute the example line to line. Every thing went fine and able to generte optimize pb file.
But when i ported to android. It is not giving output and no errors also.
What should i do

@yashwantptl7
Copy link

How do I load this .pb file in my python code for prediction ? Is their any step by step guide for this? I checked many articles but they not clearly specify it and are too confusing.

@anujonthemove
Copy link

@yashwantptl7
I think it's a bit too late for a reply but it might come in handy for others looking for some answers in this thread.
I am assuming that you got a '.pb' extension file after freezing your tensorflow model.

Here's how you can load a frozen model and use if for prediction:

def load_frozen_graph(frozen_graph):
    with tf.gfile.GFile(frozen_graph, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
    
    with tf.Graph().as_default() as graph:
       '''
        adding prefix here helps to distinctly mark the tensor names
       '''
        tf.import_graph_def(graph_def, name='prefix')
    
    return graph
def predict_from_frozen_graph(frozen_graph_path, X_test):
    
    y_pred_prime = None
    graph = load_frozen_graph(frozen_graph_path)
    for op in graph.get_operations():
        print(op.name)
    
    x = graph.get_tensor_by_name('prefix/input:0')
    y = graph.get_tensor_by_name('prefix/output:0')   
    
    with tf.Session(graph=graph) as sess:
        y_pred_prime = sess.run(y, feed_dict={x: X_test})
    
    return y_pred_prime

y_pred_prime = predict_from_frozen_graph('frozen_model.pb', X_test)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment