Created
January 17, 2020 13:36
-
-
Save jonhere/75fc1930f143889fed02c2bbaccfad2f to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use futures::future::FutureExt; | |
use hyper::server::accept::Accept; | |
use std::future::Future; | |
use std::pin::Pin; | |
use std::sync::atomic::AtomicUsize; | |
use std::sync::atomic::Ordering; | |
use std::sync::Arc; | |
use std::sync::Mutex; | |
use std::task::{Context, Poll}; | |
use tokio::net::TcpListener; | |
use tokio::prelude::*; | |
use tokio_tls::{TlsAcceptor, TlsStream}; | |
#[derive(Clone)] | |
pub struct CertChanger { | |
// updates happen in Mutex; atomic can use Ordering::Relaxed | |
inner: Arc<(AtomicUsize, Mutex<Vec<u8>>)>, | |
} | |
/// structure to be passed to hyper::Server::builder() | |
pub struct In { | |
listener: TcpListener, | |
cert_changer: CertChanger, | |
cert_used: usize, | |
acceptor: Arc<TlsAcceptor>, | |
} | |
pub enum InStream { | |
Handshake( | |
Pin< | |
Box< | |
dyn Future<Output = Result<TlsStream<tokio::net::TcpStream>, native_tls::Error>> | |
+ Send, | |
>, | |
>, | |
), | |
Stream(TlsStream<tokio::net::TcpStream>), | |
} | |
impl CertChanger { | |
/// invalid input will cause panic on first new connection | |
pub fn new_unchecked(pkcs12_der_empty_password: Vec<u8>) -> Self { | |
CertChanger { | |
inner: Arc::new((AtomicUsize::new(0), Mutex::new(pkcs12_der_empty_password))), | |
} | |
} | |
pub fn set_unchecked(&self, pkcs12_der_empty_password: Vec<u8>) { | |
let elf = &*self.inner; | |
let mut guard = elf.1.lock().unwrap(); | |
*guard = pkcs12_der_empty_password; | |
elf.0.fetch_add(1, Ordering::Relaxed); | |
} | |
//todo do i want this here? | |
pub fn num_days(&self) -> i64 { | |
// taken from acme_lib | |
fn parse_date(s: &str) -> time::Tm { | |
//debug!("Parse date/time: {}", s); | |
time::strptime(s, "%h %e %H:%M:%S %Y %Z").expect("strptime") | |
} | |
let x509 = openssl::x509::X509::from_der(&self.inner.1.lock().unwrap()).unwrap(); | |
let not_after = format!("{}", x509.not_after()); | |
let expires = parse_date(¬_after); | |
let dur = expires - time::now(); | |
dur.num_days() | |
} | |
} | |
impl In { | |
pub fn new(listener: TcpListener, cert_changer: &CertChanger) -> Self { | |
let guard = cert_changer.inner.1.lock().unwrap(); | |
In { | |
listener, | |
cert_changer: cert_changer.clone(), | |
cert_used: cert_changer.inner.0.load(Ordering::Relaxed), | |
acceptor: Arc::new({ | |
let identity = native_tls::Identity::from_pkcs12(&guard, "").unwrap(); | |
let acceptor = native_tls::TlsAcceptor::new(identity).unwrap(); | |
acceptor.into() | |
}), | |
} | |
} | |
} | |
impl Accept for In { | |
type Conn = InStream; | |
type Error = tokio::io::Error; | |
fn poll_accept( | |
self: Pin<&mut Self>, | |
cx: &mut Context, | |
) -> Poll<Option<Result<Self::Conn, Self::Error>>> { | |
let elf = self.get_mut(); | |
// this appears to work (for now) | |
// todo maybe need to keep Future (which would meen using eg | |
// async move { let r = listener.accept().await; (r, listener) }.boxed() ) | |
let f = elf.listener.accept(); | |
futures::pin_mut!(f); | |
let res = f.poll(cx); | |
//todo fix (if this res errors it takes down the server) | |
if elf.cert_used != elf.cert_changer.inner.0.load(Ordering::Relaxed) { | |
// get new cert | |
elf.acceptor = Arc::new({ | |
let guard = elf.cert_changer.inner.1.lock().unwrap(); | |
elf.cert_used = elf.cert_changer.inner.0.load(Ordering::Relaxed); | |
let identity = native_tls::Identity::from_pkcs12(&guard, "").unwrap(); | |
let acceptor = native_tls::TlsAcceptor::new(identity).unwrap(); | |
acceptor.into() | |
}); | |
} | |
let acceptor = &elf.acceptor; | |
res.map(|r| { | |
Some(r.map(|(tcp_stream, addr)| { | |
//todo add addr to InStream | |
tracing::info!("{:?} connected will now check tls", addr); | |
let acceptor = acceptor.clone(); | |
InStream::Handshake(async move { acceptor.accept(tcp_stream).await }.boxed()) | |
})) | |
}) | |
} | |
} | |
impl AsyncRead for InStream { | |
fn poll_read( | |
self: Pin<&mut Self>, | |
cx: &mut Context, | |
buf: &mut [u8], | |
) -> Poll<Result<usize, tokio::io::Error>> { | |
let elf = self.get_mut(); | |
if let InStream::Handshake(f) = elf { | |
let res = f.poll_unpin(cx); | |
match res { | |
Poll::Pending => { | |
return Poll::Pending; | |
} | |
Poll::Ready(Err(e)) => { | |
return Poll::Ready(Err(tokio::io::Error::new(tokio::io::ErrorKind::Other, e))); | |
} | |
Poll::Ready(Ok(tls_stream)) => { | |
*elf = InStream::Stream(tls_stream); | |
} | |
} | |
} | |
if let InStream::Stream(s) = elf { | |
futures::pin_mut!(s); | |
s.poll_read(cx, buf) | |
} else { | |
unreachable!(); | |
} | |
} | |
} | |
impl AsyncWrite for InStream { | |
fn poll_write( | |
self: Pin<&mut Self>, | |
cx: &mut Context, | |
buf: &[u8], | |
) -> Poll<Result<usize, tokio::io::Error>> { | |
let elf = self.get_mut(); | |
if let InStream::Handshake(f) = elf { | |
let res = f.poll_unpin(cx); | |
match res { | |
Poll::Pending => { | |
return Poll::Pending; | |
} | |
Poll::Ready(Err(e)) => { | |
return Poll::Ready(Err(tokio::io::Error::new(tokio::io::ErrorKind::Other, e))); | |
} | |
Poll::Ready(Ok(tls_stream)) => { | |
*elf = InStream::Stream(tls_stream); | |
} | |
} | |
} | |
if let InStream::Stream(s) = elf { | |
futures::pin_mut!(s); | |
s.poll_write(cx, buf) | |
} else { | |
unreachable!(); | |
} | |
} | |
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), tokio::io::Error>> { | |
let elf = self.get_mut(); | |
if let InStream::Handshake(f) = elf { | |
let res = f.poll_unpin(cx); | |
match res { | |
Poll::Pending => { | |
return Poll::Pending; | |
} | |
Poll::Ready(Err(e)) => { | |
return Poll::Ready(Err(tokio::io::Error::new(tokio::io::ErrorKind::Other, e))); | |
} | |
Poll::Ready(Ok(tls_stream)) => { | |
*elf = InStream::Stream(tls_stream); | |
} | |
} | |
} | |
if let InStream::Stream(s) = elf { | |
futures::pin_mut!(s); | |
s.poll_flush(cx) | |
} else { | |
unreachable!(); | |
} | |
} | |
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), tokio::io::Error>> { | |
let elf = self.get_mut(); | |
if let InStream::Stream(s) = elf { | |
futures::pin_mut!(s); | |
s.poll_shutdown(cx) | |
} else { | |
//todo any way to async shutdown the active future? | |
// drop the future now (instead of when Self dropped) | |
*elf = InStream::Handshake(async { panic!() }.boxed()); | |
Poll::Ready(Ok(())) | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment