Skip to content

Instantly share code, notes, and snippets.

@ntjess
Created October 7, 2024 21:35
Show Gist options
  • Save ntjess/b6ab3725a73aaef79ff3abd5a5e41887 to your computer and use it in GitHub Desktop.
Save ntjess/b6ab3725a73aaef79ff3abd5a5e41887 to your computer and use it in GitHub Desktop.
import solara as sl
import typing as t
import numpy as np
from plotly import graph_objects as go
from typing_extensions import TypedDict as TypedDictExt
Point = t.Tuple[float, float]
class PlotlyPointData(TypedDictExt):
trace_indexes: list[int]
point_indexes: list[int]
xs: list[int]
ys: list[int]
class PlotlyDeviceState(TypedDictExt):
alt: bool
ctrl: bool
meta: bool
shift: bool
button: int
buttons: int
class PlotlyEventData(TypedDictExt):
event_type: str
points: PlotlyPointData
device_state: PlotlyDeviceState
selector: t.Optional[t.Any]
def circle_trace(radius: float = 5.0, n_points=1000):
theta = np.linspace(0, 2 * np.pi, n_points)
x = radius * np.cos(theta)
y = radius * np.sin(theta)
return go.Scatter(x=x, y=y, mode="lines")
def line_trace(tangent_xy: Point | None, radius: float = 5.0, n_points=1000):
# From calculations
if tangent_xy is None:
return None
x1, y1 = tangent_xy
if y1 == 0:
y1 = 1e-6
m = -x1 / y1
b = y1 + x1**2 / y1
# Sample from the smaller domain of x or y
x = np.linspace(-radius, radius, n_points)
y = m * x + b
if abs(x[-1] - x[0]) < abs(y[-1] - y[0]):
y = np.linspace(-radius, radius, n_points)
x = (y - b) / m
return go.Scatter(x=x, y=y, mode="lines", line=dict(color="red"))
@sl.component
def InterceptPlot(radius=5.0, n_points=1000):
px_data = sl.use_reactive(t.cast(Point | None, None))
fig = go.Figure()
fig.update_yaxes(scaleanchor="x", scaleratio=1)
def on_hover(data: PlotlyEventData):
points = data["points"]
if not points["trace_indexes"]:
return
px_data.set((points["xs"][0], points["ys"][0]))
fig.add_trace(circle_trace(radius, n_points))
if line := line_trace(px_data.value, radius, n_points):
fig.add_trace(line)
sl.FigurePlotly(fig, on_hover=on_hover)
bound = radius * 1.5
fig.update_layout(
width=500,
height=500,
yaxis_range=[-bound, bound],
xaxis_range=[-bound, bound],
)
@sl.component
def Page():
InterceptPlot()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment