Skip to content

Instantly share code, notes, and snippets.

@05jd
Last active June 10, 2018 08:27
Show Gist options
  • Save 05jd/93eeceaabe0e4bfeac6938ce2b2a673e to your computer and use it in GitHub Desktop.
Save 05jd/93eeceaabe0e4bfeac6938ce2b2a673e to your computer and use it in GitHub Desktop.
Simple session context manager of tensorflow session
import tensorflow as tf
class SessionContextManager(object):
"""Session context management of tensorflow session and graph.
This class handles a default session and a default graph at once
for both Tensorflow and Keras.
"""
def __init__(self, session):
"""
Arguments:
session (tf.Session): an instance of tensorflow session
"""
self._keras_backend = tf.keras.backend
self._prev_keras_default_sess = self._keras_backend.get_session()
self._prev_tf_default_session = tf.get_default_session()
self._prev_tf_default_graph = tf.get_default_graph()
self._session = session
self._session_context_manager = None
self._graph_context_manager = None
def __enter__(self):
self._graph_context_manager = self._session.graph.as_default()
self._session_context_manager = self._session.as_default()
self._graph_context_manager.__enter__()
self._keras_backend.set_session(self._session)
return self._session_context_manager.__enter__()
def __exit__(self, exec_type, exec_value, exec_tb):
self._session_context_manager.__exit__(
exec_type, exec_value, exec_tb)
self._graph_context_manager.__exit__(
exec_type, exec_value, exec_tb)
self._graph_context_manager = None
self._session_context_manager = None
self._keras_backend.set_session(self._prev_keras_default_sess)
def test_session_manager():
from tensorflow.python.keras import backend as K
default_sess = tf.get_default_session()
default_graph = tf.get_default_graph()
default_keras_sess = K.get_session()
graph = tf.Graph()
sess = tf.Session(graph=graph)
with SessionContextManager(sess):
assert K.get_session() is sess
assert tf.get_default_session() is sess
assert tf.get_default_graph() is graph
assert K.get_session() is default_keras_sess
assert tf.get_default_session() is default_sess
assert tf.get_default_graph() is default_graph
if __name__ == '__main__':
test_session_manager()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment