|
import asyncio |
|
import json |
|
import os |
|
import threading |
|
from dataclasses import dataclass |
|
from itertools import chain |
|
from multiprocessing.shared_memory import SharedMemory |
|
from typing import Any |
|
|
|
import pandas as pd |
|
import streamlit as st |
|
from streamlit import session_state as ss |
|
from streamlit.components.v1 import html |
|
from streamlit.web.server import Server |
|
from streamlit.web.server.server import start_listening |
|
from tornado.web import RequestHandler |
|
|
|
|
|
_JS_TO_PD_COL_OFFSET: int = -2 |
|
|
|
# Create shared memory for the payload |
|
try: |
|
payload_memory = SharedMemory(name="JS_PAYLOAD", create=True, size=128) |
|
except FileExistsError: |
|
payload_memory = SharedMemory(name="JS_PAYLOAD", create=False, size=128) |
|
|
|
|
|
@dataclass |
|
class Selection: |
|
"""Dataclass to store the selected cell information.""" |
|
col: int |
|
row: int |
|
sorted_by: int |
|
|
|
|
|
def sort_df_by_selected_col(table: pd.DataFrame, js_sorted_by: int) -> pd.DataFrame: |
|
if js_sorted_by == 1: |
|
return table |
|
elif js_sorted_by == -1: |
|
return table.sort_index(axis=0, ascending=False) |
|
sorting_col: str = table.columns[abs(js_sorted_by) + _JS_TO_PD_COL_OFFSET] |
|
return table.sort_values(by=sorting_col, ascending=js_sorted_by > 0) |
|
|
|
|
|
def _retrieve_payload() -> Selection: |
|
"""Retrieve the payload from the shared memory and return it as a tuple.""" |
|
payload = {} |
|
if payload_memory.buf[0] != 0: |
|
payload_bytes = bytearray(payload_memory.buf[:]) |
|
payload_str = payload_bytes.decode('utf-8').rstrip('\x00') |
|
payload_length, payload = len(payload_str), json.loads(payload_str) |
|
payload_memory.buf[:payload_length] = bytearray(payload_length) |
|
if payload: |
|
selected_cell_info = Selection(*( |
|
int(val) for val in chain(payload.get('cellId').split(','), [payload.get('sortedByCol')]) |
|
)) |
|
print(f"{os.getpid()}::{threading.get_ident()}: Streamlit callback received payload: {selected_cell_info}") |
|
return Selection(selected_cell_info.col-1, selected_cell_info.row, selected_cell_info.sorted_by) |
|
else: |
|
print(f"{os.getpid()}::{threading.get_ident()}: Streamlit callback saw no payload!") |
|
return Selection(-1, -1, -1) |
|
|
|
|
|
def _interpret_payload(payload: Selection) -> tuple[Any, Any]: |
|
"""Interpret the payload and return the selected row and column.""" |
|
sorted_df = sort_df_by_selected_col(df, payload.sorted_by) |
|
selected_row = sorted_df.index[payload.row] |
|
selected_col = sorted_df.columns[payload.col] if payload.col >= 0 else None |
|
|
|
# Update a text field: |
|
selection_str = f", with contents: `{sorted_df.iat[payload.row, payload.col]}`" if selected_col else "" |
|
ss["CELL_ID"] = ( |
|
f"Clicked on cell with index [{selected_row}, {selected_col}]" |
|
f" (at position [{payload.row}, {payload.col}])" |
|
f"{selection_str}." |
|
) |
|
return selected_row, selected_col |
|
|
|
|
|
def fake_click(*args, **kwargs): |
|
parsed_payload: Selection = _retrieve_payload() |
|
# ss["PAYLOAD"] = parsed_payload |
|
selected_row, selected_col = _interpret_payload(parsed_payload) |
|
# ss["SELECTION"] = selected_row, selected_col |
|
|
|
# Do something with selection... |
|
|
|
|
|
# Create a sample dataframe |
|
df = pd.DataFrame({ |
|
'A': [1, 2, 3], |
|
'B': [4, 5, 6], |
|
'C': [7, 8, 9] |
|
}) |
|
|
|
# JavaScript to add event listeners to dataframe cells and send data to Streamlit |
|
html_contents = """ |
|
<script defer> |
|
const fakeButton = window.parent.document.querySelector("[data-testid^='baseButton-secondary']"); |
|
const tbl = window.parent.document.querySelector("[data-testid^='stDataFrameResizable']"); |
|
const canvas = window.parent.document.querySelector("[data-testid^='data-grid-canvas']"); |
|
let sortedBy = 1 |
|
|
|
function sendPayload(obj) { |
|
payloadStr = JSON.stringify(obj); |
|
window.sessionStorage.setItem("payload", payloadStr); |
|
fetch('/js_callback', { |
|
method: 'POST', |
|
body: payloadStr, |
|
headers: { |
|
'Content-Type': 'application/json' |
|
} |
|
}) |
|
.then(response => { |
|
fakeButton.click(); |
|
}); |
|
} |
|
|
|
function updateColumnValue() { |
|
const headers = canvas.querySelectorAll('th[role="columnheader"]'); |
|
let arrowFound = false; |
|
|
|
headers.forEach(header => { |
|
const textContent = header.textContent.trim(); |
|
const colIndex = parseInt(header.getAttribute('aria-colindex'), 10); |
|
|
|
if (textContent.startsWith('↑')) { |
|
sortedBy = colIndex; |
|
arrowFound = true; |
|
} else if (textContent.startsWith('↓')) { |
|
sortedBy = -colIndex; |
|
arrowFound = true; |
|
} |
|
}); |
|
if (!arrowFound) { |
|
sortedBy = 1; |
|
} |
|
console.log(`Sorting column is now: ${sortedBy}`); |
|
} |
|
|
|
const sortObserver = new MutationObserver((mutations) => { |
|
mutations.forEach((mutation) => { |
|
if (mutation.type === 'characterData' || mutation.type === 'childList') { |
|
updateColumnValue(); |
|
} |
|
}); |
|
}); |
|
|
|
// Observe changes in the canvas element and its subtree |
|
sortObserver.observe(canvas, { |
|
characterData: true, |
|
childList: true, |
|
subtree: true |
|
}); |
|
|
|
function handleTableClick(event) { |
|
// MutationObserver callback function |
|
const cellObserverCallback = (mutationsList, observer) => { |
|
for (const mutation of mutationsList) { |
|
if (mutation.type === 'attributes' && mutation.attributeName === 'aria-selected') { |
|
const target = mutation.target; |
|
if (target.tagName === 'TD' && target.getAttribute('aria-selected') === 'true') { |
|
cellCoords = target.id.replace('glide-cell-','').replace('-',','); |
|
console.log(`Detected click on cell {${cellCoords}}, sorted by column "${sortedBy}"`); |
|
observer.disconnect(); // Stop observing once the element is found |
|
sendPayload({"action": "click", "cellId": cellCoords, "sortedByCol": sortedBy}); |
|
} |
|
} |
|
} |
|
}; |
|
|
|
// Create a MutationObserver |
|
const cellObserver = new MutationObserver(cellObserverCallback); |
|
|
|
// Observe changes in attributes in the subtree of the canvas element |
|
cellObserver.observe(canvas, { attributes: true, subtree: true }); |
|
} |
|
|
|
tbl.addEventListener('click', handleTableClick) |
|
console.log("Event listeners added!"); |
|
</script> |
|
""" |
|
|
|
# Display the dataframe |
|
st.data_editor(df, disabled=True, key="DATAFRAME") |
|
|
|
st.text_input(label="N/A", label_visibility="hidden", key="CELL_ID", disabled=True, help="Click on a cell...") |
|
|
|
# Create a fake button: |
|
st.button("", key="fakeButton", on_click=fake_click) |
|
st.markdown( |
|
""" |
|
<style> |
|
button, iframe { |
|
visibility: hidden; |
|
} |
|
</style> |
|
""", |
|
unsafe_allow_html=True, |
|
) |
|
html(html_contents) |
|
|
|
|
|
class JSCallbackHandler(RequestHandler): |
|
def set_default_headers(self): |
|
# We hijack this method to store the JS payload |
|
try: |
|
payload: bytes = self.request.body |
|
print(f"{os.getpid()}::{threading.get_ident()}: Python received payload: {json.loads(payload)}") |
|
except json.JSONDecodeError: |
|
raise ValueError("Invalid JSON payload!") |
|
|
|
if payload_memory.buf[0] == 0: |
|
payload_memory.buf[:len(payload)] = payload |
|
print(f"{os.getpid()}::{threading.get_ident()}: Payload {payload} stored in shared memory") |
|
|
|
|
|
class CustomServer(Server): |
|
async def start(self): |
|
# Override the start of the Tornado server, so we can add custom handlers |
|
app = self._create_app() |
|
|
|
# Add a new handler |
|
app.default_router.add_rules([ |
|
(r"/js_callback", JSCallbackHandler), |
|
]) |
|
|
|
# Our new rules go before the rule matching everything, reverse the list |
|
app.default_router.rules = list(reversed(app.default_router.rules)) |
|
|
|
start_listening(app) |
|
await self._runtime.start() |
|
|
|
|
|
if __name__ == '__main__': |
|
# See: https://bartbroere.eu/2024/03/27/adding-custom-tornado-handlers-to-streamlit/ |
|
import streamlit.web.bootstrap |
|
|
|
if '__streamlitmagic__' not in locals(): |
|
# Code adapted from bootstrap.py in streamlit |
|
streamlit.web.bootstrap._fix_sys_path(__file__) |
|
streamlit.web.bootstrap._fix_tornado_crash() |
|
streamlit.web.bootstrap._fix_sys_argv(__file__, []) |
|
streamlit.web.bootstrap._fix_pydeck_mapbox_api_warning() |
|
streamlit.web.bootstrap._fix_pydantic_duplicate_validators_error() |
|
# streamlit.web.bootstrap._install_pages_watcher(__file__) # Uncomment if on Streamlit < v1.36.0 |
|
|
|
server = CustomServer(__file__, is_hello=False) |
|
|
|
async def run_server(): |
|
await server.start() |
|
streamlit.web.bootstrap._on_server_start(server) |
|
streamlit.web.bootstrap._set_up_signal_handler(server) |
|
await server.stopped |
|
|
|
asyncio.run(run_server()) |
@enoch-nkm I made a better implementation and submitted a PR. It seems that the company isn't convinced that this feature is necessary. If more people upvote the issue related to this it might help them change their mind.
Regardless, if the PR doesn't get merged in the near future, I might package my own version of streamlit with the change included.