Last active
March 1, 2019 10:19
-
-
Save hewumars/b2f43c904b7e296547857444493c3a6f to your computer and use it in GitHub Desktop.
1.Keras H5转TF pb文件
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from keras.models import load_model | |
import tensorflow as tf | |
import os | |
from keras import backend as K | |
#路径设置 | |
input_path = 'weights/' | |
weight_file = '****.h5' | |
weight_file_path = os.path.join(input_path,weight_file) | |
output_graph_name = weight_file[:-3]+'.pb' | |
def h5_to_pb(h5_model,output_dir,model_name,out_prefix = "output_",log_tensorboard = True): | |
if osp.exists(output_dir) == False: | |
os.mkdir(output_dir) | |
out_nodes = [] | |
for i in range(len(h5_model.outputs)): | |
out_nodes.append(out_prefix + str(i + 1)) | |
tf.identity(h5_model.output[i],out_prefix + str(i + 1)) | |
sess = K.get_session() | |
from tensorflow.python.framework import graph_util,graph_io | |
init_graph = sess.graph.as_graph_def() | |
main_graph = graph_util.convert_variables_to_constants(sess,init_graph,out_nodes) | |
graph_io.write_graph(main_graph,output_dir,name = model_name,as_text = False) | |
if log_tensorboard: | |
from tensorflow.python.tools import import_pb_to_tensorboard | |
import_pb_to_tensorboard.import_to_tensorboard(osp.join(output_dir,model_name),output_dir) | |
#输出路径 | |
output_dir = osp.join(os.getcwd(),"trans_model") | |
#加载模型 | |
h5_model = load_model(weight_file_path) | |
h5_to_pb(h5_model,output_dir = output_dir,model_name = output_graph_name) | |
print('model saved') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
import argparse | |
from datetime import datetime | |
def time_tensorflow_run(session, target, info_string): | |
# Args: | |
# session:the TensorFlow session to run the computation under. | |
# target:需要评测的运算算子。 | |
# info_string:测试名称。 | |
num_steps_burn_in = 10 # 先定义预热轮数(头几轮跌代有显存加载、cache命中等问题因此可以跳过,只考量10轮迭代之后的计算时间) | |
total_duration = 0.0 # 记录总时间 | |
total_duration_squared = 0.0 # 总时间平方和 -----用来后面计算方差 | |
for i in xrange(FLAGS.num_batches + num_steps_burn_in): # 迭代轮数 | |
start_time = time.time() # 记录时间 | |
_ = session.run(target) # 每次迭代通过session.run(target) | |
duration = time.time() - start_time # | |
if i >= num_steps_burn_in: | |
if not i % 10: | |
print ('%s: step %d, duration = %.3f' % | |
(datetime.now(), i - num_steps_burn_in, duration)) | |
total_duration += duration # 累加便于后面计算每轮耗时的均值和标准差 | |
total_duration_squared += duration * duration | |
mn = total_duration / FLAGS.num_batches # 每轮迭代的平均耗时 | |
vr = total_duration_squared / FLAGS.num_batches - mn * mn # | |
sd = math.sqrt(vr) # 标准差 | |
print ('%s: %s across %d steps, %.3f +/- %.3f sec / batch' % | |
(datetime.now(), info_string, FLAGS.num_batches, mn, sd)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment