Created
June 1, 2018 20:25
-
-
Save peci1/80cf0dd79986db83b4c99d0714ddf2ff to your computer and use it in GitHub Desktop.
Tensorflow executor of models from TF 1.4+ on TF 1.3- (python 3)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
import numpy as np | |
class PbModelExecutor: | |
def __init__(self, model_path, input_shape, input_name='input', output_name='output:0'): | |
self.session = tf.Session() | |
self.input = tf.placeholder(tf.float32, shape=input_shape) | |
graph_def = tf.GraphDef() | |
with open(model_path, "rb") as f: | |
graph_def.ParseFromString(f.read()) | |
self.fix_graph_def(graph_def) | |
self.out = tf.import_graph_def(graph_def, input_map={input_name: self.input}, return_elements=[output_name], name='output') | |
self.coord = tf.train.Coordinator() | |
self.threads = tf.train.start_queue_runners(sess=self.session, coord=self.coord) | |
@staticmethod | |
def fix_graph_def(graph_def): | |
# fix nodes | |
for node in graph_def.node: | |
if node.op == 'RefSwitch': | |
node.op = 'Switch' | |
for index in range(len(node.input)): | |
if 'moving_' in node.input[index]: | |
node.input[index] = node.input[index] + '/read' | |
elif node.op == 'AssignSub': | |
node.op = 'Sub' | |
if 'use_locking' in node.attr: | |
del node.attr['use_locking'] | |
if "dilations" in node.attr: | |
del node.attr["dilations"] | |
if "index_type" in node.attr: | |
del node.attr["index_type"] | |
def execute(self, input): | |
output, = self.session.run(self.out, feed_dict={self.input: input}) | |
return output | |
def close(self): | |
self.coord.request_stop() | |
self.coord.join(self.threads) | |
self.session.close() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
from tensorflow import graph_util as gu | |
# create the model and session | |
export_dir = "." | |
SAVE_NODES = ['input', 'output'] | |
output_graph_def = gu.extract_sub_graph( | |
gu.remove_training_nodes( | |
gu.convert_variables_to_constants(session, self.graph.as_graph_def(), SAVE_NODES), | |
SAVE_NODES), | |
SAVE_NODES) | |
tf.train.write_graph(output_graph_def, export_dir, 'model.pb', as_text=False) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment