Created
August 23, 2019 09:53
-
-
Save BenZstory/3ef2d6e59dc8ff133708c8b6122738b1 to your computer and use it in GitHub Desktop.
[load_model] load a ckpt or frozen_pb #tf #dl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def load_model(model, input_map=None): | |
"""Loads a tensorflow model and restore the variables to the default session.""" | |
# Check if the model is a model directory (containing a metagraph and a checkpoint file) | |
# or if it is a protobuf file with a frozen graph | |
model_exp = os.path.expanduser(model) | |
if (os.path.isfile(model_exp)): | |
print('Model filename: %s' % model_exp) | |
with tf.gfile.FastGFile(model_exp, 'rb') as f: | |
graph_def = tf.GraphDef() | |
graph_def.ParseFromString(f.read()) | |
tf.import_graph_def(graph_def, input_map=input_map, name='') | |
else: | |
print('Model directory: %s' % model_exp) | |
meta_file, ckpt_file = get_model_filenames(model_exp) | |
print('Metagraph file: %s' % meta_file) | |
print('Checkpoint file: %s' % ckpt_file) | |
saver = tf.train.import_meta_graph(os.path.join(model_exp, meta_file), input_map=input_map) | |
default_sess = tf.get_default_session() | |
saver.restore(tf.get_default_session(), os.path.join(model_exp, ckpt_file)) | |
def get_model_filenames(model_dir): | |
ckpt = tf.train.get_checkpoint_state(model_dir) | |
if ckpt and ckpt.model_checkpoint_path: | |
ckpt_file = os.path.basename(ckpt.model_checkpoint_path) | |
meta_file = ckpt_file + '.meta' | |
return meta_file, ckpt_file | |
else: | |
raise ValueError('No checkpoint file found in the model directory (%s)' % model_dir) | |
images_placeholder = tf.placeholder(dtype=tf.uint8, shape=[None, None, None], name='images_ph') | |
image = tf.image.convert_image_dtype(images_placeholder, dtype=tf.float32) | |
image = tf.image.resize_images(image, (256, 256)) | |
image = tf.expand_dims(image, axis=0) | |
input_map = {"sources_ph": image} | |
load_model(CHECKPOINT_DIR, input_map=input_map) # input_map can also leave none if you don't need to add preprocess |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment