Last active
June 4, 2022 22:54
-
-
Save e-nomem/5b267f98ade768bcb24034615f81c8d6 to your computer and use it in GitHub Desktop.
Hyper Acceptor
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::fmt::Debug; | |
use std::io; | |
use std::marker::PhantomData; | |
use std::net::SocketAddr; | |
use std::pin::Pin; | |
use std::sync::Arc; | |
use std::task::Context; | |
use std::task::Poll; | |
use futures::ready; | |
use futures::Stream; | |
use pin_project_lite::pin_project; | |
use tokio::io::AsyncRead; | |
use tokio::io::AsyncWrite; | |
#[cfg(unix)] | |
use tokio::net::unix::SocketAddr as UnixSocketAddr; | |
use tokio::net::TcpListener; | |
use tokio::net::TcpStream; | |
#[cfg(unix)] | |
use tokio::net::UnixListener; | |
#[cfg(unix)] | |
use tokio::net::UnixStream; | |
use tokio::sync::OwnedSemaphorePermit; | |
use tokio::sync::Semaphore; | |
use tokio_util::sync::PollSemaphore; | |
pub trait AcceptableListener { | |
type Stream: AsyncRead + AsyncWrite; | |
type Addr; | |
fn poll_accept(&self, cx: &'_ mut Context) -> Poll<io::Result<(Self::Stream, Self::Addr)>>; | |
} | |
impl AcceptableListener for TcpListener { | |
type Stream = TcpStream; | |
type Addr = SocketAddr; | |
fn poll_accept(&self, cx: &'_ mut Context) -> Poll<io::Result<(Self::Stream, Self::Addr)>> { | |
self.poll_accept(cx) | |
} | |
} | |
#[cfg(unix)] | |
impl AcceptableListener for UnixListener { | |
type Stream = UnixStream; | |
type Addr = UnixSocketAddr; | |
fn poll_accept(&self, cx: &'_ mut Context) -> Poll<io::Result<(Self::Stream, Self::Addr)>> { | |
self.poll_accept(cx) | |
} | |
} | |
pub trait IntoAcceptableListener { | |
type Listener: AcceptableListener; | |
fn into_listener(self) -> io::Result<Self::Listener>; | |
} | |
impl<T> IntoAcceptableListener for T | |
where | |
T: AcceptableListener, | |
{ | |
type Listener = Self; | |
fn into_listener(self) -> io::Result<Self::Listener> { | |
Ok(self) | |
} | |
} | |
impl IntoAcceptableListener for std::net::TcpListener { | |
type Listener = TcpListener; | |
fn into_listener(self) -> io::Result<Self::Listener> { | |
self.set_nonblocking(true)?; | |
TcpListener::from_std(self) | |
} | |
} | |
#[cfg(unix)] | |
impl IntoAcceptableListener for std::os::unix::net::UnixListener { | |
type Listener = UnixListener; | |
fn into_listener(self) -> io::Result<Self::Listener> { | |
self.set_nonblocking(true)?; | |
UnixListener::from_std(self) | |
} | |
} | |
impl IntoAcceptableListener for SocketAddr { | |
type Listener = TcpListener; | |
fn into_listener(self) -> io::Result<Self::Listener> { | |
let listener = std::net::TcpListener::bind(self)?; | |
listener.into_listener() | |
} | |
} | |
#[cfg(unix)] | |
impl IntoAcceptableListener for std::os::unix::net::SocketAddr { | |
type Listener = UnixListener; | |
fn into_listener(self) -> io::Result<Self::Listener> { | |
let listener = std::os::unix::net::UnixListener::bind_addr(&self)?; | |
listener.into_listener() | |
} | |
} | |
pin_project!( | |
pub struct PermittedConnection<C> { | |
#[pin] | |
pub(crate) conn: C, | |
pub(crate) _permit: Option<OwnedSemaphorePermit>, | |
} | |
); | |
impl<C> AsyncRead for PermittedConnection<C> | |
where | |
C: AsyncRead, | |
{ | |
fn poll_read( | |
self: Pin<&mut Self>, | |
cx: &mut Context<'_>, | |
buf: &mut tokio::io::ReadBuf<'_>, | |
) -> Poll<io::Result<()>> { | |
let this = self.project(); | |
this.conn.poll_read(cx, buf) | |
} | |
} | |
impl<C> AsyncWrite for PermittedConnection<C> | |
where | |
C: AsyncWrite, | |
{ | |
fn poll_write( | |
self: Pin<&mut Self>, | |
cx: &mut Context<'_>, | |
buf: &[u8], | |
) -> Poll<Result<usize, io::Error>> { | |
let this = self.project(); | |
this.conn.poll_write(cx, buf) | |
} | |
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> { | |
let this = self.project(); | |
this.conn.poll_flush(cx) | |
} | |
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> { | |
let this = self.project(); | |
this.conn.poll_shutdown(cx) | |
} | |
} | |
pub struct AcceptStream<L, F, C> { | |
listener: L, | |
semaphore: Option<PollSemaphore>, | |
mapper: Pin<Box<F>>, | |
_mapper: PhantomData<fn() -> C>, | |
} | |
impl AcceptStream<(), (), ()> { | |
pub fn builder() -> AcceptStreamBuilder { | |
Default::default() | |
} | |
} | |
impl<L, F, C> Stream for AcceptStream<L, F, C> | |
where | |
L: AcceptableListener, | |
C: AsyncRead + AsyncWrite, | |
F: Fn(L::Stream, L::Addr) -> Option<C>, | |
{ | |
type Item = io::Result<Pin<Box<PermittedConnection<C>>>>; | |
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { | |
let permit = match self.semaphore.clone() { | |
Some(mut semaphore) => { | |
match ready!(semaphore.poll_acquire(cx)) { | |
Some(p) => Some(p), | |
None => return Poll::Ready(None), // Semaphore is closed, end the stream | |
} | |
} | |
None => None, | |
}; | |
loop { | |
let (stream, addr) = ready!(self.listener.poll_accept(cx))?; | |
if let Some(conn) = self.mapper.as_ref()(stream, addr) { | |
break Poll::Ready(Some(Ok(Box::pin(PermittedConnection { | |
conn, | |
_permit: permit, | |
})))); | |
} | |
} | |
} | |
} | |
#[derive(Debug, Default)] | |
pub struct AcceptStreamBuilder { | |
max_connections: Option<usize>, | |
} | |
impl AcceptStreamBuilder { | |
pub fn new() -> Self { | |
Default::default() | |
} | |
pub fn build<A>( | |
self, | |
addr: A, | |
) -> io::Result< | |
AcceptStream< | |
A::Listener, | |
impl Fn( | |
<A::Listener as AcceptableListener>::Stream, | |
<A::Listener as AcceptableListener>::Addr, | |
) -> Option<<A::Listener as AcceptableListener>::Stream>, | |
<A::Listener as AcceptableListener>::Stream, | |
>, | |
> | |
where | |
A: IntoAcceptableListener, | |
{ | |
self.build_with_mapper(addr, |s, _| Some(s)) | |
} | |
pub fn build_with_mapper<A, F, C>( | |
self, | |
addr: A, | |
mapper: F, | |
) -> io::Result<AcceptStream<A::Listener, F, C>> | |
where | |
A: IntoAcceptableListener, | |
C: AsyncRead + AsyncWrite, | |
F: Fn( | |
<A::Listener as AcceptableListener>::Stream, | |
<A::Listener as AcceptableListener>::Addr, | |
) -> Option<C>, | |
{ | |
let listener = addr.into_listener()?; | |
let semaphore = self.max_connections.map(|limit| { | |
let s = Arc::new(Semaphore::new(limit)); | |
PollSemaphore::new(s) | |
}); | |
Ok(AcceptStream { | |
listener, | |
semaphore, | |
mapper: Box::pin(mapper), | |
_mapper: PhantomData, | |
}) | |
} | |
pub fn with_max_connections(mut self, limit: Option<usize>) -> Self { | |
self.max_connections = limit; | |
self | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::net::SocketAddr; | |
use std::time::Duration; | |
use axum::routing::on; | |
use axum::routing::MethodFilter; | |
use axum::Router; | |
use color_eyre::Result; | |
use hyper::server::accept; | |
use hyper::Server; | |
use tokio_io_timeout::TimeoutStream; | |
use tracing::debug; | |
use server::accept_stream::AcceptStream; | |
async fn run_server() -> Result<()> { | |
// let addr = std::os::unix::net::SocketAddr::from_pathname("./test.sock")?; | |
let addr: SocketAddr = SocketAddr::new("127.0.0.1".parse()?, 8080); | |
let mapper = |s, a: SocketAddr| { | |
let mut stream = TimeoutStream::new(s); | |
debug!(remote_addr = ?a, "applying timeouts to connection stream"); | |
stream.set_read_timeout(Some(Duration::from_secs(5))); | |
stream.set_write_timeout(Some(Duration::from_secs(5))); | |
Some(stream) | |
}; | |
let acceptor = AcceptStream::builder() | |
.with_max_connections(Some(512)) | |
.build_with_mapper(addr, mapper)?; | |
Server::builder(accept::from_stream(acceptor)) | |
.serve(routes().into_make_service()) | |
.await?; | |
Ok(()) | |
} | |
fn routes() -> Router { | |
Router::new().route("/", on(MethodFilter::GET, root)) | |
} | |
async fn root() -> &'static str { | |
"Hello, World!" | |
} | |
#[tokio::main] | |
async fn main() -> Result<()> { | |
run_server().await | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment