Skip to content

Instantly share code, notes, and snippets.

@kemurphy
Created August 21, 2020 08:59
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kemurphy/6704f4c7c7510dcd71202be6d28a99d6 to your computer and use it in GitHub Desktop.
Save kemurphy/6704f4c7c7510dcd71202be6d28a99d6 to your computer and use it in GitHub Desktop.
use std::{
collections::{btree_map, BTreeMap},
env,
io::Error,
sync::{Arc, Once},
};
use anyhow::Context;
use futures::{Sink, Stream, StreamExt};
use log::info;
use thiserror;
use tokio::{
net::{TcpListener, TcpStream},
sync::{mpsc, oneshot, watch, RwLock},
};
use tokio_tungstenite::tungstenite::{
error::Error as WsError,
handshake::server::{ErrorResponse, Request, Response},
protocol::Message,
};
pub struct UserId(i64);
pub struct SessionId(i64);
type ClosingPacket = (watch::Sender, Arc<Notify>);
enum Event {
Message(Value),
Close(oneshot::Sender<ClosingPacket>),
}
type EventSink = impl Sink<Event, Error = WsError> + Debug;
type MessageSink = impl Sink<Message, Error = WsError> + Debug;
type MessageStream = impl Stream<Item = Result<Message, WsError>> + Debug;
pub struct SessionData {
id: SessionId,
user_id: UserId,
sink: EventSink,
}
impl SessionData {
async fn for_user(uid: UserId) -> Option<SessionData> {
let mut session_map = get_session_map().read().await;
match session_map.entry(uid) {
btree_map::Entry::Vacant(_) => None,
btree_map::Entry::Occupied(entry) => {
let rx = entry.get().clone();
Some(rx.recv().await)
}
}
}
async fn close(&self) -> oneshot::Receiver<ClosingChannel> {
let (sender, receiver) = oneshot::channel();
self.sink.send(Event::Close(sender)).await;
receiver
}
}
static mut session_counter: SessionId = SessionId(0);
fn get_session_map() -> RwLock<BTreeMap<UserId, SessionData>> {
static mut session_map_init: Once = Once::new();
static mut session_map: Option<RwLock<BTreeMap<UserId, watch::Receiver<SessionData>>>> = None;
session_map_init.call_once(|| {
session_map = Some(RwLock::new(BTreeMap::new()));
});
session_map.unwrap()
}
impl UserId {
fn from_request(req: &Request) -> anyhow::Result<Session> {
use cookie::Cookie;
use hmac::{Hmac, NewMac};
use http::header::COOKIE;
use jwt::VerifyWithKey;
use sha2::Sha256;
let headers = req.headers();
let secret = headers
.get_all("x-session-secret")
.iter()
.last()
.context(ConnReqError::NoSecret)?
.as_bytes();
let token = headers
.get_all(COOKIE)
.iter()
.filter_map(|v| v.to_str().ok())
.filter_map(|v| Cookie::parse(v).ok())
.filter(|c| c.name() == "hanabi.sid")
.last()
.context(ConnReqError::NoCookie)?;
let key: Hmac<Sha256> = Hmac::new_varkey(secret).or(Err(ConnReqError::BadSecret))?;
let claims: BTreeMap<String, serde_json::Value> = token.value().verify_with_key(&key)?;
claims
.get("id")
.and_then(|v| v.as_i64())
.map(UserId)
.context(ConnReqError::BadCookie)
}
}
impl SessionId {
fn next(&mut self) -> SessionId {
self.0 += 1;
*self
}
}
fn to_event_sink<S>(s: S) -> EventSink
where
S: Sink<Event> + Debug,
<S as Sink>::Error = WsError,
{
s
}
fn to_message_sink<S>(s: S) -> MessageSink
where
S: Sink<Message> + Debug,
<S as Sink>::Error = WsError,
{
s
}
fn to_message_stream<S>(s: S) -> MessageStream
where
S: Stream + Debug,
<S as Stream>::Item = Result<Message, WsError>,
{
s
}
pub struct Session {
user_id: UserId,
stream: MessageStream,
sink: MessageSink,
closing: ClosingPacket,
}
impl Session {
fn new(user_id: UserId, ws: WebSocketStream<TcpStream>) -> Session {
let (event_sink, inner_stream) = mpsc::unbounded();
let sink = event_sink
.clone()
.with_flat_map(|msg| stream::once(Event::Message(msg)));
let data = SessionData {
session_id: session_counter.next(),
user_id: uid,
sink: to_event_sink(event_sink),
};
enum Action<'a> {
Insert(watch::Sender<SessionData>),
Replace {
old: SessionData,
new: SessionData,
rx: &'a mut watch::Receiver<SessionData>,
},
}
let action = {
let mut session_map = get_session_map().write().await;
match session_map.entry(uid) {
btree_map::Entry::Occupied(entry) => {
let rx = entry.into_mut();
Action::Replace {
old: watch_rx.recv().await,
new: data,
rx,
}
}
btree_map::Entry::Vacant(entry) => {
let (tx, rx) = watch::channel(data);
entry.insert(watch_rx);
Action::Insert { tx }
}
}
};
let watch_tx = match action {
Action::Insert(tx) => tx,
Action::Replace { old, new, rx } => {
let chan = old.close().await;
let (watch_tx, notify) = chan.await;
notify.notified().await;
watch_tx.send(data).await
}
};
let closing = (watch_tx, Arc::new(Notify::new()));
let (inner_sink, stream) = ws.split();
let inner_sink = inner_sink.with_flat_map(|evt| match evt {
Event::Message(msg) => stream::once(Message::Text(msg.to_string())),
Event::Close(chan) => {
chan.send(closing.clone()).await;
stream::empty()
}
});
Session {
user_id,
stream: to_message_stream(stream),
sink: to_message_sink(sink),
closing,
}
}
fn user_id(&self) -> UserId {
self.user_id
}
}
impl Drop for Session {
fn drop(self) {
// XXX TODO probably want to call watch_tx.closing(), but need async drop
get_session_map().remove(self.0);
}
}
impl Sink<Message> for Session {
type Error = WsError;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.stream.poll_ready(cx)
}
fn start_send(self: Pin<&mut Self>, item: Item) -> Result<(), Self::Error> {
self.stream.start_send(item)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.stream.poll_flush(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.stream.poll_close(cx)
}
}
impl Stream for Session {
type Item = Result<Message, WsError>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
self.stream.poll_next(cx)
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.stream.size_hint()
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment