Skip to content

Instantly share code, notes, and snippets.

@th-yoo
Created July 9, 2023 10:49
Show Gist options
  • Save th-yoo/603e8e85c5ae58a3d2f7127e33bedd32 to your computer and use it in GitHub Desktop.
Save th-yoo/603e8e85c5ae58a3d2f7127e33bedd32 to your computer and use it in GitHub Desktop.
Bokeh with FastAPI
from __future__ import annotations
from pprint import pprint
from var_dump import var_dump
from fastapi import FastAPI, WebSocket, Request, HTTPException, WebSocketDisconnect
from fastapi.responses import HTMLResponse, PlainTextResponse
from fastapi.staticfiles import StaticFiles
app = FastAPI()
#app.mount('/static', StaticFiles(directory='bokeh/server/static'), name='static')
app.mount('/static', StaticFiles(directory='../bokeh/bokehjs/build'), name='static')
from bokeh.settings import settings
from bokeh.core.types import ID
from bokeh.util.token import (
check_token_signature,
generate_jwt_token,
generate_session_id,
get_session_id,
get_token_payload
)
from bokeh.application.application import Application, SessionContext
from bokeh.document import Document
import weakref
class BokashiSessionContext(SessionContext):
#_session: ServerSession | None
#_request: _RequestProxy | None
_token: str | None
def __init__(self
, session_id: ID
#, server_context: ServerContext
, document: Document
, logout_url: str | None = None) -> None:
self._doc = document
#self._session = None
self._logout_url = logout_url
#super().__init__(server_context, session_id)
super().__init__(None, session_id)
self._request = None
self._token = None
def _set_session(self, session: ServerSession) -> None:
#self._session = session
pass
@property
def destroyed(self) -> bool:
# TODO
# server session
return False
async def with_locked_document(self
, func: Callable[[Document], Awaitable[None]]) -> None:
await func(self._doc)
class HTTPError(Exception):
def __init__(self, status: int, message: str = ''):
self.status = status
self.message = message
# all the keys should be lower cased.
def get_session(headers: dict[str,str], cookies: dict[str,str], qs: dict[str,str]) -> SessionContext:
token = qs.get('bokeh-token', None)
session_id: ID | None = qs.get('bokeh-session-id', None)
if 'bokeh-session-id' in headers:
if seesion_id:
raise HTTPError(403, 'session ID was provided as an argument and header')
session_id = headers.get('bokeh-session-id')
if token:
if session_id:
raise HTTPError(403, 'Both token and session ID were provided')
session_id = get_session_id(token)
elif not session_id:
session_id = generate_session_id(
settings.secret_key_bytes(),
settings.sign_sessions()
)
if not token:
if cookies and 'cookie' in headers:
del headers['cookie']
payload = {'headers': headers, 'cookies': cookies, 'arguments': qs}
token = generate_jwt_token(
session_id,
secret_key=settings.secret_key_bytes(),
signed=settings.sign_sessions(),
expiration=300,
extra_payload=payload
)
if not check_token_signature(
token,
secret_key=settings.secret_key_bytes(),
signed=settings.sign_sessions()
):
# error('Session id had invalid signature: %r', session_id)
raise HTTPError(403, 'Invalid token or session ID')
doc = Document()
session_ctx = BokashiSessionContext(session_id, doc)
session_ctx._token = token
doc._session_context = weakref.ref(session_ctx)
return session_ctx
from main import bkapp
from bokeh.application.handlers.function import FunctionHandler
bkapp = Application(FunctionHandler(bkapp))
def html_page_for_session(ctx: SessionContext, root_url: str):
from bokeh.embed.util import RenderItem
render_item = RenderItem(
token=ctx._token,
roots=ctx._doc.roots,
use_for_title=True
)
from bokeh.resources import Resources
resources = Resources(mode='server', root_url=root_url)
from bokeh.embed.bundle import bundle_for_objs_and_resources
bundle = bundle_for_objs_and_resources(None, resources)
from bokeh.embed.elements import html_page_for_render_items
return html_page_for_render_items(
bundle,
{},
[render_item],
ctx._doc.title,
template=ctx._doc.template,
template_variables=ctx._doc.template_variables or {}
)
session_ctx: SessionContext | None = None
@app.get("/")
async def get(req: Request):
url = req.url
global session_ctx
try:
session_ctx = get_session(*map(lambda x: dict(x), (req.headers, req.cookies, req.query_params)))
await bkapp.on_session_created(session_ctx)
bkapp.initialize_document(session_ctx._doc)
# FIXME: url
html = html_page_for_session(session_ctx, ''.join((url.scheme, '://', url.netloc)))
#pprint(html)
except HTTPError as e:
raise HTTPException(status_code=e.status, detail=e.message)
return HTMLResponse(html)
import asyncio
class TornadoWSAdapter:
def __init__(self, ws: WebSocket):
self._sock = ws
def write_message(self, msg: str | bytes, binary: bool = False) -> None:
if isinstance(msg, str):
co = self._sock.send_text(msg)
elif isinstance(msg, bytes):
co = self._sock.send_bytes(msg)
#elif isinstance(msg, dict):
# return await self._sock.send_json(msg)
asyncio.create_task(co)
async def read_message(self, callback: Callable[..., None] | None = None) -> Awaitable[None | str | bytes]:
msg = await self._sock.receive()
rv = msg.get('text', msg.get('bytes', None))
if callback is not None:
await asyncio.ensure_future(callback(rv))
return rv
#def close(self, code, reason):
def close(self, *args):
asyncio.create_task(self._sock.close(*args))
class TornadoLockAdapter:
def __init__(self):
self._lck = asyncio.Lock()
async def acquire(self):
await self._lck.acquire()
return self
def __enter__(self):
pass
def __exit__(self, exc_type, exc, tb):
self._lck.release()
from bokeh.client.websocket import WebSocketClientConnectionWrapper
class WSConnAdapter(WebSocketClientConnectionWrapper):
def __init__(self, socket: WebSocket) -> None:
self._socket = TornadoWSAdapter(socket)
self.write_lock = TornadoLockAdapter()
from bokeh.protocol import Protocol
from bokeh.protocol import messages as msg
from bokeh.protocol.exceptions import MessageError, ProtocolError, ValidationError
from bokeh.protocol.message import Message
from bokeh.protocol.receiver import Receiver
from bokeh.document.events import DocumentPatchedEvent
from typing import (
TYPE_CHECKING,
cast,
Any,
Optional,
Dict,
Union,
List,
Awaitable,
Callable,
Tuple,
Type,
)
# server/connection.py
class ServerConnection:
def __init__(self, proto: Protocol, sock: WSConnAdapter):
self._protocol = proto
self._sock = sock
def ok(self, message: Message[Any]) -> msg.ok:
return self.protocol.create('OK', message.header['msgid'])
def error(self, message: Message[Any], text: str) -> msg.error:
return self.protocol.create('ERROR', message.header['msgid'], text)
def send_patch_document(self, event: DocumentPatchedEvent) -> Awaitable[None]:
msg = self.protocol.create('PATCH-DOC', [event])
return msg.send(self._sock)
@property
def protocol(self) -> Protocol:
return self._protocol
# bokeh/server/protocol_handler.py
# TODO: document lock
class ProtocolHandler:
_handlers: dict[str, Callable[..., Any]]
def __init__(self, doc: Document, ss: ServerSession) -> None:
self._doc = doc
self._ss = ss
self._handlers = {}
self._handlers['PULL-DOC-REQ'] = self.pull
self._handlers['PUSH-DOC'] = self.push
self._handlers['PATCH-DOC'] = self.patch
self._handlers['SERVER-INFO-REQ'] = self.server_info
async def pull(self, msg: msg.pull_doc_req, conn: ServerConnection) -> msg.pull_doc_reply:
return conn.protocol.create('PULL-DOC-REPLY', msg.header['msgid'], self._doc)
async def push(self, msg: msg.push_doc, conn: ServerConnection) -> msg.ok:
msg.push_to_document(self._doc)
return conn.ok(msg)
async def patch(self, msg: msg.patch_doc, conn: ServerConnection) -> msg.ok:
msg.apply_to_document(self._doc, self._ss)
return conn.ok(msg)
async def server_info(self, msg: msg.server_info_req, conn: ServerConnection) -> msg.server_info_reply:
return conn.protocol.create('SERVER-INFO-REPLY', msg.header['msgid'])
async def handle(self, message, conn):
handler = self._handlers.get(message.msgtype)
if handler is None:
handler = self._handlers.get(message.msgtype)
if handler is None:
raise ProtocolError(f"{message} not expected on server")
try:
work = await handler(message, conn)
except Exception as e:
#log.error("error handling message\n message: %r \n error: %r",
# message, e, exc_info=True)
# FIXME
return conn.error(message, repr(e))
#work = e
return work
# server/session.py
class ServerSession:
def __init__(self, session_id: ID, doc: Document, conn: ServerConnection):
self._id = session_id
self._doc = doc
self._conn = conn
self._doc.callbacks.on_change_dispatch_to(self)
def _document_patched(self, event: DocumentPatchedEvent) -> None:
#may_suppress = event.setter is self
# TODO: broadcast all the ServerConnection instants (ws connections?)
asyncio.create_task(self._conn.send_patch_document(event))
import calendar
import datetime as dt
@app.websocket('/ws')
async def ws(ws: WebSocket):
sp = ws.get('subprotocols')
if not sp or len(sp) < 2:
raise HTTPException(status_code=403, detail='Invalid subprotocols')
sub_proto, token = sp
if sub_proto != 'bokeh' or not token:
raise HTTPException(status_code=403, detail='Invalid subprotocols')
payload = get_token_payload(token)
now = calendar.timegm(dt.datetime.utcnow().utctimetuple())
if 'session_expiry' not in payload:
raise HTTPException(status_code=403, detail='Session expiry has not been provided')
elif now >= payload['session_expiry']:
raise HTTPException(status_code=403, detail='Token is expired')
elif not check_token_signature(
token,
secret_key=settings.secret_key_bytes(),
signed=settings.sign_sessions()
):
#session_id = get_session_id(token)
raise HTTPException(status_code=403, detail='Invalid token signature')
session_id = get_session_id(token)
proto = Protocol()
receiver = Receiver(proto)
conn = WSConnAdapter(ws)
sconn = ServerConnection(proto, conn)
ss = ServerSession(session_id, session_ctx._doc, sconn)
handler = ProtocolHandler(session_ctx._doc, ss)
await ws.accept(subprotocol='bokeh')
m = proto.create('ACK')
await m.send(conn)
while True:
data = await ws.receive()
if data['type'] == 'websocket.disconnect':
raise WebSocketDisconnect(data['code'])
# FIXME: what if empty string?
frag = data.get('text') or data.get['bytes']
try:
msg = await receiver.consume(frag)
if msg:
work = await handler.handle(msg, sconn)
if isinstance(work, Message):
await work.send(conn)
except Exception as e:
pprint(e)
# TODO: ping/pong
import uvicorn
if __name__ == '__main__':
uvicorn.run(app, host='0.0.0.0', port=5050)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment