|
import asyncio |
|
import json |
|
import os |
|
import threading |
|
from dataclasses import dataclass |
|
from itertools import chain |
|
from multiprocessing.shared_memory import SharedMemory |
|
from typing import Any |
|
|
|
import pandas as pd |
|
import streamlit as st |
|
from streamlit import session_state as ss |
|
from streamlit.components.v1 import html |
|
from streamlit.web.server import Server |
|
from streamlit.web.server.server import start_listening |
|
from tornado.web import RequestHandler |
|
|
|
|
|
_JS_TO_PD_COL_OFFSET: int = -2 |
|
|
|
# Create shared memory for the payload |
|
try: |
|
payload_memory = SharedMemory(name="JS_PAYLOAD", create=True, size=128) |
|
except FileExistsError: |
|
payload_memory = SharedMemory(name="JS_PAYLOAD", create=False, size=128) |
|
|
|
|
|
@dataclass |
|
class Selection: |
|
"""Dataclass to store the selected cell information.""" |
|
col: int |
|
row: int |
|
sorted_by: int |
|
|
|
|
|
def sort_df_by_selected_col(table: pd.DataFrame, js_sorted_by: int) -> pd.DataFrame: |
|
if js_sorted_by == 1: |
|
return table |
|
elif js_sorted_by == -1: |
|
return table.sort_index(axis=0, ascending=False) |
|
sorting_col: str = table.columns[abs(js_sorted_by) + _JS_TO_PD_COL_OFFSET] |
|
return table.sort_values(by=sorting_col, ascending=js_sorted_by > 0) |
|
|
|
|
|
def _retrieve_payload() -> Selection: |
|
"""Retrieve the payload from the shared memory and return it as a tuple.""" |
|
payload = {} |
|
if payload_memory.buf[0] != 0: |
|
payload_bytes = bytearray(payload_memory.buf[:]) |
|
payload_str = payload_bytes.decode('utf-8').rstrip('\x00') |
|
payload_length, payload = len(payload_str), json.loads(payload_str) |
|
payload_memory.buf[:payload_length] = bytearray(payload_length) |
|
if payload: |
|
selected_cell_info = Selection(*( |
|
int(val) for val in chain(payload.get('cellId').split(','), [payload.get('sortedByCol')]) |
|
)) |
|
print(f"{os.getpid()}::{threading.get_ident()}: Streamlit callback received payload: {selected_cell_info}") |
|
return Selection(selected_cell_info.col-1, selected_cell_info.row, selected_cell_info.sorted_by) |
|
else: |
|
print(f"{os.getpid()}::{threading.get_ident()}: Streamlit callback saw no payload!") |
|
return Selection(-1, -1, -1) |
|
|
|
|
|
def _interpret_payload(payload: Selection) -> tuple[Any, Any]: |
|
"""Interpret the payload and return the selected row and column.""" |
|
sorted_df = sort_df_by_selected_col(df, payload.sorted_by) |
|
selected_row = sorted_df.index[payload.row] |
|
selected_col = sorted_df.columns[payload.col] if payload.col >= 0 else None |
|
|
|
# Update a text field: |
|
selection_str = f", with contents: `{sorted_df.iat[payload.row, payload.col]}`" if selected_col else "" |
|
ss["CELL_ID"] = ( |
|
f"Clicked on cell with index [{selected_row}, {selected_col}]" |
|
f" (at position [{payload.row}, {payload.col}])" |
|
f"{selection_str}." |
|
) |
|
return selected_row, selected_col |
|
|
|
|
|
def fake_click(*args, **kwargs): |
|
parsed_payload: Selection = _retrieve_payload() |
|
# ss["PAYLOAD"] = parsed_payload |
|
selected_row, selected_col = _interpret_payload(parsed_payload) |
|
# ss["SELECTION"] = selected_row, selected_col |
|
|
|
# Do something with selection... |
|
|
|
|
|
# Create a sample dataframe |
|
df = pd.DataFrame({ |
|
'A': [1, 2, 3], |
|
'B': [4, 5, 6], |
|
'C': [7, 8, 9] |
|
}) |
|
|
|
# JavaScript to add event listeners to dataframe cells and send data to Streamlit |
|
html_contents = """ |
|
<script defer> |
|
const fakeButton = window.parent.document.querySelector("[data-testid^='baseButton-secondary']"); |
|
const tbl = window.parent.document.querySelector("[data-testid^='stDataFrameResizable']"); |
|
const canvas = window.parent.document.querySelector("[data-testid^='data-grid-canvas']"); |
|
let sortedBy = 1 |
|
|
|
function sendPayload(obj) { |
|
payloadStr = JSON.stringify(obj); |
|
window.sessionStorage.setItem("payload", payloadStr); |
|
fetch('/js_callback', { |
|
method: 'POST', |
|
body: payloadStr, |
|
headers: { |
|
'Content-Type': 'application/json' |
|
} |
|
}) |
|
.then(response => { |
|
fakeButton.click(); |
|
}); |
|
} |
|
|
|
function updateColumnValue() { |
|
const headers = canvas.querySelectorAll('th[role="columnheader"]'); |
|
let arrowFound = false; |
|
|
|
headers.forEach(header => { |
|
const textContent = header.textContent.trim(); |
|
const colIndex = parseInt(header.getAttribute('aria-colindex'), 10); |
|
|
|
if (textContent.startsWith('↑')) { |
|
sortedBy = colIndex; |
|
arrowFound = true; |
|
} else if (textContent.startsWith('↓')) { |
|
sortedBy = -colIndex; |
|
arrowFound = true; |
|
} |
|
}); |
|
if (!arrowFound) { |
|
sortedBy = 1; |
|
} |
|
console.log(`Sorting column is now: ${sortedBy}`); |
|
} |
|
|
|
const sortObserver = new MutationObserver((mutations) => { |
|
mutations.forEach((mutation) => { |
|
if (mutation.type === 'characterData' || mutation.type === 'childList') { |
|
updateColumnValue(); |
|
} |
|
}); |
|
}); |
|
|
|
// Observe changes in the canvas element and its subtree |
|
sortObserver.observe(canvas, { |
|
characterData: true, |
|
childList: true, |
|
subtree: true |
|
}); |
|
|
|
function handleTableClick(event) { |
|
// MutationObserver callback function |
|
const cellObserverCallback = (mutationsList, observer) => { |
|
for (const mutation of mutationsList) { |
|
if (mutation.type === 'attributes' && mutation.attributeName === 'aria-selected') { |
|
const target = mutation.target; |
|
if (target.tagName === 'TD' && target.getAttribute('aria-selected') === 'true') { |
|
cellCoords = target.id.replace('glide-cell-','').replace('-',','); |
|
console.log(`Detected click on cell {${cellCoords}}, sorted by column "${sortedBy}"`); |
|
observer.disconnect(); // Stop observing once the element is found |
|
sendPayload({"action": "click", "cellId": cellCoords, "sortedByCol": sortedBy}); |
|
} |
|
} |
|
} |
|
}; |
|
|
|
// Create a MutationObserver |
|
const cellObserver = new MutationObserver(cellObserverCallback); |
|
|
|
// Observe changes in attributes in the subtree of the canvas element |
|
cellObserver.observe(canvas, { attributes: true, subtree: true }); |
|
} |
|
|
|
tbl.addEventListener('click', handleTableClick) |
|
console.log("Event listeners added!"); |
|
</script> |
|
""" |
|
|
|
# Display the dataframe |
|
st.data_editor(df, disabled=True, key="DATAFRAME") |
|
|
|
st.text_input(label="N/A", label_visibility="hidden", key="CELL_ID", disabled=True, help="Click on a cell...") |
|
|
|
# Create a fake button: |
|
st.button("", key="fakeButton", on_click=fake_click) |
|
st.markdown( |
|
""" |
|
<style> |
|
button, iframe { |
|
visibility: hidden; |
|
} |
|
</style> |
|
""", |
|
unsafe_allow_html=True, |
|
) |
|
html(html_contents) |
|
|
|
|
|
class JSCallbackHandler(RequestHandler): |
|
def set_default_headers(self): |
|
# We hijack this method to store the JS payload |
|
try: |
|
payload: bytes = self.request.body |
|
print(f"{os.getpid()}::{threading.get_ident()}: Python received payload: {json.loads(payload)}") |
|
except json.JSONDecodeError: |
|
raise ValueError("Invalid JSON payload!") |
|
|
|
if payload_memory.buf[0] == 0: |
|
payload_memory.buf[:len(payload)] = payload |
|
print(f"{os.getpid()}::{threading.get_ident()}: Payload {payload} stored in shared memory") |
|
|
|
|
|
class CustomServer(Server): |
|
async def start(self): |
|
# Override the start of the Tornado server, so we can add custom handlers |
|
app = self._create_app() |
|
|
|
# Add a new handler |
|
app.default_router.add_rules([ |
|
(r"/js_callback", JSCallbackHandler), |
|
]) |
|
|
|
# Our new rules go before the rule matching everything, reverse the list |
|
app.default_router.rules = list(reversed(app.default_router.rules)) |
|
|
|
start_listening(app) |
|
await self._runtime.start() |
|
|
|
|
|
if __name__ == '__main__': |
|
# See: https://bartbroere.eu/2024/03/27/adding-custom-tornado-handlers-to-streamlit/ |
|
import streamlit.web.bootstrap |
|
|
|
if '__streamlitmagic__' not in locals(): |
|
# Code adapted from bootstrap.py in streamlit |
|
streamlit.web.bootstrap._fix_sys_path(__file__) |
|
streamlit.web.bootstrap._fix_tornado_crash() |
|
streamlit.web.bootstrap._fix_sys_argv(__file__, []) |
|
streamlit.web.bootstrap._fix_pydeck_mapbox_api_warning() |
|
streamlit.web.bootstrap._fix_pydantic_duplicate_validators_error() |
|
# streamlit.web.bootstrap._install_pages_watcher(__file__) # Uncomment if on Streamlit < v1.36.0 |
|
|
|
server = CustomServer(__file__, is_hello=False) |
|
|
|
async def run_server(): |
|
await server.start() |
|
streamlit.web.bootstrap._on_server_start(server) |
|
streamlit.web.bootstrap._set_up_signal_handler(server) |
|
await server.stopped |
|
|
|
asyncio.run(run_server()) |
Hi, first of all thanks for the great work!
Could you maybe show how the cell location values can be used in the streamlit app? Because I didn't get this to work, only showed the values in the js console.
And I get the error that the PAYLOAD variable is not initialized.