Skip to content

Instantly share code, notes, and snippets.

@riga
Created November 27, 2019 13:52
Show Gist options
  • Save riga/439f6905857ee95c794e307ec9c9828e to your computer and use it in GitHub Desktop.
Save riga/439f6905857ee95c794e307ec9c9828e to your computer and use it in GitHub Desktop.
Patch a TensorFlow graph
# coding: utf-8
import os
import tensorflow as tf
def read_constant_graph(graph_path, create_session=True, as_text=False):
graph = tf.Graph()
with graph.as_default():
graph_def = tf.GraphDef()
if as_text:
from google.protobuf import text_format
with open(graph_path, "r") as f:
text_format.Merge(f.read(), graph_def)
else:
from tensorflow.python.platform import gfile
with gfile.FastGFile(graph_path, "rb") as f:
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name="")
if create_session:
session = tf.Session(graph=graph)
return graph, session
else:
return graph
def write_constant_graph(session, output_names, graph_path, **kwargs):
kwargs.setdefault("as_text", False)
constant_graph = tf.graph_util.convert_variables_to_constants(session,
session.graph.as_graph_def(), output_names)
graph_path = os.path.normpath(os.path.abspath(graph_path))
graph_dir, graph_name = os.path.split(graph_path)
if not os.path.exists(graph_dir):
os.makedirs(graph_dir)
if os.path.exists(graph_path):
os.remove(graph_path)
tf.train.write_graph(constant_graph, graph_dir, graph_name, **kwargs)
graph, sess = read_constant_graph("padding10_fullModel.pb")
id_op = graph.get_operation_by_name("pid_output/Softmax")
er_op = graph.get_operation_by_name("enreg_output/BiasAdd")
with tf.variable_scope("patch"):
e_mean = 213.90352475881576
e_std = 108.05413626100672
e_rescaled = er_op.outputs[0] * e_std + e_mean
id_t = id_op.outputs[0]
id_ph_el_mu = id_t[:, 0:3]
id_pi0 = id_t[:, 0:1] * 0.
id_ch = id_t[:, 3:4]
id_nh_am_un = id_t[:, 0:3] * 0.
id_concat = tf.concat([id_ph_el_mu, id_pi0, id_ch, id_nh_am_un], axis=1)
tf.identity(e_rescaled, name="output/regressed_energy")
tf.identity(id_concat, name="output/id_probabilities")
write_constant_graph(sess, ["output/id_probabilities", "output/regressed_energy"], "patched.pb")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment