Skip to content

Instantly share code, notes, and snippets.

@dacut
Created July 5, 2020 23:38
Show Gist options
  • Save dacut/13929d70eac74aa0f7f66331f38b4daa to your computer and use it in GitHub Desktop.
Save dacut/13929d70eac74aa0f7f66331f38b4daa to your computer and use it in GitHub Desktop.
TlsIncoming implementation for Hyper
use std::future::Future;
use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};
use hyper::server::accept::{Accept as HyperAccept};
use tokio::net::{TcpListener, TcpStream};
use tokio_rustls::{Accept, TlsAcceptor};
use tokio_rustls::server::TlsStream;
pub struct TlsIncoming<'a> {
listener: &'a mut TcpListener,
acceptor: &'a mut TlsAcceptor,
tls_stream_accept: Option<Pin<Box<Accept<TcpStream>>>>,
}
impl <'a> TlsIncoming<'a> {
pub fn new(listener: &'a mut TcpListener, acceptor: &'a mut TlsAcceptor) -> TlsIncoming<'a> {
TlsIncoming { listener: listener, acceptor: acceptor, tls_stream_accept: None }
}
}
impl <'a> HyperAccept for TlsIncoming<'a> {
type Conn = TlsStream<TcpStream>;
type Error = io::Error;
/// Attempts to poll `TcpStream` by polling inner `TcpListener` to accept
/// connection.
///
/// If `TcpListener` isn't ready yet, `Poll::Pending` is returned and
/// current task will be notified by a waker.
fn poll_accept(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<io::Result<TlsStream<TcpStream>>>> {
if self.tls_stream_accept.is_none() {
// Need to poll the TCP listener
self.tls_stream_accept = match self.listener.poll_accept(cx) {
Poll::Ready(t) => match t {
Ok((tcp_stream, _)) => Some(Box::pin(self.acceptor.accept(tcp_stream))),
Err(e) => return Poll::Ready(Some(Err(e))),
},
Poll::Pending => return Poll::Pending,
};
};
// If we reach here, tls_stream_accept is guaranteed to be Some(...).
let accept: &mut Pin<Box<Accept<TcpStream>>> = self.tls_stream_accept.as_mut().unwrap();
match accept.as_mut().poll(cx) {
Poll::Ready(t) => Poll::Ready(Some(t)),
Poll::Pending => Poll::Pending,
}
}
}
@dacut
Copy link
Author

dacut commented Jul 5, 2020

Then using this:

use hyper::Server;
use tokio::net::TcpListener;
use tokio_rustls::TlsAcceptor;
mod tls;
use crate::tls::TlsIncoming;

fn server() {
    let tls_config = ...;
    let mut tcp_listener = TcpListener::bind(...)).unwrap()
    let mut tls_acceptor = TlsAcceptor::from(tls_config);
    let incoming = TlsIncoming::new(&mut tcp_listener, &mut tls_acceptor);
    Server::builder(incoming).serve(...);
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment