Skip to content

Instantly share code, notes, and snippets.

@yoel-zeldes
Last active December 9, 2018 20:45
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save yoel-zeldes/1d2501af7a7fd6449bfdc04cb4b48b0c to your computer and use it in GitHub Desktop.
Save yoel-zeldes/1d2501af7a7fd6449bfdc04cb4b48b0c to your computer and use it in GitHub Desktop.
def _load_model(model_path):
"""
Load a tensorflow model from the given path.
It's assumed the path is either a directory containing a .meta file, or the .meta file itself.
If there's also a file containing the weights with the same name as the .meta file
(without the .meta extension), it'll be loaded as well.
"""
if os.path.isdir(model_path):
meta_filename = [filename for filename in os.listdir(model_path) if filename.endswith('.meta')]
assert len(meta_filename) == 1, 'expecting to get a .meta file or a directory containing a .meta file'
model_path = os.path.join(model_path, meta_filename[0])
else:
assert model_path.endswith('.meta'), 'expecting to get a .meta file or a directory containing a .meta file'
weights_path = model_path[:-len('.meta')]
graph = tf.Graph()
with graph.as_default():
saver = tf.train.import_meta_graph(model_path)
if os.path.isfile(weights_path):
session = tf.Session(graph=graph)
saver.restore(session, weights_path)
else:
session = None
return graph, session
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment