Skip to content

Instantly share code, notes, and snippets.

@applenob
Last active April 26, 2023 18:37
Show Gist options
  • Save applenob/977f7627345c4b83149752e2f1c88a50 to your computer and use it in GitHub Desktop.
Save applenob/977f7627345c4b83149752e2f1c88a50 to your computer and use it in GitHub Desktop.
Load tensorflow model from frozen pb file.
# coding=utf-8
import tensorflow as tf
def get_session():
"""load a new session"""
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
return tf.Session(config=config)
def load_frozen_graph(frozen_graph_filename):
"""load a graph from protocol buffer file"""
# We load the protobuf file from the disk and parse it to retrieve the
# unserialized graph_def
with tf.gfile.GFile(frozen_graph_filename, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
# Block: add this block only when error like
# "Input 0 of node X was passed float from Y:0 incompatible with expected float_ref." occur.
for node in graph_def.node:
if node.op == 'RefSwitch':
node.op = 'Switch'
for index in range(len(node.input)):
if 'moving_' in node.input[index]:
node.input[index] = node.input[index] + '/read'
elif node.op == 'AssignSub':
node.op = 'Sub'
if 'use_locking' in node.attr: del node.attr['use_locking']
elif node.op == 'AssignAdd':
node.op = 'Add'
if 'use_locking' in node.attr: del node.attr['use_locking']
elif node.op == 'Assign':
node.op = 'Identity'
if 'use_locking' in node.attr: del node.attr['use_locking']
if 'validate_shape' in node.attr: del node.attr['validate_shape']
if len(node.input) == 2:
# input0: ref: Should be from a Variable node. May be uninitialized.
# input1: value: The value to be assigned to the variable.
node.input[0] = node.input[1]
del node.input[1]
# Block end
# Then, we import the graph_def into a new Graph and returns it
with tf.Graph().as_default() as graph:
# The name var will prefix every op/nodes in your graph
# Since we load everything in a new graph, this is not needed
tf.import_graph_def(graph_def)
return graph
def load_graph_session_from_pb(pb_file, print_op=False):
"""load graph and session from protocol buffer file"""
graph = load_frozen_graph(pb_file)
if print_op:
for op in graph.get_operations():
print(op.name)
with graph.as_default():
sess = get_session()
return graph, sess
graph, sess = load_graph_session_from_pb("elmo.pb")
input = graph.get_operation_by_name("input").outputs[0]
output = graph.get_operation_by_name("output").outputs[0]
def predict_func(sess, one_batch):
output_feeds = [output]
feed_dict = {input: one_batch}
return sess.run(output_feeds, feed_dict=feed_dict)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment