Created
July 28, 2021 07:21
-
-
Save vgthengane/e600f4af6c6e4699d91f1df74ce8f113 to your computer and use it in GitHub Desktop.
Session state for streamlit library.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import streamlit as st | |
from streamlit.hashing import _CodeHasher | |
try: | |
# Before Streamlit 0.65 | |
from streamlit.ReportThread import get_report_ctx | |
from streamlit.server.Server import Server | |
except ModuleNotFoundError: | |
# After Streamlit 0.65 | |
from streamlit.report_thread import get_report_ctx | |
from streamlit.server.server import Server | |
class _SessionState: | |
def __init__(self, session, hash_funcs): | |
"""Initialize SessionState instance.""" | |
self.__dict__["_state"] = { | |
"data": {}, | |
"hash": None, | |
"hasher": _CodeHasher(hash_funcs), | |
"is_rerun": False, | |
"session": session, | |
} | |
# for item, value in kwargs.items(): | |
# if item not in self._state["data"]: | |
# self._state["data"][item] = value | |
def __call__(self, **kwargs): | |
"""Initialize state data once.""" | |
for item, value in kwargs.items(): | |
if item not in self._state["data"]: | |
self._state["data"][item] = value | |
def __getitem__(self, item): | |
"""Return a saved state value, None if item is undefined.""" | |
return self._state["data"].get(item, None) | |
def __getattr__(self, item): | |
"""Return a saved state value, None if item is undefined.""" | |
return self._state["data"].get(item, None) | |
def __setitem__(self, item, value): | |
"""Set state value.""" | |
self._state["data"][item] = value | |
def __setattr__(self, item, value): | |
"""Set state value.""" | |
self._state["data"][item] = value | |
def clear(self): | |
"""Clear session state and request a rerun.""" | |
self._state["data"].clear() | |
self._state["session"].request_rerun() | |
def sync(self): | |
"""Rerun the app with all state values up to date from the beginning to fix rollbacks.""" | |
# Ensure to rerun only once to avoid infinite loops | |
# caused by a constantly changing state value at each run. | |
# | |
# Example: state.value += 1 | |
if self._state["is_rerun"]: | |
self._state["is_rerun"] = False | |
elif self._state["hash"] is not None: | |
if self._state["hash"] != self._state["hasher"].to_bytes(self._state["data"], None): | |
self._state["is_rerun"] = True | |
self._state["session"].request_rerun() | |
self._state["hash"] = self._state["hasher"].to_bytes(self._state["data"], None) | |
def _get_session(): | |
session_id = get_report_ctx().session_id | |
session_info = Server.get_current()._get_session_info(session_id) | |
if session_info is None: | |
raise RuntimeError("Couldn't get your Streamlit Session object.") | |
return session_info.session | |
def get(hash_funcs=None, **kwargs): | |
session = _get_session() | |
if not hasattr(session, "_custom_session_state"): | |
session._custom_session_state = _SessionState(session, hash_funcs) | |
session._custom_session_state(**kwargs) | |
return session._custom_session_state |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment