|
import asyncio |
|
import json |
|
import os |
|
import threading |
|
from dataclasses import dataclass |
|
from itertools import chain |
|
from multiprocessing.shared_memory import SharedMemory |
|
from typing import Any |
|
|
|
import pandas as pd |
|
import streamlit as st |
|
from streamlit import session_state as ss |
|
from streamlit.components.v1 import html |
|
from streamlit.web.server import Server |
|
from streamlit.web.server.server import start_listening |
|
from tornado.web import RequestHandler |
|
|
|
|
|
_JS_TO_PD_COL_OFFSET: int = -2 |
|
|
|
# Create shared memory for the payload |
|
try: |
|
payload_memory = SharedMemory(name="JS_PAYLOAD", create=True, size=128) |
|
except FileExistsError: |
|
payload_memory = SharedMemory(name="JS_PAYLOAD", create=False, size=128) |
|
|
|
|
|
@dataclass |
|
class Selection: |
|
"""Dataclass to store the selected cell information.""" |
|
col: int |
|
row: int |
|
sorted_by: int |
|
|
|
|
|
def sort_df_by_selected_col(table: pd.DataFrame, js_sorted_by: int) -> pd.DataFrame: |
|
if js_sorted_by == 1: |
|
return table |
|
elif js_sorted_by == -1: |
|
return table.sort_index(axis=0, ascending=False) |
|
sorting_col: str = table.columns[abs(js_sorted_by) + _JS_TO_PD_COL_OFFSET] |
|
return table.sort_values(by=sorting_col, ascending=js_sorted_by > 0) |
|
|
|
|
|
def _retrieve_payload() -> Selection: |
|
"""Retrieve the payload from the shared memory and return it as a tuple.""" |
|
payload = {} |
|
if payload_memory.buf[0] != 0: |
|
payload_bytes = bytearray(payload_memory.buf[:]) |
|
payload_str = payload_bytes.decode('utf-8').rstrip('\x00') |
|
payload_length, payload = len(payload_str), json.loads(payload_str) |
|
payload_memory.buf[:payload_length] = bytearray(payload_length) |
|
if payload: |
|
selected_cell_info = Selection(*( |
|
int(val) for val in chain(payload.get('cellId').split(','), [payload.get('sortedByCol')]) |
|
)) |
|
print(f"{os.getpid()}::{threading.get_ident()}: Streamlit callback received payload: {selected_cell_info}") |
|
return Selection(selected_cell_info.col-1, selected_cell_info.row, selected_cell_info.sorted_by) |
|
else: |
|
print(f"{os.getpid()}::{threading.get_ident()}: Streamlit callback saw no payload!") |
|
return Selection(-1, -1, -1) |
|
|
|
|
|
def _interpret_payload(payload: Selection) -> tuple[Any, Any]: |
|
"""Interpret the payload and return the selected row and column.""" |
|
sorted_df = sort_df_by_selected_col(df, payload.sorted_by) |
|
selected_row = sorted_df.index[payload.row] |
|
selected_col = sorted_df.columns[payload.col] if payload.col >= 0 else None |
|
|
|
# Update a text field: |
|
selection_str = f", with contents: `{sorted_df.iat[payload.row, payload.col]}`" if selected_col else "" |
|
ss["CELL_ID"] = ( |
|
f"Clicked on cell with index [{selected_row}, {selected_col}]" |
|
f" (at position [{payload.row}, {payload.col}])" |
|
f"{selection_str}." |
|
) |
|
return selected_row, selected_col |
|
|
|
|
|
def fake_click(*args, **kwargs): |
|
parsed_payload: Selection = _retrieve_payload() |
|
# ss["PAYLOAD"] = parsed_payload |
|
selected_row, selected_col = _interpret_payload(parsed_payload) |
|
# ss["SELECTION"] = selected_row, selected_col |
|
|
|
# Do something with selection... |
|
|
|
|
|
# Create a sample dataframe |
|
df = pd.DataFrame({ |
|
'A': [1, 2, 3], |
|
'B': [4, 5, 6], |
|
'C': [7, 8, 9] |
|
}) |
|
|
|
# JavaScript to add event listeners to dataframe cells and send data to Streamlit |
|
html_contents = """ |
|
<script defer> |
|
const fakeButton = window.parent.document.querySelector("[data-testid^='baseButton-secondary']"); |
|
const tbl = window.parent.document.querySelector("[data-testid^='stDataFrameResizable']"); |
|
const canvas = window.parent.document.querySelector("[data-testid^='data-grid-canvas']"); |
|
let sortedBy = 1 |
|
|
|
function sendPayload(obj) { |
|
payloadStr = JSON.stringify(obj); |
|
window.sessionStorage.setItem("payload", payloadStr); |
|
fetch('/js_callback', { |
|
method: 'POST', |
|
body: payloadStr, |
|
headers: { |
|
'Content-Type': 'application/json' |
|
} |
|
}) |
|
.then(response => { |
|
fakeButton.click(); |
|
}); |
|
} |
|
|
|
function updateColumnValue() { |
|
const headers = canvas.querySelectorAll('th[role="columnheader"]'); |
|
let arrowFound = false; |
|
|
|
headers.forEach(header => { |
|
const textContent = header.textContent.trim(); |
|
const colIndex = parseInt(header.getAttribute('aria-colindex'), 10); |
|
|
|
if (textContent.startsWith('↑')) { |
|
sortedBy = colIndex; |
|
arrowFound = true; |
|
} else if (textContent.startsWith('↓')) { |
|
sortedBy = -colIndex; |
|
arrowFound = true; |
|
} |
|
}); |
|
if (!arrowFound) { |
|
sortedBy = 1; |
|
} |
|
console.log(`Sorting column is now: ${sortedBy}`); |
|
} |
|
|
|
const sortObserver = new MutationObserver((mutations) => { |
|
mutations.forEach((mutation) => { |
|
if (mutation.type === 'characterData' || mutation.type === 'childList') { |
|
updateColumnValue(); |
|
} |
|
}); |
|
}); |
|
|
|
// Observe changes in the canvas element and its subtree |
|
sortObserver.observe(canvas, { |
|
characterData: true, |
|
childList: true, |
|
subtree: true |
|
}); |
|
|
|
function handleTableClick(event) { |
|
// MutationObserver callback function |
|
const cellObserverCallback = (mutationsList, observer) => { |
|
for (const mutation of mutationsList) { |
|
if (mutation.type === 'attributes' && mutation.attributeName === 'aria-selected') { |
|
const target = mutation.target; |
|
if (target.tagName === 'TD' && target.getAttribute('aria-selected') === 'true') { |
|
cellCoords = target.id.replace('glide-cell-','').replace('-',','); |
|
console.log(`Detected click on cell {${cellCoords}}, sorted by column "${sortedBy}"`); |
|
observer.disconnect(); // Stop observing once the element is found |
|
sendPayload({"action": "click", "cellId": cellCoords, "sortedByCol": sortedBy}); |
|
} |
|
} |
|
} |
|
}; |
|
|
|
// Create a MutationObserver |
|
const cellObserver = new MutationObserver(cellObserverCallback); |
|
|
|
// Observe changes in attributes in the subtree of the canvas element |
|
cellObserver.observe(canvas, { attributes: true, subtree: true }); |
|
} |
|
|
|
tbl.addEventListener('click', handleTableClick) |
|
console.log("Event listeners added!"); |
|
</script> |
|
""" |
|
|
|
# Display the dataframe |
|
st.data_editor(df, disabled=True, key="DATAFRAME") |
|
|
|
st.text_input(label="N/A", label_visibility="hidden", key="CELL_ID", disabled=True, help="Click on a cell...") |
|
|
|
# Create a fake button: |
|
st.button("", key="fakeButton", on_click=fake_click) |
|
st.markdown( |
|
""" |
|
<style> |
|
button, iframe { |
|
visibility: hidden; |
|
} |
|
</style> |
|
""", |
|
unsafe_allow_html=True, |
|
) |
|
html(html_contents) |
|
|
|
|
|
class JSCallbackHandler(RequestHandler): |
|
def set_default_headers(self): |
|
# We hijack this method to store the JS payload |
|
try: |
|
payload: bytes = self.request.body |
|
print(f"{os.getpid()}::{threading.get_ident()}: Python received payload: {json.loads(payload)}") |
|
except json.JSONDecodeError: |
|
raise ValueError("Invalid JSON payload!") |
|
|
|
if payload_memory.buf[0] == 0: |
|
payload_memory.buf[:len(payload)] = payload |
|
print(f"{os.getpid()}::{threading.get_ident()}: Payload {payload} stored in shared memory") |
|
|
|
|
|
class CustomServer(Server): |
|
async def start(self): |
|
# Override the start of the Tornado server, so we can add custom handlers |
|
app = self._create_app() |
|
|
|
# Add a new handler |
|
app.default_router.add_rules([ |
|
(r"/js_callback", JSCallbackHandler), |
|
]) |
|
|
|
# Our new rules go before the rule matching everything, reverse the list |
|
app.default_router.rules = list(reversed(app.default_router.rules)) |
|
|
|
start_listening(app) |
|
await self._runtime.start() |
|
|
|
|
|
if __name__ == '__main__': |
|
# See: https://bartbroere.eu/2024/03/27/adding-custom-tornado-handlers-to-streamlit/ |
|
import streamlit.web.bootstrap |
|
|
|
if '__streamlitmagic__' not in locals(): |
|
# Code adapted from bootstrap.py in streamlit |
|
streamlit.web.bootstrap._fix_sys_path(__file__) |
|
streamlit.web.bootstrap._fix_tornado_crash() |
|
streamlit.web.bootstrap._fix_sys_argv(__file__, []) |
|
streamlit.web.bootstrap._fix_pydeck_mapbox_api_warning() |
|
streamlit.web.bootstrap._fix_pydantic_duplicate_validators_error() |
|
# streamlit.web.bootstrap._install_pages_watcher(__file__) # Uncomment if on Streamlit < v1.36.0 |
|
|
|
server = CustomServer(__file__, is_hello=False) |
|
|
|
async def run_server(): |
|
await server.start() |
|
streamlit.web.bootstrap._on_server_start(server) |
|
streamlit.web.bootstrap._set_up_signal_handler(server) |
|
await server.stopped |
|
|
|
asyncio.run(run_server()) |
Hi @enoch-nkm, please clarify - did you use
working_workaround.py
or anything else? In one of the animations I demonstrate how the variable gets received and printed in the python console. Also, which OS + python version are you using?The variable can only be not-initialized if shared memory isn't working correctly, since the frontend thread (where js runs) is usually different from the backend thread (where the click is processed in python). Variable sharing between threads doesn't happen unless it is set up correctly.
Other than that, I have some small improvements to the script related to using the right cell when the frontend table is sorted, which I plan to upload eventuall.