Last active
March 5, 2021 11:48
-
-
Save milannedic/7a0af2671633e2b8c4deb29aa0892a36 to your computer and use it in GitHub Desktop.
Steps needed to export Keras model .h5 file to .pb file used by Tensorboard in order to display network graph.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# How to use this: | |
# | |
# Step 1. You need to export Keras model to .h5 file | |
# which includes network architecture: | |
# https://github.com/amir-abdi/keras_to_tensorflow | |
# You can do this by including (**) line below your model definition | |
# for example, in my case: | |
# inception_v3 = keras.applications.inception_v3.InceptionV3( | |
# include_top=True, | |
# weights=None, | |
# input_tensor=None, | |
# input_shape=img_shape, | |
# pooling=None, | |
# classes=num_classes) | |
# | |
# inception_v3.load_weights(pretrained_model, by_name=True) | |
# # (**) needed this to export model architecture. | |
# inception_v3.save('inception_v3_saved_model_mnedic.h5') | |
# Step 2. | |
# $ python3 keras_to_tensorflow.py -input_model_file inception_v3_saved_model_mnedic.h5 | |
# or | |
# $ python keras_to_tensorflow.py -input_model_file inception_v3_saved_model_mnedic.h5 | |
# This should generate: | |
# inception_v3_saved_model_mnedic.h5.pb file | |
# Step 3. Run: $ python3 tensorboard_graph.py | |
# Step 4. Run: $ tensorboard --logdir=path/to/log-directory/ | |
# tensorboard --logdir=logs/tests/2/ | |
# Step 5. Navigate your browser to shown address. You should see graph | |
import tensorflow as tf | |
from tensorflow.python.platform import gfile | |
with tf.Session() as sess: | |
model_filename = '/home/mnedic/int_reload/inception_v3_saved_model_mnedic.h5.pb' | |
with gfile.FastGFile(model_filename, 'rb') as f: | |
graph_def = tf.GraphDef() | |
graph_def.ParseFromString(f.read()) | |
g_in = tf.import_graph_def(graph_def) | |
LOGDIR = 'logs/tests/1/' | |
train_writer = tf.summary.FileWriter(LOGDIR) | |
train_writer.add_graph(sess.graph) | |
train_writer.flush() | |
train_writer.close() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment