Skip to content

Instantly share code, notes, and snippets.

@Mainvooid
Last active May 11, 2018 01:17
Show Gist options
  • Save Mainvooid/fc1793747bea94c9d74515c831254f1d to your computer and use it in GitHub Desktop.
Save Mainvooid/fc1793747bea94c9d74515c831254f1d to your computer and use it in GitHub Desktop.
tensorflow 读写pb文件#TenorFlow #Python
# read
with tf.Graph().as_default():
output_graph_def = tf.GraphDef()
output_graph_path = './flower_model_save.pb' #'./flower_model_save.pbtxt'
with open(output_graph_path, "rb") as f:
output_graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(output_graph_def, name="")
with tf.Session() as sess:
......
# write
constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def,["output"])
with tf.gfile.FastGFile("./flower_model/flower_model_save.pb", mode='wb') as f:
f.write(constant_graph.SerializeToString())
#or 只是保存了模型的结构,并不保存训练完毕的参数值
tf.train.write_graph(graph_def, pb_file_path, 'flower_model_save.pb', as_text=False)
tf.train.write_graph(graph_def, pb_file_path, 'flower_model_save.pbtxt', as_text=True)
#or
builder = tf.saved_model.builder.SavedModelBuilder(pb_file_path+'savemodel')
# 构造模型保存的内容,指定要保存的 session,特定的 tag,
# 输入输出信息字典,额外的信息
builder.add_meta_graph_and_variables(sess,['cpu_server_1'])
builder.save() # 保存 PB 模型
#保存好以后到saved_model_dir目录下,会有一个saved_model.pb文件以及variables文件夹。
#顾名思义,variables保存所有变量,saved_model.pb用于保存模型结构等信息。
#这种方法对应的导入模型的方法:
with tf.Session(graph=tf.Graph()) as sess:
tf.saved_model.loader.load(sess, ['cpu_1'], pb_file_path+'savemodel')
sess.run(tf.global_variables_initializer())
input_x = sess.graph.get_tensor_by_name('x:0')
input_y = sess.graph.get_tensor_by_name('y:0')
op = sess.graph.get_tensor_by_name('op_to_store:0')
ret = sess.run(op, feed_dict={input_x: 5, input_y: 5})
print(ret)
# 只需要指定要恢复模型的 session,模型的 tag,模型的保存路径即可,使用起来更加简单
#refer to https://zhuanlan.zhihu.com/p/32887066
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment