Skip to content

Instantly share code, notes, and snippets.

@jiankaiwang
Created September 18, 2019 09:12
Show Gist options
  • Save jiankaiwang/ff63e786162225121040b4090bc015d6 to your computer and use it in GitHub Desktop.
Save jiankaiwang/ff63e786162225121040b4090bc015d6 to your computer and use it in GitHub Desktop.
Example of model specifications to inputs and outputs.
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
@author: jiankaiwang (https://jiankaiwang.no-ip.biz/)
@version:
Tensorflow: 1.x (developed >= 1.13.2)
@description:
Example of model specifications to inputs and outputs.
@dependency:
OperateFrozenModel (TF1_FrozenModel.py, https://gist.github.com/jiankaiwang/24cc1bc8b38ce72bba73f7fb326f7b9e)
@changelog (main):
2019-04: initial commit
2019-09: released on gist.github.com
"""
import tensorflow as tf
import numpy as np
import OperateFrozenModel
# In[]
pb_path = "/Users/jiankaiwang/Desktop/output_graph.pb"
merged_pb_path = "/Users/jiankaiwang/Desktop/merged_graph.pb"
# In[]
tf.reset_default_graph()
# merged_graph
merged_graph = tf.Graph()
with merged_graph.as_default():
# defined a specific input
# shape [None, 224, 224, 3] depends on your model's input
inputs = tf.placeholder(tf.float32, [None, 224, 224, 3], "input")
# import a trained model graph
graph_def = tf.GraphDef()
with tf.gfile.GFile(pb_path, "rb") as fin:
graph_def.ParseFromString(fin.read())
graph_outputs, = tf.import_graph_def(
graph_def,
input_map={"Placeholder:0": inputs},
return_elements=["final_result:0"],
name="")
# defined a specific output
outputs = tf.identity(graph_outputs, name="output")
# exported as a frozen model
with tf.Session() as sess:
state, graph = OperateFrozenModel.save_sess_into_frozen_model(
sess, ["output"], merged_pb_path)
print(state, graph)
# In[]
# sample input
sampled = np.random.randn(1, 224, 224, 3)
_, merged_graph = OperateFrozenModel.load_frozen_model(merged_pb_path)
with merged_graph.as_default():
inputs = merged_graph.get_tensor_by_name("input:0")
outputs = merged_graph.get_tensor_by_name("output:0")
with tf.Session() as sess:
print(sess.run(outputs, feed_dict={inputs: sampled}))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment