-
-
Save iwinux/b2949d129c381c051a1dfad79f6497fb to your computer and use it in GitHub Desktop.
hyper-openssl + axum-server
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::{ | |
env, io, | |
net::SocketAddr, | |
path::Path, | |
pin::Pin, | |
sync::Arc, | |
task::{Context, Poll}, | |
}; | |
use anyhow::Result; | |
use axum::{routing::get, Router}; | |
use axum_server::{accept::Accept, service::SendService, Server}; | |
use futures_util::{future::poll_fn, join, Future}; | |
use hyper::server::{ | |
accept::Accept as HyperAccept, | |
conn::{AddrIncoming, Http}, | |
}; | |
use log::{error, info}; | |
use openssl::ssl::{Ssl, SslAcceptor, SslContext, SslContextRef, SslFiletype, SslMethod}; | |
use pin_project_lite::pin_project; | |
use tokio::{ | |
io::{AsyncRead, AsyncWrite}, | |
task::spawn, | |
}; | |
use tokio_openssl::SslStream; | |
#[derive(Clone)] | |
struct TLSAcceptor { | |
ctx: Arc<SslContext>, | |
} | |
impl TLSAcceptor { | |
fn new<P: AsRef<Path>>(cert_path: P, key_path: P) -> Result<Self> { | |
let mut builder = SslAcceptor::mozilla_modern(SslMethod::tls_server())?; | |
builder.set_certificate_file(&cert_path, SslFiletype::PEM)?; | |
builder.set_private_key_file(&key_path, SslFiletype::PEM)?; | |
let ctx = Arc::new(builder.build().into_context()); | |
Ok(Self { ctx }) | |
} | |
} | |
async fn accept_stream<S>(ctx: &SslContextRef, stream: S) -> Result<SslStream<S>> | |
where | |
S: AsyncRead + AsyncWrite + Unpin, | |
{ | |
let session = Ssl::new(&ctx)?; | |
let mut stream = SslStream::new(session, stream)?; | |
Pin::new(&mut stream).accept().await?; | |
Ok(stream) | |
} | |
impl<IO, Service> Accept<IO, Service> for TLSAcceptor | |
where | |
IO: AsyncRead + AsyncWrite + Unpin, | |
{ | |
type Stream = SslStream<IO>; | |
type Service = Service; | |
type Future = AcceptFuture<IO, Service>; | |
fn accept(&self, stream: IO, service: Service) -> Self::Future { | |
AcceptFuture { | |
tls_ctx: self.ctx.clone(), | |
tcp_stream: Some((stream, service)), | |
} | |
} | |
} | |
pin_project! { | |
pub struct AcceptFuture<IO, Service> { | |
tls_ctx: Arc<SslContext>, | |
tcp_stream: Option<(IO, Service)>, | |
} | |
} | |
impl<IO, Service> Future for AcceptFuture<IO, Service> | |
where | |
IO: AsyncRead + AsyncWrite + Unpin, | |
{ | |
type Output = io::Result<(SslStream<IO>, Service)>; | |
fn poll(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> { | |
let pin = self.project(); | |
let (tcp_stream, service) = pin.tcp_stream.take().expect("must not be None"); | |
let mut tls_future = Box::pin(accept_stream(&pin.tls_ctx, tcp_stream)); | |
Pin::new(&mut tls_future) | |
.poll(ctx) | |
.map_ok(|stream| (stream, service)) | |
.map_err(|err| io::Error::new(io::ErrorKind::Other, format!("TLS Error: {:?}", err))) | |
} | |
} | |
fn dummy_app() -> Router { | |
Router::new().route("/", get(|| async { "OK." })) | |
} | |
async fn serve_tls_loop(addr: SocketAddr, acceptor: TLSAcceptor) -> Result<()> { | |
let mut listener = AddrIncoming::bind(&addr)?; | |
let http = Http::new(); | |
let service = dummy_app().into_service(); | |
loop { | |
let addr_stream = poll_fn(|ctx| Pin::new(&mut listener).poll_accept(ctx)) | |
.await | |
.unwrap()?; | |
let tls_ctx = acceptor.ctx.clone(); | |
let http = http.clone(); | |
let service = service.clone(); | |
spawn(async move { | |
let tls_stream = match accept_stream(&tls_ctx, addr_stream).await { | |
Ok(tls_stream) => tls_stream, | |
Err(err) => { | |
error!("failed to create TLS stream: {:?}", err); | |
return; | |
} | |
}; | |
if let Err(err) = http.serve_connection(tls_stream, service).await { | |
error!("failed to handle HTTP request: {:?}", err); | |
} | |
}); | |
} | |
} | |
fn serve_tls_axum_server( | |
addr: SocketAddr, | |
acceptor: TLSAcceptor, | |
) -> impl Future<Output = Result<(), io::Error>> { | |
let service = dummy_app().into_make_service(); | |
Server::bind(addr).acceptor(acceptor).serve(service) | |
} | |
#[tokio::main] | |
async fn main() -> Result<()> { | |
if env::var_os("RUST_LOG").is_none() { | |
env::set_var("RUST_LOG", "INFO") | |
} | |
pretty_env_logger::init_timed(); | |
let cert_path = env::var("SERVER_CERT").unwrap_or_else(|_| "server.crt".into()); | |
let key_path = env::var("SERVER_KEY").unwrap_or_else(|_| "server.key".into()); | |
let acceptor = TLSAcceptor::new(cert_path, key_path)?; | |
// `curl -k https://127.0.0.1:8443` works | |
let addr1: SocketAddr = "127.0.0.1:8443".parse()?; | |
let join_handle1 = spawn(serve_tls_loop(addr1, acceptor.clone())); | |
// `curl -k https://127.0.0.1:9443` fails with error "OpenSSL SSL_connect: SSL_ERROR_SYSCALL" | |
// tcpdump shows that the server sends `FIN` before client sends `TLS Client Hello` | |
let addr2: SocketAddr = "127.0.0.1:9443".parse()?; | |
let join_handle2 = spawn(serve_tls_axum_server(addr2, acceptor)); | |
let _ = join!(join_handle1, join_handle2); | |
info!("server has been shut down"); | |
Ok(()) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment