Skip to content

Instantly share code, notes, and snippets.

@cpsievert
Created July 15, 2024 15:07
Show Gist options
  • Save cpsievert/f79e4816b43659aed436284a166ba479 to your computer and use it in GitHub Desktop.
Save cpsievert/f79e4816b43659aed436284a166ba479 to your computer and use it in GitHub Desktop.
Filter DataGrid using plotly scatterplot
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from shinywidgets import render_plotly
from shiny import reactive
from shiny.express import render
num_rows = 10
x_points = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
y_points = [10, 9, 8, 7, 6, 5, 4, 3, 2, 1]
row_labels = [f"label_{i+1}" for i in range(num_rows)]
df = pd.DataFrame({"Row_Label": row_labels, "X": x_points, "Y": y_points})
@render_plotly
def scatterplot():
p = px.scatter(
df,
x="X",
y="Y",
custom_data="Row_Label",
hover_data="Row_Label",
title="Scatter Plot",
)
w = go.FigureWidget(p)
w.data[0].on_click(set_selected_row)
return w
selected_row_index: int | None = reactive.Value(None)
@render.data_frame
def table():
index = selected_row_index.get()
if index is None:
return df
else:
return df.iloc[[index]]
def set_selected_row(trace, points, selector):
idx = points.point_inds[0]
selected_row_index.set(idx)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment