Last active
February 24, 2019 20:52
-
-
Save thepulkitagarwal/36bc45aa43ae43f83baa5111a89be73e to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# This was created with @warptime's help. Thank you! | |
from tensorflow.python.framework import graph_util | |
from tensorflow.python.framework import graph_io | |
from keras.models import load_model | |
from keras import backend as K | |
import os.path as osp | |
model = load_model(path_to_model) | |
nb_classes = 1 # The number of output nodes in the model | |
prefix_output_node_names_of_final_network = 'output_node' | |
K.set_learning_phase(0) | |
pred = [None]*nb_classes | |
pred_node_names = [None]*nb_classes | |
for i in range(nb_classes): | |
pred_node_names[i] = prefix_output_node_names_of_final_network+str(i) | |
pred[i] = tf.identity(model.output[i], name=pred_node_names[i]) | |
print('output nodes names are: ', pred_node_names) | |
sess = K.get_session() | |
output_fld = 'tensorflow_model/' | |
if not os.path.isdir(output_fld): | |
os.mkdir(output_fld) | |
output_graph_name = saved_model_path + '.pb' | |
output_graph_suffix = '_inference' | |
constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph.as_graph_def(), pred_node_names) | |
graph_io.write_graph(constant_graph, output_fld, output_graph_name, as_text=False) | |
print('saved the constant graph (ready for inference) at: ', osp.join(output_fld, output_graph_name)) |
I use this freeze code but when I try to use my .pb in Android application I have this error " input must be 4-dimensional[80,3,1]"
Can you help me
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
the code has to be modified if you really want to run it, I mean, something has been forgotten to be imported...