Skip to content

Instantly share code, notes, and snippets.

@ruberthbarros
Created October 29, 2020 20:04
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ruberthbarros/f586b091742d7467839e44031976d732 to your computer and use it in GitHub Desktop.
Save ruberthbarros/f586b091742d7467839e44031976d732 to your computer and use it in GitHub Desktop.
Prototype for an editable table API using streamlit framework.
import numbers
import pandas as pd
import streamlit as st
from streamlit.hashing import _CodeHasher
from streamlit.report_thread import get_report_ctx
from streamlit.server.server import Server
def main():
state = get_state()
if state.data is None:
state.data = pd.DataFrame({"A": [1, 2], "B": [3, 4]})
editor = TableEditor("table1", state.data)
# Check for button interactions and updates the internal data state
editor.interact()
state.data = editor.data
st.table(state.data)
state.sync()
class TableEditor:
"""Encapsulates editable tables using streamlit.
Usage:
>>> original_df = pd.DataFrame(...)
>>> editor = TableEditor("editor uid", original_df)
>>> editor.interact()
>>> edited_df = editor.data
"""
def __init__(self, uid, dataframe, layout=None):
"""Initialize TableEditor instance.
Args:
uid (str): Table unique identifier to avoid widget key conflict.
dataframe (pandas.DataFrame): Data to be edited.
layout (list, optional): List of column proportions. See
https://docs.streamlit.io/en/stable/api.html#streamlit.beta_columns.
Defaults to None.
"""
self._uid = uid
self._data = dataframe.copy()
self._n_rows = dataframe.shape[0]
self._n_cols = dataframe.shape[1]
self._cells = {}
self._update_button = None
self._add_row_button = None
self._delete_buttons = {}
if layout is None:
# If layout not defined the dataframe columns will be 5 times bigger
# than Delete buttons column
layout = st.beta_columns(
[5 if col < self._n_cols else 1 for col in range(self._n_cols + 1)]
)
self._layout = layout
self._create_table()
self._create_buttons()
@property
def data(self):
return self._data
def interact(self):
if self._update_button:
self._update()
if self._add_row_button:
self._add_row()
for key, button in self._delete_buttons.items():
if button:
# key[1] is always the row index
self._delete_row(key[1])
break
def _create_table(self):
# Gets only layout columns to put actual data from dataframe - indices [0:n_cols]
data_columns = self._layout[:self._n_cols]
for col_index, column in enumerate(data_columns):
# Writes column names
column.markdown(f"**{self._data.columns[col_index]}**")
for row_index in range(self._n_rows):
key = (self._uid, col_index, row_index)
with column:
self._add_cell(key, self._data.iloc[row_index, col_index])
def _create_buttons(self):
# Always the last column in column layout
button_del_column = self._layout[self._n_cols]
# The buttons are not horizontally aligned with input widgets, so we need this little hack
button_del_column.markdown("<div style='margin-top:4.2em;'></div>", unsafe_allow_html=True)
for row_index in range(self._n_rows):
key = (self._uid, row_index)
self._delete_buttons[key] = button_del_column.button("Delete", key=str(key))
button_del_column.markdown(
"<div style='margin-top:2.43em;'></div>", unsafe_allow_html=True
)
self._add_row_button = st.button("Add Row", key=f"add_row_button_{self._uid}")
self._update_button = st.button(label="Update Data", key=f"update_button_{self._uid}")
def _update(self):
for col_index in range(self._n_cols):
for row_index in range(self._n_rows):
new_value = self._cells[(self._uid, col_index, row_index)].value
self._data.iloc[row_index, col_index] = new_value
self._data = self._data.sort_values(by=self._data.columns.to_list(), ignore_index=True)
def _add_row(self):
columns = self._data.columns.to_list()
values = [[1] for _ in columns]
row = pd.DataFrame(dict(zip(columns, values)))
self._data = self._data.append(row).reset_index(drop=True)
for col_index in range(self._n_cols):
row_index = self._n_rows
key = (self._uid, col_index, row_index)
with(self._layout[col_index]):
self._add_cell(key, self._data.iloc[row_index, col_index])
def _delete_row(self, row_index):
self._data = self._data.drop([row_index]).reset_index(drop=True)
for col_index in range(self._n_cols):
key = (self._uid, col_index, row_index)
self._delete_cell(key)
def _add_cell(self, key, value):
new_cell = _Cell(key)
new_cell.value = value
self._cells[key] = new_cell
def _delete_cell(self, key):
del self._cells[key]
class _Cell:
def __init__(self, uid):
self._uid = uid
self._value = None
self._widget = st.empty()
@property
def value(self):
return self._value
@value.setter
def value(self, value):
if isinstance(value, numbers.Integral):
self._value = self._widget.number_input(
label="",
value=value,
min_value=1,
step=1,
key=self._uid
)
elif isinstance(value, float):
self._value = self._widget.number_input(
label="",
value=value,
min_value=0.,
step=0.001,
format="%.3f",
key=self._uid
)
else:
self._value = self._widget.text_input(
label="",
value=value,
key=self._uid
)
#
# Code below is from https://gist.github.com/okld/0aba4869ba6fdc8d49132e6974e2e662.
#
class _SessionState:
def __init__(self, session, hash_funcs):
"""Initialize SessionState instance."""
self.__dict__["_state"] = {
"data": {},
"hash": None,
"hasher": _CodeHasher(hash_funcs),
"is_rerun": False,
"session": session,
}
def __call__(self, **kwargs):
"""Initialize state data once."""
for item, value in kwargs.items():
if item not in self._state["data"]:
self._state["data"][item] = value
def __getitem__(self, item):
"""Return a saved state value, None if item is undefined."""
return self._state["data"].get(item, None)
def __getattr__(self, item):
"""Return a saved state value, None if item is undefined."""
return self._state["data"].get(item, None)
def __setitem__(self, item, value):
"""Set state value."""
self._state["data"][item] = value
def __setattr__(self, item, value):
"""Set state value."""
self._state["data"][item] = value
def clear(self):
"""Clear session state and request a rerun."""
self._state["data"].clear()
self._state["session"].request_rerun()
def sync(self):
"""Rerun the app with all state values up to date from the beginning to fix rollbacks."""
# Ensure to rerun only once to avoid infinite loops
# caused by a constantly changing state value at each run.
#
# Example: state.value += 1
if self._state["is_rerun"]:
self._state["is_rerun"] = False
elif self._state["hash"] is not None:
if self._state["hash"] != self._state["hasher"].to_bytes(self._state["data"], None):
self._state["is_rerun"] = True
self._state["session"].request_rerun()
self._state["hash"] = self._state["hasher"].to_bytes(self._state["data"], None)
def _get_session():
session_id = get_report_ctx().session_id
session_info = Server.get_current()._get_session_info(session_id)
if session_info is None:
raise RuntimeError("Couldn't get your Streamlit Session object.")
return session_info.session
def get_state(hash_funcs=None):
session = _get_session()
if not hasattr(session, "_custom_session_state"):
session._custom_session_state = _SessionState(session, hash_funcs)
return session._custom_session_state
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment