Skip to content

Instantly share code, notes, and snippets.

@e-nomem
Last active June 4, 2022 22:54
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save e-nomem/5b267f98ade768bcb24034615f81c8d6 to your computer and use it in GitHub Desktop.
Save e-nomem/5b267f98ade768bcb24034615f81c8d6 to your computer and use it in GitHub Desktop.
Hyper Acceptor
use std::fmt::Debug;
use std::io;
use std::marker::PhantomData;
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::Arc;
use std::task::Context;
use std::task::Poll;
use futures::ready;
use futures::Stream;
use pin_project_lite::pin_project;
use tokio::io::AsyncRead;
use tokio::io::AsyncWrite;
#[cfg(unix)]
use tokio::net::unix::SocketAddr as UnixSocketAddr;
use tokio::net::TcpListener;
use tokio::net::TcpStream;
#[cfg(unix)]
use tokio::net::UnixListener;
#[cfg(unix)]
use tokio::net::UnixStream;
use tokio::sync::OwnedSemaphorePermit;
use tokio::sync::Semaphore;
use tokio_util::sync::PollSemaphore;
pub trait AcceptableListener {
type Stream: AsyncRead + AsyncWrite;
type Addr;
fn poll_accept(&self, cx: &'_ mut Context) -> Poll<io::Result<(Self::Stream, Self::Addr)>>;
}
impl AcceptableListener for TcpListener {
type Stream = TcpStream;
type Addr = SocketAddr;
fn poll_accept(&self, cx: &'_ mut Context) -> Poll<io::Result<(Self::Stream, Self::Addr)>> {
self.poll_accept(cx)
}
}
#[cfg(unix)]
impl AcceptableListener for UnixListener {
type Stream = UnixStream;
type Addr = UnixSocketAddr;
fn poll_accept(&self, cx: &'_ mut Context) -> Poll<io::Result<(Self::Stream, Self::Addr)>> {
self.poll_accept(cx)
}
}
pub trait IntoAcceptableListener {
type Listener: AcceptableListener;
fn into_listener(self) -> io::Result<Self::Listener>;
}
impl<T> IntoAcceptableListener for T
where
T: AcceptableListener,
{
type Listener = Self;
fn into_listener(self) -> io::Result<Self::Listener> {
Ok(self)
}
}
impl IntoAcceptableListener for std::net::TcpListener {
type Listener = TcpListener;
fn into_listener(self) -> io::Result<Self::Listener> {
self.set_nonblocking(true)?;
TcpListener::from_std(self)
}
}
#[cfg(unix)]
impl IntoAcceptableListener for std::os::unix::net::UnixListener {
type Listener = UnixListener;
fn into_listener(self) -> io::Result<Self::Listener> {
self.set_nonblocking(true)?;
UnixListener::from_std(self)
}
}
impl IntoAcceptableListener for SocketAddr {
type Listener = TcpListener;
fn into_listener(self) -> io::Result<Self::Listener> {
let listener = std::net::TcpListener::bind(self)?;
listener.into_listener()
}
}
#[cfg(unix)]
impl IntoAcceptableListener for std::os::unix::net::SocketAddr {
type Listener = UnixListener;
fn into_listener(self) -> io::Result<Self::Listener> {
let listener = std::os::unix::net::UnixListener::bind_addr(&self)?;
listener.into_listener()
}
}
pin_project!(
pub struct PermittedConnection<C> {
#[pin]
pub(crate) conn: C,
pub(crate) _permit: Option<OwnedSemaphorePermit>,
}
);
impl<C> AsyncRead for PermittedConnection<C>
where
C: AsyncRead,
{
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let this = self.project();
this.conn.poll_read(cx, buf)
}
}
impl<C> AsyncWrite for PermittedConnection<C>
where
C: AsyncWrite,
{
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, io::Error>> {
let this = self.project();
this.conn.poll_write(cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
let this = self.project();
this.conn.poll_flush(cx)
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
let this = self.project();
this.conn.poll_shutdown(cx)
}
}
pub struct AcceptStream<L, F, C> {
listener: L,
semaphore: Option<PollSemaphore>,
mapper: Pin<Box<F>>,
_mapper: PhantomData<fn() -> C>,
}
impl AcceptStream<(), (), ()> {
pub fn builder() -> AcceptStreamBuilder {
Default::default()
}
}
impl<L, F, C> Stream for AcceptStream<L, F, C>
where
L: AcceptableListener,
C: AsyncRead + AsyncWrite,
F: Fn(L::Stream, L::Addr) -> Option<C>,
{
type Item = io::Result<Pin<Box<PermittedConnection<C>>>>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let permit = match self.semaphore.clone() {
Some(mut semaphore) => {
match ready!(semaphore.poll_acquire(cx)) {
Some(p) => Some(p),
None => return Poll::Ready(None), // Semaphore is closed, end the stream
}
}
None => None,
};
loop {
let (stream, addr) = ready!(self.listener.poll_accept(cx))?;
if let Some(conn) = self.mapper.as_ref()(stream, addr) {
break Poll::Ready(Some(Ok(Box::pin(PermittedConnection {
conn,
_permit: permit,
}))));
}
}
}
}
#[derive(Debug, Default)]
pub struct AcceptStreamBuilder {
max_connections: Option<usize>,
}
impl AcceptStreamBuilder {
pub fn new() -> Self {
Default::default()
}
pub fn build<A>(
self,
addr: A,
) -> io::Result<
AcceptStream<
A::Listener,
impl Fn(
<A::Listener as AcceptableListener>::Stream,
<A::Listener as AcceptableListener>::Addr,
) -> Option<<A::Listener as AcceptableListener>::Stream>,
<A::Listener as AcceptableListener>::Stream,
>,
>
where
A: IntoAcceptableListener,
{
self.build_with_mapper(addr, |s, _| Some(s))
}
pub fn build_with_mapper<A, F, C>(
self,
addr: A,
mapper: F,
) -> io::Result<AcceptStream<A::Listener, F, C>>
where
A: IntoAcceptableListener,
C: AsyncRead + AsyncWrite,
F: Fn(
<A::Listener as AcceptableListener>::Stream,
<A::Listener as AcceptableListener>::Addr,
) -> Option<C>,
{
let listener = addr.into_listener()?;
let semaphore = self.max_connections.map(|limit| {
let s = Arc::new(Semaphore::new(limit));
PollSemaphore::new(s)
});
Ok(AcceptStream {
listener,
semaphore,
mapper: Box::pin(mapper),
_mapper: PhantomData,
})
}
pub fn with_max_connections(mut self, limit: Option<usize>) -> Self {
self.max_connections = limit;
self
}
}
use std::net::SocketAddr;
use std::time::Duration;
use axum::routing::on;
use axum::routing::MethodFilter;
use axum::Router;
use color_eyre::Result;
use hyper::server::accept;
use hyper::Server;
use tokio_io_timeout::TimeoutStream;
use tracing::debug;
use server::accept_stream::AcceptStream;
async fn run_server() -> Result<()> {
// let addr = std::os::unix::net::SocketAddr::from_pathname("./test.sock")?;
let addr: SocketAddr = SocketAddr::new("127.0.0.1".parse()?, 8080);
let mapper = |s, a: SocketAddr| {
let mut stream = TimeoutStream::new(s);
debug!(remote_addr = ?a, "applying timeouts to connection stream");
stream.set_read_timeout(Some(Duration::from_secs(5)));
stream.set_write_timeout(Some(Duration::from_secs(5)));
Some(stream)
};
let acceptor = AcceptStream::builder()
.with_max_connections(Some(512))
.build_with_mapper(addr, mapper)?;
Server::builder(accept::from_stream(acceptor))
.serve(routes().into_make_service())
.await?;
Ok(())
}
fn routes() -> Router {
Router::new().route("/", on(MethodFilter::GET, root))
}
async fn root() -> &'static str {
"Hello, World!"
}
#[tokio::main]
async fn main() -> Result<()> {
run_server().await
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment