Skip to content

Instantly share code, notes, and snippets.

@kyle-mccarthy
Created July 19, 2023 01:26
Show Gist options
  • Save kyle-mccarthy/73ab6c78e6d3bf0819fc7c00b90161f4 to your computer and use it in GitHub Desktop.
Save kyle-mccarthy/73ab6c78e6d3bf0819fc7c00b90161f4 to your computer and use it in GitHub Desktop.
Tonic bidirectional stream client utility
[package]
name = "tonic-stream-wrapper"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
tokio = { version = "1", features = ["full"] }
tokio-stream = { version = "0.1", features = ["sync"] }
tonic = { version = "0.9" }
futures = { version = "0.3" }
tracing = { version = "0.1" }
thiserror = "1.0"
use std::future::Future;
use futures::{future, Stream, StreamExt, TryStreamExt};
use tokio::{
spawn,
sync::{broadcast, mpsc, oneshot},
task::{JoinError, JoinHandle},
};
use tokio_stream::wrappers::{BroadcastStream, ReceiverStream};
use tonic::{Response, Status, Streaming};
use tracing::{error, warn};
#[derive(Debug, thiserror::Error, Clone)]
pub enum Error {
#[error("gRPC error: {0}")]
Status(#[from] tonic::Status),
#[error("Stream closed")]
StreamClosed,
}
struct RequestWrapper<T> {
inner: T,
request_id_sender: oneshot::Sender<usize>,
}
#[derive(Clone)]
struct ResponseWrapper<T> {
inner: T,
request_id: usize,
}
pub struct OrderedStream<Req, Res> {
request_sender: mpsc::Sender<RequestWrapper<Req>>,
response_sender: broadcast::Sender<ResponseWrapper<Result<Res, Status>>>,
background_pipeline: JoinHandle<()>,
}
impl<Req, Res> OrderedStream<Req, Res>
where
Res: Send + Clone + 'static,
Req: Send + 'static,
{
pub fn new<F, O>(request_fn: F) -> Self
where
F: FnOnce(Box<dyn Stream<Item = Req> + Send>) -> O,
F: Send + 'static,
O: Future<Output = Result<Response<Streaming<Res>>, Status>> + Send,
{
let (request_sender, request_receiver) = mpsc::channel::<RequestWrapper<Req>>(32);
let (response_sender, _response_receiver) =
broadcast::channel::<ResponseWrapper<Result<Res, Status>>>(32);
let background_pipeline = spawn({
let response_sender = response_sender.clone();
async move {
let outgoing = ReceiverStream::new(request_receiver).enumerate().map(
|(
request_id,
RequestWrapper {
inner,
request_id_sender,
},
)| {
// We don't care about the result here, an error just means that the receiver side
// of the oneshot is gone.
let _ = request_id_sender.send(request_id);
inner
},
);
let incoming = match request_fn(Box::new(outgoing)).await {
Err(status) => {
error!(
"Stream initialization returned an error, exiting early. (status = {:?})",
status
);
return;
}
Ok(response) => response.into_inner(),
};
incoming
.enumerate()
.map(|(request_id, inner)| ResponseWrapper { request_id, inner })
.for_each(|response| {
if response_sender.send(response).is_err() {
warn!("Response received but no receivers on the response channel");
}
future::ready(())
})
.await;
}
});
Self {
request_sender,
response_sender,
background_pipeline,
}
}
pub async fn send(&self, message: Req) -> Result<Res, Error> {
let (sender, receiver) = oneshot::channel::<usize>();
let response_receiver = self.response_sender.subscribe();
self.request_sender
.send(RequestWrapper {
inner: message,
request_id_sender: sender,
})
.await
.map_err(|_| Error::StreamClosed)?;
let request_id = receiver.await.map_err(|_| Error::StreamClosed)?;
let mut response_stream = BroadcastStream::new(response_receiver)
.filter_map(|result| {
let value = match result {
Ok(wrapper) if wrapper.request_id == request_id => Some(wrapper.inner),
_ => None,
};
future::ready(value)
})
.map_err(Error::Status)
.take(1)
.fuse();
response_stream
.next()
.await
.expect("Every request has a response")
}
pub async fn close(self) -> Result<(), JoinError> {
let Self {
background_pipeline,
request_sender,
..
} = self;
// need to drop the request sender or awaiting the join handle will hang
drop(request_sender);
background_pipeline.await
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment