Created
August 23, 2019 09:53
[load_model] load a ckpt or frozen_pb #tf #dl
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def load_model(model, input_map=None): | |
"""Loads a tensorflow model and restore the variables to the default session.""" | |
# Check if the model is a model directory (containing a metagraph and a checkpoint file) | |
# or if it is a protobuf file with a frozen graph | |
model_exp = os.path.expanduser(model) | |
if (os.path.isfile(model_exp)): | |
print('Model filename: %s' % model_exp) | |
with tf.gfile.FastGFile(model_exp, 'rb') as f: | |
graph_def = tf.GraphDef() | |
graph_def.ParseFromString(f.read()) | |
tf.import_graph_def(graph_def, input_map=input_map, name='') | |
else: | |
print('Model directory: %s' % model_exp) | |
meta_file, ckpt_file = get_model_filenames(model_exp) | |
print('Metagraph file: %s' % meta_file) | |
print('Checkpoint file: %s' % ckpt_file) | |
saver = tf.train.import_meta_graph(os.path.join(model_exp, meta_file), input_map=input_map) | |
default_sess = tf.get_default_session() | |
saver.restore(tf.get_default_session(), os.path.join(model_exp, ckpt_file)) | |
def get_model_filenames(model_dir): | |
ckpt = tf.train.get_checkpoint_state(model_dir) | |
if ckpt and ckpt.model_checkpoint_path: | |
ckpt_file = os.path.basename(ckpt.model_checkpoint_path) | |
meta_file = ckpt_file + '.meta' | |
return meta_file, ckpt_file | |
else: | |
raise ValueError('No checkpoint file found in the model directory (%s)' % model_dir) | |
images_placeholder = tf.placeholder(dtype=tf.uint8, shape=[None, None, None], name='images_ph') | |
image = tf.image.convert_image_dtype(images_placeholder, dtype=tf.float32) | |
image = tf.image.resize_images(image, (256, 256)) | |
image = tf.expand_dims(image, axis=0) | |
input_map = {"sources_ph": image} | |
load_model(CHECKPOINT_DIR, input_map=input_map) # input_map can also leave none if you don't need to add preprocess |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment