Created
April 2, 2019 13:04
-
-
Save izderadicka/bb25bf19b05945c9945dcf2526b1e33c to your computer and use it in GitHub Desktop.
Hyper Websockets
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#[macro_use] | |
extern crate log; | |
use hyper::rt; | |
use hyper::server::Server; | |
use hyper::service::service_fn_ok; | |
use hyper::{Body, Request, Response}; | |
use tokio::prelude::*; | |
mod ws; | |
/// Our server HTTP handler to initiate HTTP upgrades. | |
fn server_upgrade(req: Request<Body>) -> Response<Body> { | |
debug!("We got these headers: {:?}", req.headers()); | |
let res = match ws::upgrade_connection(req) { | |
Err(r) => r, | |
Ok((r, ws_future)) => { | |
let ws_process = ws_future | |
.map_err(|err| error!("Cannot create websocket: {} ", err)) | |
.and_then(|ws| { | |
let (tx, rc) = ws.split(); | |
rc.map(|m| { | |
debug!("Got message {:?}", m); | |
m | |
}) | |
.forward(tx) | |
.map(|_| debug!("Websocket has ended")) | |
.map_err(|err| error!("Socket error {}", err)) | |
}); | |
rt::spawn(ws_process); | |
r | |
} | |
}; | |
res | |
} | |
fn main() { | |
pretty_env_logger::init(); | |
let addr = ([127, 0, 0, 1], 5000).into(); | |
let server = Server::bind(&addr) | |
.serve(|| service_fn_ok(server_upgrade)) | |
.map_err(|e| eprintln!("server error: {}", e)); | |
rt::run(server); | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use futures::future::poll_fn; | |
use futures::prelude::*; | |
use headers::{self, HeaderMapExt}; | |
use hyper::header::{self, AsHeaderName, HeaderMap, HeaderValue}; | |
use hyper::{Body, Request, Response, StatusCode}; | |
use quick_error::quick_error; | |
use std::fmt; | |
use std::io; | |
use tungstenite::protocol; | |
quick_error! { | |
#[derive(Debug)] | |
pub enum Error { | |
Ws(err: tungstenite::Error) { | |
from() | |
} | |
Io(err: io::Error) { | |
from() | |
} | |
InvalidMessageType { | |
description("Message is of incorrect type (binary vs text)") | |
} | |
} | |
} | |
fn header_matches<S: AsHeaderName>(headers: &HeaderMap<HeaderValue>, name: S, value: &str) -> bool { | |
headers | |
.get(name) | |
.and_then(|v| v.to_str().ok()) | |
.map(|v| v.to_lowercase() == value) | |
.unwrap_or(false) | |
} | |
pub fn upgrade_connection( | |
req: Request<Body>, | |
) -> Result< | |
( | |
Response<Body>, | |
impl Future<Item = WebSocket, Error = hyper::Error>+Send, | |
), | |
Response<Body>, | |
> { | |
let mut res = Response::new(Body::empty()); | |
let mut header_error = false; | |
debug!("We got these headers: {:?}", req.headers()); | |
if !header_matches(req.headers(), header::UPGRADE, "websocket") { | |
error!("Upgrade is not to websocket"); | |
header_error = true; | |
} | |
if !header_matches(req.headers(), header::SEC_WEBSOCKET_VERSION, "13") { | |
error!("Websocket protocol version must be 13"); | |
header_error = true; | |
} | |
if !req | |
.headers() | |
.typed_get::<headers::Connection>() | |
.map(|h| h.contains("Upgrade")) | |
.unwrap_or(false) | |
{ | |
error!("It must be upgrade connection"); | |
header_error = true; | |
} | |
let key = req.headers().typed_get::<headers::SecWebsocketKey>(); | |
if key.is_none() { | |
error!("Websocket key missing"); | |
header_error = true; | |
} | |
if header_error { | |
*res.status_mut() = StatusCode::BAD_REQUEST; | |
return Err(res); | |
} | |
*res.status_mut() = StatusCode::SWITCHING_PROTOCOLS; | |
let h = res.headers_mut(); | |
h.typed_insert(headers::Upgrade::websocket()); | |
h.typed_insert(headers::SecWebsocketAccept::from(key.unwrap())); | |
h.typed_insert(headers::Connection::upgrade()); | |
let upgraded = req.into_body().on_upgrade().map(|upgraded| { | |
debug!("Connection upgraded to websocket"); | |
WebSocket::new(upgraded) | |
}); | |
Ok((res, upgraded)) | |
} | |
/// A websocket `Stream` and `Sink` | |
pub struct WebSocket { | |
inner: protocol::WebSocket<::hyper::upgrade::Upgraded>, | |
} | |
impl WebSocket { | |
pub(crate) fn new(upgraded: hyper::upgrade::Upgraded) -> Self { | |
let inner = protocol::WebSocket::from_raw_socket(upgraded, protocol::Role::Server, None); | |
WebSocket { inner } | |
} | |
/// Gracefully close this websocket. | |
pub fn close(mut self) -> impl Future<Item = (), Error = Error> { | |
poll_fn(move || Sink::close(&mut self)) | |
} | |
} | |
impl Stream for WebSocket { | |
type Item = Message; | |
type Error = Error; | |
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> { | |
loop { | |
let msg = match self.inner.read_message() { | |
Ok(item) => item, | |
Err(::tungstenite::Error::Io(ref err)) | |
if err.kind() == io::ErrorKind::WouldBlock => | |
{ | |
return Ok(Async::NotReady); | |
} | |
Err(::tungstenite::Error::ConnectionClosed(frame)) => { | |
trace!("websocket closed: {:?}", frame); | |
return Ok(Async::Ready(None)); | |
} | |
Err(e) => { | |
debug!("websocket poll error: {}", e); | |
return Err(Error::Ws(e)); | |
} | |
}; | |
match msg { | |
msg @ protocol::Message::Text(..) | |
| msg @ protocol::Message::Binary(..) | |
| msg @ protocol::Message::Ping(..) => { | |
return Ok(Async::Ready(Some(Message { inner: msg }))); | |
} | |
protocol::Message::Pong(payload) => { | |
trace!("websocket client pong: {:?}", payload); | |
} | |
} | |
} | |
} | |
} | |
impl Sink for WebSocket { | |
type SinkItem = Message; | |
type SinkError = Error; | |
fn start_send(&mut self, item: Self::SinkItem) -> StartSend<Self::SinkItem, Self::SinkError> { | |
match item.inner { | |
protocol::Message::Ping(..) => { | |
// warp doesn't yet expose a way to construct a `Ping` message, | |
// so the only way this could is if the user is forwarding the | |
// received `Ping`s straight back. | |
// | |
// tungstenite already auto-reponds to `Ping`s with a `Pong`, | |
// so this just prevents accidentally sending extra pings. | |
return Ok(AsyncSink::Ready); | |
} | |
_ => (), | |
} | |
match self.inner.write_message(item.inner) { | |
Ok(()) => Ok(AsyncSink::Ready), | |
Err(::tungstenite::Error::SendQueueFull(inner)) => { | |
debug!("websocket send queue full"); | |
Ok(AsyncSink::NotReady(Message { inner })) | |
} | |
Err(::tungstenite::Error::Io(ref err)) if err.kind() == io::ErrorKind::WouldBlock => { | |
// the message was accepted and partly written, so this | |
// isn't an error. | |
Ok(AsyncSink::Ready) | |
} | |
Err(e) => { | |
debug!("websocket start_send error: {}", e); | |
Err(Error::Ws(e)) | |
} | |
} | |
} | |
fn poll_complete(&mut self) -> Poll<(), Self::SinkError> { | |
match self.inner.write_pending() { | |
Ok(()) => Ok(Async::Ready(())), | |
Err(::tungstenite::Error::Io(ref err)) if err.kind() == io::ErrorKind::WouldBlock => { | |
Ok(Async::NotReady) | |
} | |
Err(err) => { | |
debug!("websocket poll_complete error: {}", err); | |
Err(Error::Ws(err)) | |
} | |
} | |
} | |
fn close(&mut self) -> Poll<(), Self::SinkError> { | |
match self.inner.close(None) { | |
Ok(()) => Ok(Async::Ready(())), | |
Err(::tungstenite::Error::Io(ref err)) if err.kind() == io::ErrorKind::WouldBlock => { | |
Ok(Async::NotReady) | |
} | |
Err(::tungstenite::Error::ConnectionClosed(frame)) => { | |
trace!("websocket closed: {:?}", frame); | |
return Ok(Async::Ready(())); | |
} | |
Err(err) => { | |
debug!("websocket close error: {}", err); | |
Err(Error::Ws(err)) | |
} | |
} | |
} | |
} | |
impl fmt::Debug for WebSocket { | |
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { | |
f.debug_struct("WebSocket").finish() | |
} | |
} | |
/// A WebSocket message. | |
/// | |
/// Only repesents Text and Binary messages. | |
/// | |
/// This will likely become a `non-exhaustive` enum in the future, once that | |
/// language feature has stabilized. | |
#[derive(Eq, PartialEq, Clone)] | |
pub struct Message { | |
inner: protocol::Message, | |
} | |
impl Message { | |
/// Construct a new Text `Message`. | |
pub fn text<S: Into<String>>(s: S) -> Message { | |
Message { | |
inner: protocol::Message::text(s), | |
} | |
} | |
/// Construct a new Binary `Message`. | |
pub fn binary<V: Into<Vec<u8>>>(v: V) -> Message { | |
Message { | |
inner: protocol::Message::binary(v), | |
} | |
} | |
/// Returns true if this message is a Text message. | |
pub fn is_text(&self) -> bool { | |
self.inner.is_text() | |
} | |
/// Returns true if this message is a Binary message. | |
pub fn is_binary(&self) -> bool { | |
self.inner.is_binary() | |
} | |
/// Returns true if this message is a Ping message. | |
pub fn is_ping(&self) -> bool { | |
self.inner.is_ping() | |
} | |
/// Try to get a reference to the string text, if this is a Text message. | |
pub fn to_str(&self) -> Result<&str, Error> { | |
match self.inner { | |
protocol::Message::Text(ref s) => Ok(s), | |
_ => Err(Error::InvalidMessageType), | |
} | |
} | |
/// Return the bytes of this message. | |
pub fn as_bytes(&self) -> &[u8] { | |
match self.inner { | |
protocol::Message::Text(ref s) => s.as_bytes(), | |
protocol::Message::Binary(ref v) => v, | |
_ => unreachable!(), | |
} | |
} | |
} | |
impl fmt::Debug for Message { | |
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { | |
fmt::Debug::fmt(&self.inner, f) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment