Skip to content

Instantly share code, notes, and snippets.

@BenZstory
Created August 23, 2019 09:53
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save BenZstory/3ef2d6e59dc8ff133708c8b6122738b1 to your computer and use it in GitHub Desktop.
Save BenZstory/3ef2d6e59dc8ff133708c8b6122738b1 to your computer and use it in GitHub Desktop.
[load_model] load a ckpt or frozen_pb #tf #dl
def load_model(model, input_map=None):
"""Loads a tensorflow model and restore the variables to the default session."""
# Check if the model is a model directory (containing a metagraph and a checkpoint file)
# or if it is a protobuf file with a frozen graph
model_exp = os.path.expanduser(model)
if (os.path.isfile(model_exp)):
print('Model filename: %s' % model_exp)
with tf.gfile.FastGFile(model_exp, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, input_map=input_map, name='')
else:
print('Model directory: %s' % model_exp)
meta_file, ckpt_file = get_model_filenames(model_exp)
print('Metagraph file: %s' % meta_file)
print('Checkpoint file: %s' % ckpt_file)
saver = tf.train.import_meta_graph(os.path.join(model_exp, meta_file), input_map=input_map)
default_sess = tf.get_default_session()
saver.restore(tf.get_default_session(), os.path.join(model_exp, ckpt_file))
def get_model_filenames(model_dir):
ckpt = tf.train.get_checkpoint_state(model_dir)
if ckpt and ckpt.model_checkpoint_path:
ckpt_file = os.path.basename(ckpt.model_checkpoint_path)
meta_file = ckpt_file + '.meta'
return meta_file, ckpt_file
else:
raise ValueError('No checkpoint file found in the model directory (%s)' % model_dir)
images_placeholder = tf.placeholder(dtype=tf.uint8, shape=[None, None, None], name='images_ph')
image = tf.image.convert_image_dtype(images_placeholder, dtype=tf.float32)
image = tf.image.resize_images(image, (256, 256))
image = tf.expand_dims(image, axis=0)
input_map = {"sources_ph": image}
load_model(CHECKPOINT_DIR, input_map=input_map) # input_map can also leave none if you don't need to add preprocess
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment