Skip to content

Instantly share code, notes, and snippets.

@peci1
Created June 1, 2018 20:25
Show Gist options
  • Save peci1/80cf0dd79986db83b4c99d0714ddf2ff to your computer and use it in GitHub Desktop.
Save peci1/80cf0dd79986db83b4c99d0714ddf2ff to your computer and use it in GitHub Desktop.
Tensorflow executor of models from TF 1.4+ on TF 1.3- (python 3)
import tensorflow as tf
import numpy as np
class PbModelExecutor:
def __init__(self, model_path, input_shape, input_name='input', output_name='output:0'):
self.session = tf.Session()
self.input = tf.placeholder(tf.float32, shape=input_shape)
graph_def = tf.GraphDef()
with open(model_path, "rb") as f:
graph_def.ParseFromString(f.read())
self.fix_graph_def(graph_def)
self.out = tf.import_graph_def(graph_def, input_map={input_name: self.input}, return_elements=[output_name], name='output')
self.coord = tf.train.Coordinator()
self.threads = tf.train.start_queue_runners(sess=self.session, coord=self.coord)
@staticmethod
def fix_graph_def(graph_def):
# fix nodes
for node in graph_def.node:
if node.op == 'RefSwitch':
node.op = 'Switch'
for index in range(len(node.input)):
if 'moving_' in node.input[index]:
node.input[index] = node.input[index] + '/read'
elif node.op == 'AssignSub':
node.op = 'Sub'
if 'use_locking' in node.attr:
del node.attr['use_locking']
if "dilations" in node.attr:
del node.attr["dilations"]
if "index_type" in node.attr:
del node.attr["index_type"]
def execute(self, input):
output, = self.session.run(self.out, feed_dict={self.input: input})
return output
def close(self):
self.coord.request_stop()
self.coord.join(self.threads)
self.session.close()
import tensorflow as tf
from tensorflow import graph_util as gu
# create the model and session
export_dir = "."
SAVE_NODES = ['input', 'output']
output_graph_def = gu.extract_sub_graph(
gu.remove_training_nodes(
gu.convert_variables_to_constants(session, self.graph.as_graph_def(), SAVE_NODES),
SAVE_NODES),
SAVE_NODES)
tf.train.write_graph(output_graph_def, export_dir, 'model.pb', as_text=False)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment