Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
[load_model] load a ckpt or frozen_pb #tf #dl
def load_model(model, input_map=None):
"""Loads a tensorflow model and restore the variables to the default session."""
# Check if the model is a model directory (containing a metagraph and a checkpoint file)
# or if it is a protobuf file with a frozen graph
model_exp = os.path.expanduser(model)
if (os.path.isfile(model_exp)):
print('Model filename: %s' % model_exp)
with tf.gfile.FastGFile(model_exp, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, input_map=input_map, name='')
else:
print('Model directory: %s' % model_exp)
meta_file, ckpt_file = get_model_filenames(model_exp)
print('Metagraph file: %s' % meta_file)
print('Checkpoint file: %s' % ckpt_file)
saver = tf.train.import_meta_graph(os.path.join(model_exp, meta_file), input_map=input_map)
default_sess = tf.get_default_session()
saver.restore(tf.get_default_session(), os.path.join(model_exp, ckpt_file))
def get_model_filenames(model_dir):
ckpt = tf.train.get_checkpoint_state(model_dir)
if ckpt and ckpt.model_checkpoint_path:
ckpt_file = os.path.basename(ckpt.model_checkpoint_path)
meta_file = ckpt_file + '.meta'
return meta_file, ckpt_file
else:
raise ValueError('No checkpoint file found in the model directory (%s)' % model_dir)
images_placeholder = tf.placeholder(dtype=tf.uint8, shape=[None, None, None], name='images_ph')
image = tf.image.convert_image_dtype(images_placeholder, dtype=tf.float32)
image = tf.image.resize_images(image, (256, 256))
image = tf.expand_dims(image, axis=0)
input_map = {"sources_ph": image}
load_model(CHECKPOINT_DIR, input_map=input_map) # input_map can also leave none if you don't need to add preprocess
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.