Skip to content

Instantly share code, notes, and snippets.

@Swoorup
Created October 2, 2022 16:28
Show Gist options
  • Save Swoorup/41dfcba8d712fb5fa03a38119a1d7088 to your computer and use it in GitHub Desktop.
Save Swoorup/41dfcba8d712fb5fa03a38119a1d7088 to your computer and use it in GitHub Desktop.
Restartable stream on complete or error using a connect function
use std::task::Poll;
use futures::{Future, FutureExt, Stream};
use pin_project::pin_project;
use tracing::debug;
#[pin_project(project = StreamStateProj)]
#[derive(Debug, Clone)]
enum StreamState<F, S>
where
F: Future,
{
NotConnected,
Connected(#[pin] S),
Connecting {
#[pin]
reconnect_attempt: F,
},
}
impl<F, S> Default for StreamState<F, S>
where
F: Future,
{
fn default() -> Self {
Self::NotConnected
}
}
/// Restartable stream on either construction error or when finished
#[pin_project]
#[derive(Debug, Clone)]
pub struct Restartable<Param, F, S, E, Cons>
where
Param: Clone,
F: Future<Output = Result<S, E>>,
S: Stream,
Cons: Fn(Param) -> F,
{
param: Param,
create: Cons,
#[pin]
stream: StreamState<F, S>,
}
impl<Param, F, S, E, Cons> Restartable<Param, F, S, E, Cons>
where
Param: Clone,
F: Future<Output = Result<S, E>>,
S: Stream,
Cons: Fn(Param) -> F,
{
pub fn of(establish: Cons, param: Param) -> Self {
Self {
param: param,
create: establish,
stream: Default::default(),
}
}
}
impl<Param, F, S, E, Cons> Stream for Restartable<Param, F, S, E, Cons>
where
Param: Clone,
F: Future<Output = Result<S, E>>,
S: Stream,
Cons: Fn(Param) -> F,
{
type Item = S::Item;
fn poll_next(mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Option<Self::Item>> {
let mut me = self.as_mut().project();
match me.stream.as_mut().project() {
StreamStateProj::NotConnected => {
debug!("Establishing connection...");
let reconnect_attempt = (me.create)(me.param.clone());
me.stream.set(StreamState::Connecting { reconnect_attempt });
cx.waker().wake_by_ref();
Poll::Pending
}
StreamStateProj::Connected(stream) => {
match stream.poll_next(cx) {
Poll::Ready(Some(t)) => Poll::Ready(Some(t)),
// completed, reset to NotConnected
Poll::Ready(None) => {
me.stream.set(StreamState::NotConnected);
cx.waker().wake_by_ref();
Poll::Pending
}
Poll::Pending => Poll::Pending,
}
}
StreamStateProj::Connecting { ref mut reconnect_attempt } => {
debug!("Connecting");
match reconnect_attempt.poll_unpin(cx) {
Poll::Ready(Ok(stream)) => {
me.stream.set(StreamState::Connected(stream));
cx.waker().wake_by_ref();
Poll::Pending
}
// on error mark it as completed
Poll::Ready(Err(_)) => Poll::Ready(None),
Poll::Pending => Poll::Pending,
}
}
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
match self.stream {
StreamState::NotConnected | StreamState::Connecting { reconnect_attempt: _ } => (0, Some(0)),
StreamState::Connected(ref stream) => stream.size_hint(),
}
}
}
@Swoorup
Copy link
Author

Swoorup commented Oct 3, 2022

Example usage:

#![feature(async_closure)]
#![feature(let_chains)]
#![feature(trait_alias)]
use std::time::Duration;

use my_lib::util::stream::Restartable;
use futures::{pin_mut, Stream};
use tokio_retry::strategy::{jitter, ExponentialBackoff};
use tokio_retry::Retry;
use tokio_stream::StreamExt;

async fn create_stream() -> Result<impl Stream<Item = u32>, ()> {
  use async_stream::stream;
  let ts = Timestamp::now();
  println!("Current timestamp: {ts:#?}");

  Ok(stream! {
      for i in 0..3 {
          yield i;
      }
  })
}

async fn create_with_retry(_: ()) -> Result<impl Stream<Item = u32>, ()> {
  let retry_strategy = ExponentialBackoff::from_millis(2)
    .max_delay(Duration::from_secs(2))
    .map(jitter)
    .take(3); // limit to 3 retries
  Retry::spawn(retry_strategy, create_stream).await
}

#[tokio::main]
async fn main() {
  let stream = Restartable::of(create_with_retry, ()).take(4);
  pin_mut!(stream);
  while let Some(value) = stream.next().await {
    println!("got {}", value);
  }
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment