Skip to content

Instantly share code, notes, and snippets.

@sdbondi
Last active January 30, 2020 07:18
Show Gist options
  • Save sdbondi/cbc43752faa75914b36ae9dd3fd13915 to your computer and use it in GitHub Desktop.
Save sdbondi/cbc43752faa75914b36ae9dd3fd13915 to your computer and use it in GitHub Desktop.
Generic retry stream
use futures::{ready, stream::FusedStream, task::Context, Future, Stream};
use pin_project::{pin_project, project};
use std::{pin::Pin, task::Poll};
use tokio::time::{delay_for, Delay};
pub trait Backoff {
fn calculate_backoff(&self, attempts: usize) -> Duration;
}
#[pin_project]
enum State<TFut> {
Initial,
Waiting(#[pin] Delay),
Running(#[pin] TFut),
Complete,
}
/// Future which tries to run another future a few times until it succeeds or the maximum attempts is reached
#[pin_project]
pub struct DelayedRetry<'a, TFutFactory, TFut, TBackoff> {
future_factory: TFutFactory,
backoff: TBackoff,
attempts: usize,
max_attempts: usize,
#[pin]
state: State<TFut>,
_lifetime: PhantomData<&'a ()>,
}
impl<'a, TFutFactory, TFut, TBackoff, T, E> DelayedRetry<'a, TFutFactory, TFut, TBackoff>
where
TFutFactory: FnMut(usize) -> TFut,
TFut: Future<Output = Result<T, E>>,
TBackoff: Backoff,
{
pub fn new(future_factory: TFutFactory, backoff: TBackoff, max_attempts: usize) -> Self {
Self {
future_factory,
backoff,
max_attempts,
attempts: 0,
state: State::Initial,
_lifetime: PhantomData,
}
}
}
impl<'a, TFutFactory, TFut, TBackoff, T, E> Stream for DelayedRetry<'a, TFutFactory, TFut, TBackoff>
where
TFutFactory: FnMut(usize) -> TFut,
TFut: Future<Output = Result<T, E>> + 'a,
TBackoff: Backoff,
{
type Item = Result<T, E>;
#[project]
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let mut this = self.project();
let mut current_result = None;
loop {
#[project]
let next_state = match this.state.as_mut().project() {
State::Initial => {
let backoff_time = this.backoff.calculate_backoff(*this.attempts);
if backoff_time.as_micros() > 0 {
State::Waiting(delay_for(backoff_time))
} else {
State::Running((this.future_factory)(1))
}
},
State::Waiting(delay_fut) => {
ready!(delay_fut.poll(cx));
State::Running((this.future_factory)(*this.attempts))
},
State::Running(fut) => match ready!(fut.poll(cx)) {
Ok(v) => {
current_result = Some(Ok(v));
State::Complete
},
Err(err) => {
current_result = Some(Err(err));
*this.attempts += 1;
if this.attempts >= this.max_attempts {
// After we emit the final error, the stream ends because we've already reached the maximum
// attempts
State::Complete
} else {
let backoff_time = this.backoff.calculate_backoff(*this.attempts);
State::Waiting(delay_for(backoff_time))
}
},
},
State::Complete => {
return Poll::Ready(None);
},
};
this.state.set(next_state);
if let Some(result) = current_result.take() {
return Poll::Ready(Some(result));
}
}
}
}
impl<T, E, TFutFactory, TFut, TBackoff> FusedStream for DelayedRetry<'_, TFutFactory, TFut, TBackoff>
where
TFutFactory: FnMut(usize) -> TFut,
TFut: Future<Output = Result<T, E>>,
TBackoff: Backoff,
{
fn is_terminated(&self) -> bool {
match self.state {
State::Complete => true,
_ => false,
}
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::backoff::ConstantBackoff;
use futures::{future, StreamExt};
use std::{
sync::{
atomic::{AtomicUsize, Ordering},
Arc,
},
time::Duration,
};
#[tokio_macros::test_basic]
async fn never_succeeds() {
let call_count = Arc::new(AtomicUsize::new(0));
let call_count_clone = call_count.clone();
let retry = DelayedRetry::new(
|_| {
call_count_clone.fetch_add(1, Ordering::Relaxed);
future::ready(Result::<(), _>::Err(()))
},
ConstantBackoff::new(Duration::from_millis(1)),
3,
);
let results = retry.collect::<Vec<_>>().await;
// 3 results emitted
assert_eq!(results.len(), 3);
// ... all of them errors
assert_eq!(results.into_iter().filter(Result::is_err).count(), 3);
// ... from exactly 3 attempts
assert_eq!(call_count.load(Ordering::Relaxed), 3);
}
#[tokio_macros::test_basic]
async fn succeeds_later() {
let call_count = Arc::new(AtomicUsize::new(0));
let call_count_clone = call_count.clone();
let retry = DelayedRetry::new(
|_| match call_count_clone.fetch_add(1, Ordering::Relaxed) {
2 => future::ready(Ok("Works!")),
_ => future::ready(Err(())),
},
ConstantBackoff::new(Duration::from_millis(1)),
3,
);
let results = retry.collect::<Vec<_>>().await;
// 3 Results emitted
assert_eq!(results.len(), 3);
// ... 2 errors
assert_eq!(results.iter().filter(|r| r.is_err()).count(), 2);
// ... ending with a success
assert_eq!(results.last().unwrap(), &Ok("Works!"));
// ... after exactly 3 attempts
assert_eq!(call_count.load(Ordering::Relaxed), 3);
}
#[tokio_macros::test_basic]
async fn succeeds_immediately() {
let call_count = Arc::new(AtomicUsize::new(0));
let call_count_clone = call_count.clone();
let retry = DelayedRetry::new(
|_| {
call_count_clone.fetch_add(1, Ordering::Relaxed);
future::ready(Result::<_, ()>::Ok("Works!"))
},
ConstantBackoff::new(Duration::from_millis(1)),
3,
);
let results = retry.collect::<Vec<_>>().await;
assert_eq!(results.len(), 1);
// Returns the last error i.e the third call returns call_count == 2
assert_eq!(results.get(0).unwrap(), &Ok("Works!"));
// ... after exactly 1 attempt
assert_eq!(call_count.load(Ordering::Relaxed), 1);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment