Skip to content

Instantly share code, notes, and snippets.

@iwinux

iwinux/main.rs Secret

Last active Dec 17, 2021
Embed
What would you like to do?
hyper-openssl + axum-server
use std::{
env, io,
net::SocketAddr,
path::Path,
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use anyhow::Result;
use axum::{routing::get, Router};
use axum_server::{accept::Accept, service::SendService, Server};
use futures_util::{future::poll_fn, join, Future};
use hyper::server::{
accept::Accept as HyperAccept,
conn::{AddrIncoming, Http},
};
use log::{error, info};
use openssl::ssl::{Ssl, SslAcceptor, SslContext, SslContextRef, SslFiletype, SslMethod};
use pin_project_lite::pin_project;
use tokio::{
io::{AsyncRead, AsyncWrite},
task::spawn,
};
use tokio_openssl::SslStream;
#[derive(Clone)]
struct TLSAcceptor {
ctx: Arc<SslContext>,
}
impl TLSAcceptor {
fn new<P: AsRef<Path>>(cert_path: P, key_path: P) -> Result<Self> {
let mut builder = SslAcceptor::mozilla_modern(SslMethod::tls_server())?;
builder.set_certificate_file(&cert_path, SslFiletype::PEM)?;
builder.set_private_key_file(&key_path, SslFiletype::PEM)?;
let ctx = Arc::new(builder.build().into_context());
Ok(Self { ctx })
}
}
async fn accept_stream<S>(ctx: &SslContextRef, stream: S) -> Result<SslStream<S>>
where
S: AsyncRead + AsyncWrite + Unpin,
{
let session = Ssl::new(&ctx)?;
let mut stream = SslStream::new(session, stream)?;
Pin::new(&mut stream).accept().await?;
Ok(stream)
}
impl<IO, Service> Accept<IO, Service> for TLSAcceptor
where
IO: AsyncRead + AsyncWrite + Unpin,
{
type Stream = SslStream<IO>;
type Service = Service;
type Future = AcceptFuture<IO, Service>;
fn accept(&self, stream: IO, service: Service) -> Self::Future {
AcceptFuture {
tls_ctx: self.ctx.clone(),
tcp_stream: Some((stream, service)),
}
}
}
pin_project! {
pub struct AcceptFuture<IO, Service> {
tls_ctx: Arc<SslContext>,
tcp_stream: Option<(IO, Service)>,
}
}
impl<IO, Service> Future for AcceptFuture<IO, Service>
where
IO: AsyncRead + AsyncWrite + Unpin,
{
type Output = io::Result<(SslStream<IO>, Service)>;
fn poll(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
let pin = self.project();
let (tcp_stream, service) = pin.tcp_stream.take().expect("must not be None");
let mut tls_future = Box::pin(accept_stream(&pin.tls_ctx, tcp_stream));
Pin::new(&mut tls_future)
.poll(ctx)
.map_ok(|stream| (stream, service))
.map_err(|err| io::Error::new(io::ErrorKind::Other, format!("TLS Error: {:?}", err)))
}
}
fn dummy_app() -> Router {
Router::new().route("/", get(|| async { "OK." }))
}
async fn serve_tls_loop(addr: SocketAddr, acceptor: TLSAcceptor) -> Result<()> {
let mut listener = AddrIncoming::bind(&addr)?;
let http = Http::new();
let service = dummy_app().into_service();
loop {
let addr_stream = poll_fn(|ctx| Pin::new(&mut listener).poll_accept(ctx))
.await
.unwrap()?;
let tls_ctx = acceptor.ctx.clone();
let http = http.clone();
let service = service.clone();
spawn(async move {
let tls_stream = match accept_stream(&tls_ctx, addr_stream).await {
Ok(tls_stream) => tls_stream,
Err(err) => {
error!("failed to create TLS stream: {:?}", err);
return;
}
};
if let Err(err) = http.serve_connection(tls_stream, service).await {
error!("failed to handle HTTP request: {:?}", err);
}
});
}
}
fn serve_tls_axum_server(
addr: SocketAddr,
acceptor: TLSAcceptor,
) -> impl Future<Output = Result<(), io::Error>> {
let service = dummy_app().into_make_service();
Server::bind(addr).acceptor(acceptor).serve(service)
}
#[tokio::main]
async fn main() -> Result<()> {
if env::var_os("RUST_LOG").is_none() {
env::set_var("RUST_LOG", "INFO")
}
pretty_env_logger::init_timed();
let cert_path = env::var("SERVER_CERT").unwrap_or_else(|_| "server.crt".into());
let key_path = env::var("SERVER_KEY").unwrap_or_else(|_| "server.key".into());
let acceptor = TLSAcceptor::new(cert_path, key_path)?;
// `curl -k https://127.0.0.1:8443` works
let addr1: SocketAddr = "127.0.0.1:8443".parse()?;
let join_handle1 = spawn(serve_tls_loop(addr1, acceptor.clone()));
// `curl -k https://127.0.0.1:9443` fails with error "OpenSSL SSL_connect: SSL_ERROR_SYSCALL"
// tcpdump shows that the server sends `FIN` before client sends `TLS Client Hello`
let addr2: SocketAddr = "127.0.0.1:9443".parse()?;
let join_handle2 = spawn(serve_tls_axum_server(addr2, acceptor));
let _ = join!(join_handle1, join_handle2);
info!("server has been shut down");
Ok(())
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment