Skip to content

Instantly share code, notes, and snippets.

@bryant1410
Created June 14, 2018 18:06
Show Gist options
  • Save bryant1410/6c64b53484ad1fedadadea8a703b2d40 to your computer and use it in GitHub Desktop.
Save bryant1410/6c64b53484ad1fedadadea8a703b2d40 to your computer and use it in GitHub Desktop.
Load protobuf in tensorboard
# import tensorflow as tf
# from tensorflow.python.platform import gfile
# with tf.Session() as sess:
# model_filename ='aaa.pb'
# with gfile.FastGFile(model_filename, 'rb') as f:
# graph_def = tf.GraphDef()
# graph_def.ParseFromString(f.read())
# g_in = tf.import_graph_def(graph_def)
# LOGDIR='logsst2'
# train_writer = tf.summary.FileWriter(LOGDIR)
# train_writer.add_graph(sess.graph)
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import sys
import tensorflow as tf
from tensorflow.core.framework import graph_pb2
from tensorflow.python.client import session
from tensorflow.python.framework import importer
from tensorflow.python.framework import ops
from tensorflow.python.platform import app
from tensorflow.python.platform import gfile
from tensorflow.python.summary import summary
def import_to_tensorboard(model_dir, log_dir):
"""View an imported protobuf model (`.pb` file) as a graph in Tensorboard.
Args:
model_dir: The location of the protobuf (`pb`) model to visualize
log_dir: The location for the Tensorboard log to begin visualization from.
Usage:
Call this function with your model location and desired log directory.
Launch Tensorboard by pointing it to the log directory.
View your imported `.pb` model as a graph.
"""
with session.Session(graph=ops.Graph()) as sess:
with gfile.FastGFile(model_dir, "rb") as f:
graph_def = graph_pb2.GraphDef()
graph_def.ParseFromString(f.read())
importer.import_graph_def(graph_def)
pb_visual_writer = summary.FileWriter(log_dir)
pb_visual_writer.add_graph(sess.graph)
print("Model Imported. Visualize by running: "
"tensorboard --logdir={}".format(log_dir))
def import_graph(model_dir):
"""View an imported protobuf model (`.pb` file) as a graph in Tensorboard.
Args:
model_dir: The location of the protobuf (`pb`) model to visualize
log_dir: The location for the Tensorboard log to begin visualization from.
Usage:
Call this function with your model location and desired log directory.
Launch Tensorboard by pointing it to the log directory.
View your imported `.pb` model as a graph.
"""
with session.Session(graph=ops.Graph()) as sess:
with gfile.FastGFile(model_dir, "rb") as f:
graph_def = graph_pb2.GraphDef()
graph_def.ParseFromString(f.read())
importer.import_graph_def(graph_def)
return tf.get_default_graph()
def save_weights(sess, output_path, conv_var_names=None, conv_transpose_var_names=None):
"""Save the weights of the trainable variables, each one in a different file in output_path."""
if not conv_var_names:
conv_var_names = []
if not conv_transpose_var_names:
conv_transpose_var_names = []
print("Variable order:")
with open(output_path, 'w') as file_:
for var in tf.trainable_variables():
print(var.name)
if var.name in conv_var_names:
var = tf.transpose(var, perm=[3, 0, 1, 2])
elif var.name in conv_transpose_var_names:
var = tf.transpose(var, perm=[3, 1, 0, 2])
value = sess.run(var)
# noinspection PyTypeChecker
value.tofile(file_)
def main(unused_args):
import_to_tensorboard(FLAGS.model_dir, FLAGS.log_dir)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.register("type", "bool", lambda v: v.lower() == "true")
parser.add_argument(
"--model_dir",
type=str,
default="",
required=True,
help="The location of the protobuf (\'pb\') model to visualize.")
parser.add_argument(
"--log_dir",
type=str,
default="",
required=True,
help="The location for the Tensorboard log to begin visualization from.")
FLAGS, unparsed = parser.parse_known_args()
app.run(main=main, argv=[sys.argv[0]] + unparsed)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment