Created
October 25, 2019 19:24
-
-
Save sfackler/e3b532f67476f4a9408819e24561cda4 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use futures::task::{self, Task}; | |
use futures::{Async, Future, Poll}; | |
use hyper::body::{Body, Payload}; | |
use hyper::server::conn::Connection; | |
use hyper::service::Service; | |
use parking_lot::Mutex; | |
use std::collections::HashMap; | |
use std::error::Error; | |
use std::sync::Arc; | |
use tokio_io::{AsyncRead, AsyncWrite}; | |
pub trait Shutdown: Future { | |
fn shutdown(&mut self); | |
} | |
impl<I, S, B> Shutdown for Connection<I, S> | |
where | |
S: Service<ReqBody = Body, ResBody = B> + 'static, | |
S::Error: Into<Box<dyn Error + Sync + Send>>, | |
S::Future: Send, | |
B: Payload + 'static, | |
I: AsyncRead + AsyncWrite + 'static, | |
{ | |
fn shutdown(&mut self) { | |
self.graceful_shutdown(); | |
} | |
} | |
struct Inner { | |
shutdown: bool, | |
awaiting: HashMap<u64, Task>, | |
next_id: u64, | |
active: usize, | |
blocker: Option<Task>, | |
} | |
#[derive(Clone)] | |
pub struct ShutdownState(Arc<Mutex<Inner>>); | |
impl ShutdownState { | |
pub fn new() -> ShutdownState { | |
ShutdownState(Arc::new(Mutex::new(Inner { | |
shutdown: false, | |
awaiting: HashMap::new(), | |
next_id: 0, | |
active: 0, | |
blocker: None, | |
}))) | |
} | |
pub fn wrap_connection<F>(&self, conn: F) -> ConnectionFuture<F> | |
where | |
F: Shutdown, | |
{ | |
let mut state = self.0.lock(); | |
state.active += 1; | |
let id = state.next_id; | |
state.next_id += 1; | |
ConnectionFuture { | |
conn, | |
state: self.0.clone(), | |
id, | |
shutdown: false, | |
} | |
} | |
pub fn has_shutdown(&self) -> bool { | |
self.0.lock().shutdown | |
} | |
pub fn shutdown(self) -> ShutdownFuture { | |
{ | |
let mut state = self.0.lock(); | |
if state.shutdown { | |
panic!("already shutting down"); | |
} | |
state.shutdown = true; | |
for task in state.awaiting.values() { | |
task.notify(); | |
} | |
} | |
ShutdownFuture(self.0) | |
} | |
} | |
pub struct ConnectionFuture<F> { | |
conn: F, | |
state: Arc<Mutex<Inner>>, | |
id: u64, | |
shutdown: bool, | |
} | |
impl<F> Drop for ConnectionFuture<F> { | |
fn drop(&mut self) { | |
let mut state = self.state.lock(); | |
state.awaiting.remove(&self.id); | |
state.active -= 1; | |
if state.active == 0 && state.shutdown { | |
if let Some(ref task) = state.blocker { | |
task.notify(); | |
} | |
} | |
} | |
} | |
impl<F> Future for ConnectionFuture<F> | |
where | |
F: Shutdown, | |
{ | |
type Item = F::Item; | |
type Error = F::Error; | |
fn poll(&mut self) -> Poll<F::Item, F::Error> { | |
if !self.shutdown { | |
let mut state = self.state.lock(); | |
state.awaiting.insert(self.id, task::current()); | |
if state.shutdown { | |
self.conn.shutdown(); | |
self.shutdown = true; | |
} | |
} | |
self.conn.poll() | |
} | |
} | |
pub struct ShutdownFuture(Arc<Mutex<Inner>>); | |
impl Future for ShutdownFuture { | |
type Item = (); | |
type Error = (); | |
fn poll(&mut self) -> Poll<(), ()> { | |
let mut state = self.0.lock(); | |
state.blocker = Some(task::current()); | |
if state.active == 0 { | |
Ok(Async::Ready(())) | |
} else { | |
Ok(Async::NotReady) | |
} | |
} | |
} | |
#[cfg(test)] | |
mod test { | |
use futures::task::{self, Task}; | |
use futures::{Async, Poll}; | |
use std::sync::atomic::{AtomicUsize, Ordering}; | |
use std::time::{Duration, Instant}; | |
use tokio::runtime::current_thread::Runtime; | |
use tokio::timer::Delay; | |
use super::*; | |
#[test] | |
fn shutdown() { | |
static SHUTDOWNS: AtomicUsize = AtomicUsize::new(0); | |
struct TestFuture { | |
shutdown: bool, | |
task: Option<Task>, | |
delay: Delay, | |
} | |
impl Future for TestFuture { | |
type Item = (); | |
type Error = (); | |
fn poll(&mut self) -> Poll<(), ()> { | |
if self.shutdown { | |
return Ok(Async::Ready(())); | |
} | |
self.task = Some(task::current()); | |
self.delay.poll().map_err(|_| ()) | |
} | |
} | |
impl Shutdown for TestFuture { | |
fn shutdown(&mut self) { | |
self.shutdown = true; | |
SHUTDOWNS.fetch_add(1, Ordering::SeqCst); | |
if let Some(ref task) = self.task { | |
task.notify(); | |
} | |
} | |
} | |
let mut core = Runtime::new().unwrap(); | |
let state = ShutdownState::new(); | |
for _ in 0..5 { | |
let f = TestFuture { | |
shutdown: false, | |
task: None, | |
delay: Delay::new(Instant::now() + Duration::from_secs(10_000)), | |
}; | |
core.spawn(state.wrap_connection(f)); | |
} | |
core.block_on(Delay::new(Instant::now() + Duration::from_millis(100))) | |
.unwrap(); | |
assert_eq!(SHUTDOWNS.load(Ordering::SeqCst), 0); | |
core.block_on(state.shutdown()).unwrap(); | |
assert_eq!(SHUTDOWNS.load(Ordering::SeqCst), 5); | |
core.run().unwrap(); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment