Created
August 21, 2020 08:43
-
-
Save kemurphy/695f9758d8cdb6846d7237909c3ee65f to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::{ | |
collections::{btree_map, BTreeMap}, | |
env, | |
io::Error, | |
sync::{Arc, Once}, | |
}; | |
use anyhow::{Context}; | |
use futures::{Sink, Stream, StreamExt}; | |
use log::info; | |
use thiserror; | |
use tokio::{ | |
net::{TcpListener, TcpStream}, | |
sync::{mpsc, oneshot, watch, RwLock}, | |
}; | |
use tokio_tungstenite::tungstenite::{ | |
handshake::server::{ErrorResponse, Request, Response}, | |
protocol::Message, | |
error::Error as WsError, | |
}; | |
pub struct UserId(i64); | |
pub struct SessionId(i64); | |
type ClosingPacket = (watch::Sender, Arc<Notify>); | |
enum Event { | |
Message(Value), | |
Close(oneshot::Sender<ClosingPacket>), | |
} | |
type EventSink = impl Sink<Event, Error = WsError> + Debug; | |
type MessageSink = impl Sink<Message, Error = WsError> + Debug; | |
type MessageStream = impl Stream<Item = Result<Message, WsError>> + Debug; | |
pub struct SessionData { | |
id: SessionId, | |
user_id: UserId, | |
sink: EventSink, | |
}; | |
impl SessionData { | |
async fn close(&self) -> oneshot::Receiver<ClosingChannel> { | |
let (sender, receiver) = oneshot::channel(); | |
self.sink.send(Event::Close(sender)).await; | |
receiver | |
} | |
} | |
static mut session_counter: SessionId = SessionId(0); | |
fn get_session_map() -> RwLock<BTreeMap<UserId, SessionData>> { | |
static mut session_map_init: Once = Once::new(); | |
static mut session_map: Option<RwLock<BTreeMap<UserId, watch::Receiver<SessionData>>>> = None; | |
session_map_init.call_once(|| { | |
session_map = Some(RwLock::new(BTreeMap::new())); | |
}); | |
session_map.unwrap() | |
} | |
impl UserId { | |
fn from_request(req: &Request) -> anyhow::Result<Session> { | |
use cookie::Cookie; | |
use hmac::{Hmac, NewMac}; | |
use http::header::COOKIE; | |
use jwt::VerifyWithKey; | |
use sha2::Sha256; | |
let headers = req.headers(); | |
let secret = headers | |
.get_all("x-session-secret") | |
.iter() | |
.last() | |
.context(ConnReqError::NoSecret)? | |
.as_bytes(); | |
let token = headers | |
.get_all(COOKIE) | |
.iter() | |
.filter_map(|v| v.to_str().ok()) | |
.filter_map(|v| Cookie::parse(v).ok()) | |
.filter(|c| c.name() == "hanabi.sid") | |
.last() | |
.context(ConnReqError::NoCookie)?; | |
let key: Hmac<Sha256> = Hmac::new_varkey(secret) | |
.or(Err(ConnReqError::BadSecret))?; | |
let claims: BTreeMap<String, serde_json::Value> = | |
token.value().verify_with_key(&key)?; | |
claims | |
.get("id") | |
.and_then(|v| v.as_i64()) | |
.map(UserId) | |
.context(ConnReqError::BadCookie) | |
} | |
} | |
impl SessionId { | |
fn next(&mut self) -> SessionId { | |
self.0++; | |
*self | |
} | |
} | |
fn to_event_sink<S>(s: S) -> EventSink { | |
where S: Sink<Event> + Debug, | |
<S as Sink>::Error = WsError { | |
s | |
} | |
fn to_message_sink<S>(s: S) -> MessageSink { | |
where S: Sink<Message> + Debug, | |
<S as Sink>::Error = WsError { | |
s | |
} | |
fn to_message_stream<S>(s: S) -> MessageStream { | |
where S: + Stream + Debug, | |
<S as Stream>::Item = Result<Message, WsError> { | |
s | |
} | |
pub struct Session { | |
user_id: UserId, | |
stream: MessageStream, | |
sink: MessageSink, | |
closing: ClosingPacket, | |
} | |
impl Session { | |
fn new(user_id: UserId, ws: WebSocketStream<TcpStream>) -> Session { | |
let (event_sink, inner_stream) = mpsc::unbounded(); | |
let sink = event_sink.clone().with_flat_map(|msg| { | |
stream::once(Event::Message(msg)) | |
}; | |
let data = SessionData { | |
session_id: session_counter.next(), | |
user_id: uid, | |
sink: to_event_sink(event_sink), | |
} | |
enum Action<'a> { | |
Insert(watch::Sender<SessionData>), | |
Replace { | |
old: SessionData, | |
new: SessionData, | |
rx: &'a mut watch::Receiver<SessionData>, | |
}, | |
} | |
let action = { | |
let mut session_map = get_session_map().write().await; | |
match session_map.entry(uid) { | |
btree_map::Entry::Occupied(entry) => { | |
let rx = entry.into_mut(); | |
Action::Replace { | |
old: watch_rx.recv().await, | |
new: data, | |
rx, | |
} | |
}, | |
btree_map::Entry::Vacant(entry) => { | |
let (tx, rx) = watch::channel(data); | |
entry.insert(watch_rx); | |
Action::Insert { | |
tx, | |
} | |
} | |
} | |
} | |
let watch_tx = match action { | |
Action::Insert(tx) => tx, | |
Action::Replace { old, new, rx } => { | |
let chan = old.close().await; | |
let (watch_tx, notify) = chan.await; | |
notify.notified().await; | |
watch_tx.send(data).await | |
} | |
} | |
let closing = (watch_tx, Arc::new(Notify::new())); | |
let (inner_sink, stream) = ws.split(); | |
let inner_sink = inner_sink.with_flat_map(|evt| { | |
match evt { | |
Event::Message(msg) => stream::once(Message::Text(msg.to_string())), | |
Event::Close(chan) => { | |
chan.send(closing.clone()).await; | |
stream::empty() | |
} | |
} | |
}; | |
Session { | |
user_id, | |
stream: to_message_stream(stream), | |
sink: to_message_sink(sink), | |
closing, | |
} | |
} | |
fn user_id(&self) -> UserId { | |
self.user_id | |
} | |
} | |
impl Drop for Session { | |
fn drop(self) { | |
// XXX TODO probably want to call watch_tx.closing(), but need async drop | |
get_session_map().remove(self.0); | |
} | |
} | |
impl Sink<Message> for Session { | |
type Error = WsError; | |
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> { | |
self.stream.poll_ready(cx) | |
} | |
fn start_send(self: Pin<&mut Self>, item: Item) -> Result<(), Self::Error> { | |
self.stream.start_send(item) | |
} | |
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> { | |
self.stream.poll_flush(cx) | |
} | |
fn poll_close( self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> { | |
self.stream.poll_close(cx) | |
} | |
} | |
impl Stream for Session { | |
type Item = Result<Message, WsError>; | |
fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> { | |
self.stream.poll_next(cx) | |
} | |
fn size_hint(&self) -> (usize, Option<usize>) { | |
self.stream.size_hint() | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment