Skip to content

Instantly share code, notes, and snippets.

@moshec2
Created March 23, 2023 11:51
Show Gist options
  • Save moshec2/8b1c99396f61e4b132412468fdb5e2c2 to your computer and use it in GitHub Desktop.
Save moshec2/8b1c99396f61e4b132412468fdb5e2c2 to your computer and use it in GitHub Desktop.
//! Simple HTTPS echo service based on hyper-rustls
//!
//! First parameter is the mandatory port to use.
//! Certificate and private key are hardcoded to sample files.
//! hyper will automatically use HTTP/2 if a client starts talking HTTP/2,
//! otherwise HTTP/1.1 will be used.
use core::task::{Context, Poll};
use futures_util::ready;
use hyper::server::accept::Accept;
use hyper::server::conn::{AddrIncoming, AddrStream};
use hyper::service::{make_service_fn, service_fn};
use hyper::{Body, Method, Request, Response, Server, StatusCode};
use rustls::server::AllowAnyAuthenticatedClient;
use rustls::{RootCertStore, ServerConnection};
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::vec::Vec;
use std::{env, fs, io, sync};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio_rustls::rustls::ServerConfig;
fn main() {
// Serve an echo service over HTTPS, with proper error handling.
if let Err(e) = run_server() {
eprintln!("FAILED: {}", e);
std::process::exit(1);
}
}
fn error(err: String) -> io::Error {
io::Error::new(io::ErrorKind::Other, err)
}
#[tokio::main]
async fn run_server() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
// First parameter is port number (optional, defaults to 1337)
let port = match env::args().nth(1) {
Some(ref p) => p.to_owned(),
None => "1337".to_owned(),
};
let addr = format!("127.0.0.1:{}", port).parse()?;
// Build TLS configuration.
let tls_cfg = {
// Load public certificate.
let certs = load_certs("certs/server.crt")?;
// Load private key.
let key = load_private_key("certs/server.key")?;
let mut ca = RootCertStore::empty();
let ca_cert = load_certs("certs/ca.crt")?;
ca.add(&ca_cert[0])?;
let verifier = AllowAnyAuthenticatedClient::new(ca);
let mut cfg = rustls::ServerConfig::builder()
.with_safe_defaults()
.with_client_cert_verifier(verifier)
.with_single_cert(certs, key)
.map_err(|e| error(format!("{}", e)))?;
// Configure ALPN to accept HTTP/2, HTTP/1.1, and HTTP/1.0 in that order.
cfg.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec(), b"http/1.0".to_vec()];
sync::Arc::new(cfg)
};
// Create a TCP listener via tokio.
let incoming = AddrIncoming::bind(&addr)?;
let service = make_service_fn(|socket: &TlsStream| {
let client_cert = socket
.session()
.and_then(|s| s.peer_certificates())
.and_then(|certs| certs.first());
dbg!(client_cert); // Always None
async { Ok::<_, io::Error>(service_fn(echo)) }
});
let server = Server::builder(TlsAcceptor::new(tls_cfg, incoming)).serve(service);
// Run the future, keep going until an error occurs.
println!("Starting to serve on https://{}.", addr);
server.await?;
Ok(())
}
enum State {
Handshaking(tokio_rustls::Accept<AddrStream>),
Streaming(tokio_rustls::server::TlsStream<AddrStream>),
}
// tokio_rustls::server::TlsStream doesn't expose constructor methods,
// so we have to TlsAcceptor::accept and handshake to have access to it
// TlsStream implements AsyncRead/AsyncWrite handshaking tokio_rustls::Accept first
pub struct TlsStream {
state: State,
}
impl TlsStream {
fn new(stream: AddrStream, config: Arc<ServerConfig>) -> TlsStream {
let accept = tokio_rustls::TlsAcceptor::from(config).accept(stream);
TlsStream {
state: State::Handshaking(accept),
}
}
pub fn session(&self) -> Option<&ServerConnection> {
match &self.state {
State::Streaming(stream) => Some(stream.get_ref().1),
State::Handshaking(_) => None,
}
}
}
impl AsyncRead for TlsStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context,
buf: &mut ReadBuf,
) -> Poll<io::Result<()>> {
let pin = self.get_mut();
match pin.state {
State::Handshaking(ref mut accept) => match ready!(Pin::new(accept).poll(cx)) {
Ok(mut stream) => {
let result = Pin::new(&mut stream).poll_read(cx, buf);
pin.state = State::Streaming(stream);
result
}
Err(err) => Poll::Ready(Err(err)),
},
State::Streaming(ref mut stream) => Pin::new(stream).poll_read(cx, buf),
}
}
}
impl AsyncWrite for TlsStream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let pin = self.get_mut();
match pin.state {
State::Handshaking(ref mut accept) => match ready!(Pin::new(accept).poll(cx)) {
Ok(mut stream) => {
let result = Pin::new(&mut stream).poll_write(cx, buf);
pin.state = State::Streaming(stream);
result
}
Err(err) => Poll::Ready(Err(err)),
},
State::Streaming(ref mut stream) => Pin::new(stream).poll_write(cx, buf),
}
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match self.state {
State::Handshaking(_) => Poll::Ready(Ok(())),
State::Streaming(ref mut stream) => Pin::new(stream).poll_flush(cx),
}
}
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match self.state {
State::Handshaking(_) => Poll::Ready(Ok(())),
State::Streaming(ref mut stream) => Pin::new(stream).poll_shutdown(cx),
}
}
}
pub struct TlsAcceptor {
config: Arc<ServerConfig>,
incoming: AddrIncoming,
}
impl TlsAcceptor {
pub fn new(config: Arc<ServerConfig>, incoming: AddrIncoming) -> TlsAcceptor {
TlsAcceptor { config, incoming }
}
}
impl Accept for TlsAcceptor {
type Conn = TlsStream;
type Error = io::Error;
fn poll_accept(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
let pin = self.get_mut();
match ready!(Pin::new(&mut pin.incoming).poll_accept(cx)) {
Some(Ok(sock)) => Poll::Ready(Some(Ok(TlsStream::new(sock, pin.config.clone())))),
Some(Err(e)) => Poll::Ready(Some(Err(e))),
None => Poll::Ready(None),
}
}
}
// Custom echo service, handling two different routes and a
// catch-all 404 responder.
async fn echo(req: Request<Body>) -> Result<Response<Body>, hyper::Error> {
let mut response = Response::new(Body::empty());
match (req.method(), req.uri().path()) {
// Help route.
(&Method::GET, "/") => {
*response.body_mut() = Body::from("Try POST /echo\n");
}
// Echo service route.
(&Method::POST, "/echo") => {
*response.body_mut() = req.into_body();
}
// Catch-all 404.
_ => {
*response.status_mut() = StatusCode::NOT_FOUND;
}
};
Ok(response)
}
// Load public certificate from file.
fn load_certs(filename: &str) -> io::Result<Vec<rustls::Certificate>> {
// Open certificate file.
let certfile = fs::File::open(filename)
.map_err(|e| error(format!("failed to open {}: {}", filename, e)))?;
let mut reader = io::BufReader::new(certfile);
// Load and return certificate.
let certs = rustls_pemfile::certs(&mut reader)
.map_err(|_| error("failed to load certificate".into()))?;
Ok(certs.into_iter().map(rustls::Certificate).collect())
}
// Load private key from file.
fn load_private_key(filename: &str) -> io::Result<rustls::PrivateKey> {
let rsa_keys = {
let keyfile = fs::File::open(filename).expect("cannot open private key file");
let mut reader = io::BufReader::new(keyfile);
rustls_pemfile::rsa_private_keys(&mut reader)
.expect("file contains invalid rsa private key")
};
let pkcs8_keys = {
let keyfile = fs::File::open(filename).expect("cannot open private key file");
let mut reader = io::BufReader::new(keyfile);
rustls_pemfile::pkcs8_private_keys(&mut reader)
.expect("file contains invalid pkcs8 private key (encrypted keys not supported)")
};
let key = if !pkcs8_keys.is_empty() {
pkcs8_keys[0].clone()
} else {
assert!(!rsa_keys.is_empty());
rsa_keys[0].clone()
};
Ok(rustls::PrivateKey(key))
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment