Skip to content

Instantly share code, notes, and snippets.

@sfackler
Created October 25, 2019 19:24
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sfackler/e3b532f67476f4a9408819e24561cda4 to your computer and use it in GitHub Desktop.
Save sfackler/e3b532f67476f4a9408819e24561cda4 to your computer and use it in GitHub Desktop.
use futures::task::{self, Task};
use futures::{Async, Future, Poll};
use hyper::body::{Body, Payload};
use hyper::server::conn::Connection;
use hyper::service::Service;
use parking_lot::Mutex;
use std::collections::HashMap;
use std::error::Error;
use std::sync::Arc;
use tokio_io::{AsyncRead, AsyncWrite};
pub trait Shutdown: Future {
fn shutdown(&mut self);
}
impl<I, S, B> Shutdown for Connection<I, S>
where
S: Service<ReqBody = Body, ResBody = B> + 'static,
S::Error: Into<Box<dyn Error + Sync + Send>>,
S::Future: Send,
B: Payload + 'static,
I: AsyncRead + AsyncWrite + 'static,
{
fn shutdown(&mut self) {
self.graceful_shutdown();
}
}
struct Inner {
shutdown: bool,
awaiting: HashMap<u64, Task>,
next_id: u64,
active: usize,
blocker: Option<Task>,
}
#[derive(Clone)]
pub struct ShutdownState(Arc<Mutex<Inner>>);
impl ShutdownState {
pub fn new() -> ShutdownState {
ShutdownState(Arc::new(Mutex::new(Inner {
shutdown: false,
awaiting: HashMap::new(),
next_id: 0,
active: 0,
blocker: None,
})))
}
pub fn wrap_connection<F>(&self, conn: F) -> ConnectionFuture<F>
where
F: Shutdown,
{
let mut state = self.0.lock();
state.active += 1;
let id = state.next_id;
state.next_id += 1;
ConnectionFuture {
conn,
state: self.0.clone(),
id,
shutdown: false,
}
}
pub fn has_shutdown(&self) -> bool {
self.0.lock().shutdown
}
pub fn shutdown(self) -> ShutdownFuture {
{
let mut state = self.0.lock();
if state.shutdown {
panic!("already shutting down");
}
state.shutdown = true;
for task in state.awaiting.values() {
task.notify();
}
}
ShutdownFuture(self.0)
}
}
pub struct ConnectionFuture<F> {
conn: F,
state: Arc<Mutex<Inner>>,
id: u64,
shutdown: bool,
}
impl<F> Drop for ConnectionFuture<F> {
fn drop(&mut self) {
let mut state = self.state.lock();
state.awaiting.remove(&self.id);
state.active -= 1;
if state.active == 0 && state.shutdown {
if let Some(ref task) = state.blocker {
task.notify();
}
}
}
}
impl<F> Future for ConnectionFuture<F>
where
F: Shutdown,
{
type Item = F::Item;
type Error = F::Error;
fn poll(&mut self) -> Poll<F::Item, F::Error> {
if !self.shutdown {
let mut state = self.state.lock();
state.awaiting.insert(self.id, task::current());
if state.shutdown {
self.conn.shutdown();
self.shutdown = true;
}
}
self.conn.poll()
}
}
pub struct ShutdownFuture(Arc<Mutex<Inner>>);
impl Future for ShutdownFuture {
type Item = ();
type Error = ();
fn poll(&mut self) -> Poll<(), ()> {
let mut state = self.0.lock();
state.blocker = Some(task::current());
if state.active == 0 {
Ok(Async::Ready(()))
} else {
Ok(Async::NotReady)
}
}
}
#[cfg(test)]
mod test {
use futures::task::{self, Task};
use futures::{Async, Poll};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::{Duration, Instant};
use tokio::runtime::current_thread::Runtime;
use tokio::timer::Delay;
use super::*;
#[test]
fn shutdown() {
static SHUTDOWNS: AtomicUsize = AtomicUsize::new(0);
struct TestFuture {
shutdown: bool,
task: Option<Task>,
delay: Delay,
}
impl Future for TestFuture {
type Item = ();
type Error = ();
fn poll(&mut self) -> Poll<(), ()> {
if self.shutdown {
return Ok(Async::Ready(()));
}
self.task = Some(task::current());
self.delay.poll().map_err(|_| ())
}
}
impl Shutdown for TestFuture {
fn shutdown(&mut self) {
self.shutdown = true;
SHUTDOWNS.fetch_add(1, Ordering::SeqCst);
if let Some(ref task) = self.task {
task.notify();
}
}
}
let mut core = Runtime::new().unwrap();
let state = ShutdownState::new();
for _ in 0..5 {
let f = TestFuture {
shutdown: false,
task: None,
delay: Delay::new(Instant::now() + Duration::from_secs(10_000)),
};
core.spawn(state.wrap_connection(f));
}
core.block_on(Delay::new(Instant::now() + Duration::from_millis(100)))
.unwrap();
assert_eq!(SHUTDOWNS.load(Ordering::SeqCst), 0);
core.block_on(state.shutdown()).unwrap();
assert_eq!(SHUTDOWNS.load(Ordering::SeqCst), 5);
core.run().unwrap();
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment