Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
A simple example for saving a tensorflow model and preparing it for using on Android
# Create a simple TF Graph
# By Omid Alemi - Jan 2017
# Works with TF <r1.0
import tensorflow as tf
I = tf.placeholder(tf.float32, shape=[None,3], name='I') # input
W = tf.Variable(tf.zeros_initializer(shape=[3,2]), dtype=tf.float32, name='W') # weights
b = tf.Variable(tf.zeros_initializer(shape=[2]), dtype=tf.float32, name='b') # biases
O = tf.nn.relu(tf.matmul(I, W) + b, name='O') # activation / output
saver = tf.train.Saver()
init_op = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init_op)
# save the graph
tf.train.write_graph(sess.graph_def, '.', 'hellotensor.pbtxt')
# normally you would do some training here
# we will just assign something to W
sess.run(tf.assign(W, [[1, 2],[4,5],[7,8]]))
sess.run(tf.assign(b, [1,1]))
#save a checkpoint file, which will store the above assignment
saver.save(sess, 'hellotensor.ckpt')
# Create a simple TF Graph
# By Omid Alemi - Jan 2017
# Works with TF r1.0
import tensorflow as tf
I = tf.placeholder(tf.float32, shape=[None,3], name='I') # input
W = tf.Variable(tf.zeros(shape=[3,2]), dtype=tf.float32, name='W') # weights
b = tf.Variable(tf.zeros(shape=[2]), dtype=tf.float32, name='b') # biases
O = tf.nn.relu(tf.matmul(I, W) + b, name='O') # activation / output
saver = tf.train.Saver()
init_op = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init_op)
# save the graph
tf.train.write_graph(sess.graph_def, '.', 'tfdroid.pbtxt')
# normally you would do some training here
# but fornow we will just assign something to W
sess.run(tf.assign(W, [[1, 2],[4,5],[7,8]]))
sess.run(tf.assign(b, [1,1]))
#save a checkpoint file, which will store the above assignment
saver.save(sess, 'tfdroid.ckpt')
# Preparing a TF model for usage in Android
# By Omid Alemi - Jan 2017
# Works with TF <r1.0
import sys
import tensorflow as tf
from tensorflow.python.tools import freeze_graph
from tensorflow.python.tools import optimize_for_inference_lib
MODEL_NAME = 'hellotensor'
# Freeze the graph
input_graph_path = MODEL_NAME+'.pbtxt'
checkpoint_path = './'+MODEL_NAME+'.ckpt'
input_saver_def_path = ""
input_binary = False
output_node_names = "O"
restore_op_name = "save/restore_all"
filename_tensor_name = "save/Const:0"
output_frozen_graph_name = 'frozen_'+MODEL_NAME+'.pb'
output_optimized_graph_name = 'optimized_'+MODEL_NAME+'.pb'
clear_devices = True
freeze_graph.freeze_graph(input_graph_path, input_saver_def_path,
input_binary, checkpoint_path, output_node_names,
restore_op_name, filename_tensor_name,
output_frozen_graph_name, clear_devices, "")
# Optimize for inference
input_graph_def = tf.GraphDef()
with tf.gfile.Open(output_frozen_graph_name, "r") as f:
data = f.read()
input_graph_def.ParseFromString(data)
output_graph_def = optimize_for_inference_lib.optimize_for_inference(
input_graph_def,
["I"], # an array of the input node(s)
["O"], # an array of output nodes
tf.float32.as_datatype_enum)
# Save the optimized graph
f = tf.gfile.FastGFile(output_optimized_graph_name, "w")
f.write(output_graph_def.SerializeToString())
# tf.train.write_graph(output_graph_def, './', output_optimized_graph_name)
# Preparing a TF model for usage in Android
# By Omid Alemi - Jan 2017
# Works with TF r1.0
import sys
import tensorflow as tf
from tensorflow.python.tools import freeze_graph
from tensorflow.python.tools import optimize_for_inference_lib
MODEL_NAME = 'tfdroid'
# Freeze the graph
input_graph_path = MODEL_NAME+'.pbtxt'
checkpoint_path = './'+MODEL_NAME+'.ckpt'
input_saver_def_path = ""
input_binary = False
output_node_names = "O"
restore_op_name = "save/restore_all"
filename_tensor_name = "save/Const:0"
output_frozen_graph_name = 'frozen_'+MODEL_NAME+'.pb'
output_optimized_graph_name = 'optimized_'+MODEL_NAME+'.pb'
clear_devices = True
freeze_graph.freeze_graph(input_graph_path, input_saver_def_path,
input_binary, checkpoint_path, output_node_names,
restore_op_name, filename_tensor_name,
output_frozen_graph_name, clear_devices, "")
# Optimize for inference
input_graph_def = tf.GraphDef()
with tf.gfile.Open(output_frozen_graph_name, "r") as f:
data = f.read()
input_graph_def.ParseFromString(data)
output_graph_def = optimize_for_inference_lib.optimize_for_inference(
input_graph_def,
["I"], # an array of the input node(s)
["O"], # an array of output nodes
tf.float32.as_datatype_enum)
# Save the optimized graph
f = tf.gfile.FastGFile(output_optimized_graph_name, "w")
f.write(output_graph_def.SerializeToString())
# tf.train.write_graph(output_graph_def, './', output_optimized_graph_name)
@PanagiotisPtr

This comment has been minimized.

Copy link

PanagiotisPtr commented Jul 28, 2017

The script "prep_model_tf1.py" has an issue when running it. It give the following error message:
'utf-8' codec can't decode byte 0x80 in position 98: invalid start byte
What cause this error is the fact than the file needs to be read in Binary.

So just change from "r" (read):
with tf.gfile.Open(output_frozen_graph_name, "r") as f:
data = f.read()
input_graph_def.

To "rb" (read binary):
with tf.gfile.Open(output_frozen_graph_name, "rb") as f:
data = f.read()
input_graph_def.

@Pratiush

This comment has been minimized.

Copy link

Pratiush commented Dec 20, 2017

I got an issue in prep_model_tf1.py when running it says: KeyError: "The following input nodes were not found: {'I'}\n"

Then I fixed it by actually printing the graph by --- print(input_graph_def)
and finally used first node name as input_node_names. In my case ['I_1'] worked fine.

@deepaksuresh

This comment has been minimized.

Copy link

deepaksuresh commented Mar 5, 2018

How can I convert a SavedModel, that was trained and saved with tf.estimator, to .pb format?
I can load the SavedModel using predictor = tf.contrib.predictor.from_saved_model(saved_model_dir) and perform inference on it. I'd like to use this model on android, which requires the model to in .pb format.

@divyanshujhawar

This comment has been minimized.

Copy link

divyanshujhawar commented Mar 17, 2018

@deepaksuresh Did you find a way to do it?

@nagasairatnakar

This comment has been minimized.

Copy link

nagasairatnakar commented Apr 21, 2018

When the tired to execute the example line to line. Every thing went fine and able to generte optimize pb file.
But when i ported to android. It is not giving output and no errors also.
What should i do

@yashwantptl7

This comment has been minimized.

Copy link

yashwantptl7 commented Sep 20, 2018

How do I load this .pb file in my python code for prediction ? Is their any step by step guide for this? I checked many articles but they not clearly specify it and are too confusing.

@anujonthemove

This comment has been minimized.

Copy link

anujonthemove commented May 4, 2019

@yashwantptl7
I think it's a bit too late for a reply but it might come in handy for others looking for some answers in this thread.
I am assuming that you got a '.pb' extension file after freezing your tensorflow model.

Here's how you can load a frozen model and use if for prediction:

def load_frozen_graph(frozen_graph):
    with tf.gfile.GFile(frozen_graph, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
    
    with tf.Graph().as_default() as graph:
       '''
        adding prefix here helps to distinctly mark the tensor names
       '''
        tf.import_graph_def(graph_def, name='prefix')
    
    return graph
def predict_from_frozen_graph(frozen_graph_path, X_test):
    
    y_pred_prime = None
    graph = load_frozen_graph(frozen_graph_path)
    for op in graph.get_operations():
        print(op.name)
    
    x = graph.get_tensor_by_name('prefix/input:0')
    y = graph.get_tensor_by_name('prefix/output:0')   
    
    with tf.Session(graph=graph) as sess:
        y_pred_prime = sess.run(y, feed_dict={x: X_test})
    
    return y_pred_prime

y_pred_prime = predict_from_frozen_graph('frozen_model.pb', X_test)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.