Last active
January 16, 2023 04:09
-
-
Save juliarose/cec2b4495f7269f3dc6a0c6001c8a333 to your computer and use it in GitHub Desktop.
A method for retrying futures which result in any error that implements the trait Retryable to determine whether the request is retryable
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use futures_retry::{ErrorHandler, RetryPolicy, FutureRetry}; | |
use rand::Rng; | |
use std::time::Duration; | |
#[macro_export] | |
macro_rules! retry { | |
( $x:expr, $d:expr, $n:expr ) => { | |
{ | |
FutureRetry::new(move || $x, RetryHandler::new($d, $n)) | |
.await | |
.map(|(value, _)| value) | |
.map_err(|(e, _)| e) | |
} | |
}; | |
} | |
pub trait Retryable { | |
fn retryable(&self) -> bool; | |
} | |
pub struct RetryHandler<D> { | |
max_attempts: usize, | |
current_attempt: usize, | |
display_name: D, | |
} | |
impl<D> RetryHandler<D> { | |
pub fn new(max_attempts: usize, display_name: D) -> Self { | |
RetryHandler { | |
max_attempts, | |
current_attempt: 0, | |
display_name, | |
} | |
} | |
fn calculate_wait_duration(&self) -> Duration { | |
const MAX_BACKOFF: u64 = 5; | |
let backoff_seconds = std::cmp::min( | |
2 * self.current_attempt as u64, | |
MAX_BACKOFF, | |
); | |
Duration::from_millis(backoff_seconds * 1000) | |
} | |
} | |
impl<D, T> ErrorHandler<T> for RetryHandler<D> | |
where | |
D: ::std::fmt::Display, | |
T: Retryable, | |
{ | |
type OutError = T; | |
fn handle( | |
&mut self, | |
current_attempt: usize, | |
e: T, | |
) -> RetryPolicy<T> { | |
if current_attempt > self.max_attempts { | |
eprintln!( | |
"[{}] All attempts ({}) have been used up", | |
self.display_name, self.max_attempts | |
); | |
return RetryPolicy::ForwardError(e); | |
} | |
eprintln!( | |
"[{}] Attempt {}/{} has failed", | |
self.display_name, current_attempt, self.max_attempts | |
); | |
self.current_attempt += 1; | |
if e.retryable() { | |
RetryPolicy::WaitRetry(self.calculate_wait_duration()) | |
} else { | |
RetryPolicy::ForwardError(e) | |
} | |
} | |
} | |
#[derive(Debug)] | |
pub enum Error { | |
IsOdd, | |
IsNegative, | |
} | |
impl Retryable for Error { | |
fn retryable(&self) -> bool { | |
match self { | |
Self::IsNegative => true, | |
Self::IsOdd => false, | |
} | |
} | |
} | |
async fn wait_for_number() -> Result<i32, Error> { | |
let mut rng = rand::thread_rng(); | |
let n: i32 = rng.gen(); | |
if n < 0 { | |
Err(Error::IsNegative) | |
} else if n % 2 != 0 { | |
Err(Error::IsOdd) | |
} else { | |
Ok(n) | |
} | |
} | |
#[tokio::main] | |
async fn main() { | |
let n = retry!(wait_for_number(), 3, "Retry...") | |
.unwrap(); | |
println!("The number is {}!", n) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment