Skip to content

Instantly share code, notes, and snippets.

@edoakes
Created September 21, 2021 20:22
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save edoakes/af18a782cd607e3f5b4caa4a524a15ea to your computer and use it in GitHub Desktop.
Save edoakes/af18a782cd607e3f5b4caa4a524a15ea to your computer and use it in GitHub Desktop.
Ray Serve plotly wrapper (working but hacky)
import dash
from dash import dcc, html
from dash.dependencies import Input, Output
import pandas as pd
import plotly.graph_objs as obj
import uvicorn as uvicorn
from fastapi import FastAPI
from starlette.middleware.wsgi import WSGIMiddleware
import ray
from ray import serve
def build_plotly_app():
app = dash.Dash(__name__, requests_pathname_prefix="/dash/")
years = list(range(1940, 2021, 1))
temp_high = [x / 20 for x in years]
temp_low = [x - 20 for x in temp_high]
df = pd.DataFrame({"Year": years, "TempHigh": temp_high, "TempLow": temp_low})
slider = dcc.RangeSlider(
id="slider",
value=[df["Year"].min(), df["Year"].max()],
min=df["Year"].min(),
max=df["Year"].max(),
step=5,
marks={
1940: "1940",
1945: "1945",
1950: "1950",
1955: "1955",
1960: "1960",
1965: "1965",
1970: "1970",
1975: "1975",
1980: "1980",
1985: "1985",
1990: "1990",
1995: "1995",
2000: "2000",
2005: "2005",
2010: "2010",
2015: "2015",
2020: "2020",
},
)
app.layout = html.Div(
children=[
html.H1(children="Data Visualization with Dash"),
html.Div(children="High/Low Temperatures Over Time"),
dcc.Graph(id="temp-plot"),
slider,
]
)
@app.callback(Output("temp-plot", "figure"), [Input("slider", "value")])
def add_graph(slider):
print(type(slider))
trace_high = obj.Scatter(x=df["Year"], y=df["TempHigh"], mode="markers", name="High Temperatures")
trace_low = obj.Scatter(x=df["Year"], y=df["TempLow"], mode="markers", name="Low Temperatures")
layout = obj.Layout(xaxis=dict(range=[slider[0], slider[1]]), yaxis={"title": "Temperature"})
figure = obj.Figure(data=[trace_high, trace_low], layout=layout)
return figure
return app
def lazy_middleware(*args, **kwargs):
"""This is some extreme hackery, beware...
This call lets us get a reference to the main Serve deployment class
from within the deployment. We hard-code the wrapped WSGI to be a field of
the class called `cls.app.server`, so we can transparently proxy the HTTP
requests through to that field here.
"""
return serve.get_replica_context().servable_object.app.server(*args, **kwargs)
if __name__ == "__main__":
server = FastAPI()
server.mount("/dash", WSGIMiddleware(lazy_middleware))
ray.init(address="auto", namespace="serve")
serve.start(detached=True)
@serve.deployment(route_prefix="/")
@serve.ingress(server)
class MyServeWrapper:
def __init__(self):
# We need to construct the plotly app within the constructor
# because it is unfortunately not serializable using cloudpickle
# (contains some weakrefs).
self.app = build_plotly_app()
MyServeWrapper.deploy()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment