Skip to content

Instantly share code, notes, and snippets.

@jonhere
Created January 17, 2020 13:36
Show Gist options
  • Save jonhere/75fc1930f143889fed02c2bbaccfad2f to your computer and use it in GitHub Desktop.
Save jonhere/75fc1930f143889fed02c2bbaccfad2f to your computer and use it in GitHub Desktop.
use futures::future::FutureExt;
use hyper::server::accept::Accept;
use std::future::Future;
use std::pin::Pin;
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering;
use std::sync::Arc;
use std::sync::Mutex;
use std::task::{Context, Poll};
use tokio::net::TcpListener;
use tokio::prelude::*;
use tokio_tls::{TlsAcceptor, TlsStream};
#[derive(Clone)]
pub struct CertChanger {
// updates happen in Mutex; atomic can use Ordering::Relaxed
inner: Arc<(AtomicUsize, Mutex<Vec<u8>>)>,
}
/// structure to be passed to hyper::Server::builder()
pub struct In {
listener: TcpListener,
cert_changer: CertChanger,
cert_used: usize,
acceptor: Arc<TlsAcceptor>,
}
pub enum InStream {
Handshake(
Pin<
Box<
dyn Future<Output = Result<TlsStream<tokio::net::TcpStream>, native_tls::Error>>
+ Send,
>,
>,
),
Stream(TlsStream<tokio::net::TcpStream>),
}
impl CertChanger {
/// invalid input will cause panic on first new connection
pub fn new_unchecked(pkcs12_der_empty_password: Vec<u8>) -> Self {
CertChanger {
inner: Arc::new((AtomicUsize::new(0), Mutex::new(pkcs12_der_empty_password))),
}
}
pub fn set_unchecked(&self, pkcs12_der_empty_password: Vec<u8>) {
let elf = &*self.inner;
let mut guard = elf.1.lock().unwrap();
*guard = pkcs12_der_empty_password;
elf.0.fetch_add(1, Ordering::Relaxed);
}
//todo do i want this here?
pub fn num_days(&self) -> i64 {
// taken from acme_lib
fn parse_date(s: &str) -> time::Tm {
//debug!("Parse date/time: {}", s);
time::strptime(s, "%h %e %H:%M:%S %Y %Z").expect("strptime")
}
let x509 = openssl::x509::X509::from_der(&self.inner.1.lock().unwrap()).unwrap();
let not_after = format!("{}", x509.not_after());
let expires = parse_date(&not_after);
let dur = expires - time::now();
dur.num_days()
}
}
impl In {
pub fn new(listener: TcpListener, cert_changer: &CertChanger) -> Self {
let guard = cert_changer.inner.1.lock().unwrap();
In {
listener,
cert_changer: cert_changer.clone(),
cert_used: cert_changer.inner.0.load(Ordering::Relaxed),
acceptor: Arc::new({
let identity = native_tls::Identity::from_pkcs12(&guard, "").unwrap();
let acceptor = native_tls::TlsAcceptor::new(identity).unwrap();
acceptor.into()
}),
}
}
}
impl Accept for In {
type Conn = InStream;
type Error = tokio::io::Error;
fn poll_accept(
self: Pin<&mut Self>,
cx: &mut Context,
) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
let elf = self.get_mut();
// this appears to work (for now)
// todo maybe need to keep Future (which would meen using eg
// async move { let r = listener.accept().await; (r, listener) }.boxed() )
let f = elf.listener.accept();
futures::pin_mut!(f);
let res = f.poll(cx);
//todo fix (if this res errors it takes down the server)
if elf.cert_used != elf.cert_changer.inner.0.load(Ordering::Relaxed) {
// get new cert
elf.acceptor = Arc::new({
let guard = elf.cert_changer.inner.1.lock().unwrap();
elf.cert_used = elf.cert_changer.inner.0.load(Ordering::Relaxed);
let identity = native_tls::Identity::from_pkcs12(&guard, "").unwrap();
let acceptor = native_tls::TlsAcceptor::new(identity).unwrap();
acceptor.into()
});
}
let acceptor = &elf.acceptor;
res.map(|r| {
Some(r.map(|(tcp_stream, addr)| {
//todo add addr to InStream
tracing::info!("{:?} connected will now check tls", addr);
let acceptor = acceptor.clone();
InStream::Handshake(async move { acceptor.accept(tcp_stream).await }.boxed())
}))
})
}
}
impl AsyncRead for InStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context,
buf: &mut [u8],
) -> Poll<Result<usize, tokio::io::Error>> {
let elf = self.get_mut();
if let InStream::Handshake(f) = elf {
let res = f.poll_unpin(cx);
match res {
Poll::Pending => {
return Poll::Pending;
}
Poll::Ready(Err(e)) => {
return Poll::Ready(Err(tokio::io::Error::new(tokio::io::ErrorKind::Other, e)));
}
Poll::Ready(Ok(tls_stream)) => {
*elf = InStream::Stream(tls_stream);
}
}
}
if let InStream::Stream(s) = elf {
futures::pin_mut!(s);
s.poll_read(cx, buf)
} else {
unreachable!();
}
}
}
impl AsyncWrite for InStream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context,
buf: &[u8],
) -> Poll<Result<usize, tokio::io::Error>> {
let elf = self.get_mut();
if let InStream::Handshake(f) = elf {
let res = f.poll_unpin(cx);
match res {
Poll::Pending => {
return Poll::Pending;
}
Poll::Ready(Err(e)) => {
return Poll::Ready(Err(tokio::io::Error::new(tokio::io::ErrorKind::Other, e)));
}
Poll::Ready(Ok(tls_stream)) => {
*elf = InStream::Stream(tls_stream);
}
}
}
if let InStream::Stream(s) = elf {
futures::pin_mut!(s);
s.poll_write(cx, buf)
} else {
unreachable!();
}
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), tokio::io::Error>> {
let elf = self.get_mut();
if let InStream::Handshake(f) = elf {
let res = f.poll_unpin(cx);
match res {
Poll::Pending => {
return Poll::Pending;
}
Poll::Ready(Err(e)) => {
return Poll::Ready(Err(tokio::io::Error::new(tokio::io::ErrorKind::Other, e)));
}
Poll::Ready(Ok(tls_stream)) => {
*elf = InStream::Stream(tls_stream);
}
}
}
if let InStream::Stream(s) = elf {
futures::pin_mut!(s);
s.poll_flush(cx)
} else {
unreachable!();
}
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), tokio::io::Error>> {
let elf = self.get_mut();
if let InStream::Stream(s) = elf {
futures::pin_mut!(s);
s.poll_shutdown(cx)
} else {
//todo any way to async shutdown the active future?
// drop the future now (instead of when Self dropped)
*elf = InStream::Handshake(async { panic!() }.boxed());
Poll::Ready(Ok(()))
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment