Skip to content

Instantly share code, notes, and snippets.

@izderadicka
Created April 2, 2019 13:04
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save izderadicka/bb25bf19b05945c9945dcf2526b1e33c to your computer and use it in GitHub Desktop.
Save izderadicka/bb25bf19b05945c9945dcf2526b1e33c to your computer and use it in GitHub Desktop.
Hyper Websockets
#[macro_use]
extern crate log;
use hyper::rt;
use hyper::server::Server;
use hyper::service::service_fn_ok;
use hyper::{Body, Request, Response};
use tokio::prelude::*;
mod ws;
/// Our server HTTP handler to initiate HTTP upgrades.
fn server_upgrade(req: Request<Body>) -> Response<Body> {
debug!("We got these headers: {:?}", req.headers());
let res = match ws::upgrade_connection(req) {
Err(r) => r,
Ok((r, ws_future)) => {
let ws_process = ws_future
.map_err(|err| error!("Cannot create websocket: {} ", err))
.and_then(|ws| {
let (tx, rc) = ws.split();
rc.map(|m| {
debug!("Got message {:?}", m);
m
})
.forward(tx)
.map(|_| debug!("Websocket has ended"))
.map_err(|err| error!("Socket error {}", err))
});
rt::spawn(ws_process);
r
}
};
res
}
fn main() {
pretty_env_logger::init();
let addr = ([127, 0, 0, 1], 5000).into();
let server = Server::bind(&addr)
.serve(|| service_fn_ok(server_upgrade))
.map_err(|e| eprintln!("server error: {}", e));
rt::run(server);
}
use futures::future::poll_fn;
use futures::prelude::*;
use headers::{self, HeaderMapExt};
use hyper::header::{self, AsHeaderName, HeaderMap, HeaderValue};
use hyper::{Body, Request, Response, StatusCode};
use quick_error::quick_error;
use std::fmt;
use std::io;
use tungstenite::protocol;
quick_error! {
#[derive(Debug)]
pub enum Error {
Ws(err: tungstenite::Error) {
from()
}
Io(err: io::Error) {
from()
}
InvalidMessageType {
description("Message is of incorrect type (binary vs text)")
}
}
}
fn header_matches<S: AsHeaderName>(headers: &HeaderMap<HeaderValue>, name: S, value: &str) -> bool {
headers
.get(name)
.and_then(|v| v.to_str().ok())
.map(|v| v.to_lowercase() == value)
.unwrap_or(false)
}
pub fn upgrade_connection(
req: Request<Body>,
) -> Result<
(
Response<Body>,
impl Future<Item = WebSocket, Error = hyper::Error>+Send,
),
Response<Body>,
> {
let mut res = Response::new(Body::empty());
let mut header_error = false;
debug!("We got these headers: {:?}", req.headers());
if !header_matches(req.headers(), header::UPGRADE, "websocket") {
error!("Upgrade is not to websocket");
header_error = true;
}
if !header_matches(req.headers(), header::SEC_WEBSOCKET_VERSION, "13") {
error!("Websocket protocol version must be 13");
header_error = true;
}
if !req
.headers()
.typed_get::<headers::Connection>()
.map(|h| h.contains("Upgrade"))
.unwrap_or(false)
{
error!("It must be upgrade connection");
header_error = true;
}
let key = req.headers().typed_get::<headers::SecWebsocketKey>();
if key.is_none() {
error!("Websocket key missing");
header_error = true;
}
if header_error {
*res.status_mut() = StatusCode::BAD_REQUEST;
return Err(res);
}
*res.status_mut() = StatusCode::SWITCHING_PROTOCOLS;
let h = res.headers_mut();
h.typed_insert(headers::Upgrade::websocket());
h.typed_insert(headers::SecWebsocketAccept::from(key.unwrap()));
h.typed_insert(headers::Connection::upgrade());
let upgraded = req.into_body().on_upgrade().map(|upgraded| {
debug!("Connection upgraded to websocket");
WebSocket::new(upgraded)
});
Ok((res, upgraded))
}
/// A websocket `Stream` and `Sink`
pub struct WebSocket {
inner: protocol::WebSocket<::hyper::upgrade::Upgraded>,
}
impl WebSocket {
pub(crate) fn new(upgraded: hyper::upgrade::Upgraded) -> Self {
let inner = protocol::WebSocket::from_raw_socket(upgraded, protocol::Role::Server, None);
WebSocket { inner }
}
/// Gracefully close this websocket.
pub fn close(mut self) -> impl Future<Item = (), Error = Error> {
poll_fn(move || Sink::close(&mut self))
}
}
impl Stream for WebSocket {
type Item = Message;
type Error = Error;
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
loop {
let msg = match self.inner.read_message() {
Ok(item) => item,
Err(::tungstenite::Error::Io(ref err))
if err.kind() == io::ErrorKind::WouldBlock =>
{
return Ok(Async::NotReady);
}
Err(::tungstenite::Error::ConnectionClosed(frame)) => {
trace!("websocket closed: {:?}", frame);
return Ok(Async::Ready(None));
}
Err(e) => {
debug!("websocket poll error: {}", e);
return Err(Error::Ws(e));
}
};
match msg {
msg @ protocol::Message::Text(..)
| msg @ protocol::Message::Binary(..)
| msg @ protocol::Message::Ping(..) => {
return Ok(Async::Ready(Some(Message { inner: msg })));
}
protocol::Message::Pong(payload) => {
trace!("websocket client pong: {:?}", payload);
}
}
}
}
}
impl Sink for WebSocket {
type SinkItem = Message;
type SinkError = Error;
fn start_send(&mut self, item: Self::SinkItem) -> StartSend<Self::SinkItem, Self::SinkError> {
match item.inner {
protocol::Message::Ping(..) => {
// warp doesn't yet expose a way to construct a `Ping` message,
// so the only way this could is if the user is forwarding the
// received `Ping`s straight back.
//
// tungstenite already auto-reponds to `Ping`s with a `Pong`,
// so this just prevents accidentally sending extra pings.
return Ok(AsyncSink::Ready);
}
_ => (),
}
match self.inner.write_message(item.inner) {
Ok(()) => Ok(AsyncSink::Ready),
Err(::tungstenite::Error::SendQueueFull(inner)) => {
debug!("websocket send queue full");
Ok(AsyncSink::NotReady(Message { inner }))
}
Err(::tungstenite::Error::Io(ref err)) if err.kind() == io::ErrorKind::WouldBlock => {
// the message was accepted and partly written, so this
// isn't an error.
Ok(AsyncSink::Ready)
}
Err(e) => {
debug!("websocket start_send error: {}", e);
Err(Error::Ws(e))
}
}
}
fn poll_complete(&mut self) -> Poll<(), Self::SinkError> {
match self.inner.write_pending() {
Ok(()) => Ok(Async::Ready(())),
Err(::tungstenite::Error::Io(ref err)) if err.kind() == io::ErrorKind::WouldBlock => {
Ok(Async::NotReady)
}
Err(err) => {
debug!("websocket poll_complete error: {}", err);
Err(Error::Ws(err))
}
}
}
fn close(&mut self) -> Poll<(), Self::SinkError> {
match self.inner.close(None) {
Ok(()) => Ok(Async::Ready(())),
Err(::tungstenite::Error::Io(ref err)) if err.kind() == io::ErrorKind::WouldBlock => {
Ok(Async::NotReady)
}
Err(::tungstenite::Error::ConnectionClosed(frame)) => {
trace!("websocket closed: {:?}", frame);
return Ok(Async::Ready(()));
}
Err(err) => {
debug!("websocket close error: {}", err);
Err(Error::Ws(err))
}
}
}
}
impl fmt::Debug for WebSocket {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("WebSocket").finish()
}
}
/// A WebSocket message.
///
/// Only repesents Text and Binary messages.
///
/// This will likely become a `non-exhaustive` enum in the future, once that
/// language feature has stabilized.
#[derive(Eq, PartialEq, Clone)]
pub struct Message {
inner: protocol::Message,
}
impl Message {
/// Construct a new Text `Message`.
pub fn text<S: Into<String>>(s: S) -> Message {
Message {
inner: protocol::Message::text(s),
}
}
/// Construct a new Binary `Message`.
pub fn binary<V: Into<Vec<u8>>>(v: V) -> Message {
Message {
inner: protocol::Message::binary(v),
}
}
/// Returns true if this message is a Text message.
pub fn is_text(&self) -> bool {
self.inner.is_text()
}
/// Returns true if this message is a Binary message.
pub fn is_binary(&self) -> bool {
self.inner.is_binary()
}
/// Returns true if this message is a Ping message.
pub fn is_ping(&self) -> bool {
self.inner.is_ping()
}
/// Try to get a reference to the string text, if this is a Text message.
pub fn to_str(&self) -> Result<&str, Error> {
match self.inner {
protocol::Message::Text(ref s) => Ok(s),
_ => Err(Error::InvalidMessageType),
}
}
/// Return the bytes of this message.
pub fn as_bytes(&self) -> &[u8] {
match self.inner {
protocol::Message::Text(ref s) => s.as_bytes(),
protocol::Message::Binary(ref v) => v,
_ => unreachable!(),
}
}
}
impl fmt::Debug for Message {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
fmt::Debug::fmt(&self.inner, f)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment