Skip to content

Instantly share code, notes, and snippets.

@iandewancker
Created August 16, 2018 14:44
Show Gist options
  • Save iandewancker/0206cc059472e0839a4af5ee40d1a370 to your computer and use it in GitHub Desktop.
Save iandewancker/0206cc059472e0839a4af5ee40d1a370 to your computer and use it in GitHub Desktop.
Convert keras model to pb file
# load keras model from disk
model_name = "GVC_IncepvtionV3_epoch_6_vanilla_vgg_chute_date_2018_07_27"
json_file = open(model_name+'.json', 'r')
model_json = json_file.read()
json_file.close()
model = tensorflow.keras.models.model_from_json(model_json)
# load weights into new model
model.load_weights(model_name+".h5")
print("Loaded model from disk")
from tensorflow.python.framework.graph_util import convert_variables_to_constants
# access the default graph
graph = tensorflow.keras.backend.get_session().graph
# retrieve the protobuf graph definition
input_graph_def = graph.as_graph_def()
output_node_names = "inception_resnet_v2_input,output/Softmax"
# check names of graph
# [n.name for n in graph.as_graph_def().node if "up_sampling2d_42" in n.name]
# TensorFlow built-in helper to export variables to constants
output_graph_def = convert_variables_to_constants(
sess=tensorflow.keras.backend.get_session(),
input_graph_def=input_graph_def, # GraphDef object holding the network
output_node_names=output_node_names.split(",")
)
tf.train.write_graph(output_graph_def,
"./",
model_name+"_graph.pb",
as_text=False)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment