Skip to content

Instantly share code, notes, and snippets.

@jubjamie
Created March 31, 2017 10:01
Show Gist options
  • Star 22 You must be signed in to star a gist
  • Fork 11 You must be signed in to fork a gist
  • Save jubjamie/2eec49ca1e4f58c5310d72918d991ef6 to your computer and use it in GitHub Desktop.
Save jubjamie/2eec49ca1e4f58c5310d72918d991ef6 to your computer and use it in GitHub Desktop.
Load .pb into Tensorboard
import tensorflow as tf
from tensorflow.python.platform import gfile
with tf.Session() as sess:
model_filename ='PATH_TO_PB.pb'
with gfile.FastGFile(model_filename, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
g_in = tf.import_graph_def(graph_def)
LOGDIR='/logs/tests/1/'
train_writer = tf.summary.FileWriter(LOGDIR)
train_writer.add_graph(sess.graph)
@Simo-H
Copy link

Simo-H commented Dec 27, 2017

Thank you very much for this gist, it was just what i needed.
I had a problem that the file was not writen to disk, adding these lines at the end of the file fixed the problem :

train_writer.flush()
train_writer.close()

Also it was necessary for me to change LOGDIR to LOGDIR='logs/tests/1/' (removed the first '\')
hope this help for anybody who have this problem as well.

@avielas
Copy link

avielas commented May 12, 2018

thanks! it was helpful

@sifat62
Copy link

sifat62 commented May 25, 2018

I have tried this code but this error message shown in Jupyter, any suggestion?


DecodeError Traceback (most recent call last)
in ()
5 with gfile.FastGFile(model_filename, 'rb') as f:
6 graph_def = tf.GraphDef()
----> 7 graph_def.ParseFromString(f.read())
8 g_in = tf.import_graph_def(graph_def)
9 LOGDIR='/logs'

DecodeError: Error parsing message

@andrewginns
Copy link

Perfect! This was super helpful

@xyzdcgan
Copy link

xyzdcgan commented Jan 1, 2019

While adding my frozen.pb file to it, it gives me an error
Traceback (most recent call last):
File "convert.py", line 30, in
g_in = tf.import_graph_def(graph_def)
File "/home/ios/.local/lib/python2.7/site-packages/tensorflow/python/util/deprecation.py", line 432, in new_func
return func(*args, **kwargs)
File "/home/ios/.local/lib/python2.7/site-packages/tensorflow/python/framework/importer.py", line 671, in import_graph_def
node, 'Input tensor %r %s' % (input_name, te)))
ValueError: graph_def is invalid at node u'Genc/Conv/BatchNorm/AssignMovingAvg': Input tensor 'Genc/Conv/BatchNorm/moving_mean:0' Cannot convert a tensor of type float32 to an input of type float32_ref.

and if run it through another code which is:

import tensorflow as tf
import sys
from tensorflow.python.platform import gfile

from tensorflow.core.protobuf import saved_model_pb2
from tensorflow.python.util import compat

with tf.Session() as sess:
model_filename ='saved_model.pb'
with gfile.FastGFile(model_filename, 'rb') as f:

  	data = compat.as_bytes(f.read())
  	sm = saved_model_pb2.SavedModel()
  	sm.ParseFromString(data)
  	#print(sm)
  	if 1 != len(sm.meta_graphs):
  		print('More than one graph found. Not sure which to write')
  		sys.exit(1)

  	#graph_def = tf.GraphDef()
    #graph_def.ParseFromString(sm.meta_graphs[0])
    g_in = tf.import_graph_def(sm.meta_graphs[0].graph_def)

LOGDIR='YOUR_LOG_LOCATION'
train_writer = tf.summary.FileWriter(LOGDIR)
train_writer.add_graph(sess.graph)

and gave me error:

google.protobuf.message.DecodeError: Error parsing message

any solution for this..??

i converted to .pb file from checkpoint file using tensorflow frozen_python file.

Copy link

ghost commented Feb 14, 2019

use tf.gfile.GFile instead of "gfile.FastGFile" and add
train_writer.flush()
train_writer.close()

@AI-ML-Enthusiast
Copy link

@jubjamie @ Santosh7vasa

I run the code and it saved a file like this "events.out.tfevents.1574405033.DESKTOP-TOJGNRH". Now How can I load to tensorboard and view the graph
Thanks

@cristiandapp
Copy link

cristiandapp commented Mar 14, 2021

Hi, I used it in colab:

import tensorflow as tf
from tensorflow.python.platform import gfile
with tf.Session() as sess:
    model_filename ='/content/saved_model.pb'
    with tf.gfile.GFile(model_filename, 'rb') as f:
        graph_def = tf.GraphDef()
        train_writerd.flush()
        train_writerd.close()
        graph_def.ParseFromString(f.read())
        g_in = tf.import_graph_def(graph_def)
LOGDIR='/content'
train_writerd = tf.summary.FileWriter(LOGDIR)
train_writerd.add_graph(sess.graph)

At the end I got this error: NameError: name 'train_writerd' is not defined

@AfifaIshtiaq
Copy link

I couldn't se any graph
import tensorflow as tf
from tensorflow.python.platform import gfile
with tf.compat.v1.Session() as sess:
model_filename ='inception_v1_inference.pb'
with tf.io.gfile.GFile(model_filename, 'rb') as f:
graph_def = tf.compat.v1.GraphDef()
graph_def.ParseFromString(f.read())
g_in = tf.import_graph_def(graph_def)
LOGDIR='logs/tests/1/'
train_writer = tf.compat.v1.summary.FileWriter(LOGDIR)
train_writer.add_graph(sess.graph)

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment