Skip to content

Instantly share code, notes, and snippets.

@devsnek
Created October 23, 2023 23:07
Show Gist options
  • Save devsnek/244e9f606bd58175af41a18f1760657a to your computer and use it in GitHub Desktop.
Save devsnek/244e9f606bd58175af41a18f1760657a to your computer and use it in GitHub Desktop.
use std::{
future::Future,
marker::PhantomData,
pin::Pin,
task::{Context, Poll},
};
use tower::{Layer, Service};
pub struct StateLayer<OnRequest, OnResponse, State, Request, Response> {
on_request: OnRequest,
on_response: OnResponse,
_phantom: (
PhantomData<State>,
PhantomData<Request>,
PhantomData<Response>,
),
}
impl<OnRequest, OnResponse, State, Request, Response> Clone
for StateLayer<OnRequest, OnResponse, State, Request, Response>
where
OnRequest: Clone,
OnResponse: Clone,
{
fn clone(&self) -> Self {
Self {
on_request: self.on_request.clone(),
on_response: self.on_response.clone(),
_phantom: Default::default(),
}
}
}
impl<OnRequest, OnResponse, State, Request, Response>
StateLayer<OnRequest, OnResponse, State, Request, Response>
where
OnRequest: Fn(&mut Request) -> State + Clone,
OnResponse: Fn(&mut Response, State) + Clone,
{
pub fn new(on_request: OnRequest, on_response: OnResponse) -> Self {
Self {
on_request,
on_response,
_phantom: Default::default(),
}
}
}
impl<S, OnRequest, OnResponse, State, Request, Response> Layer<S>
for StateLayer<OnRequest, OnResponse, State, Request, Response>
where
OnRequest: Clone,
OnResponse: Clone,
{
type Service = StateService<S, OnRequest, OnResponse>;
fn layer(&self, service: S) -> Self::Service {
StateService {
service,
on_request: self.on_request.clone(),
on_response: self.on_response.clone(),
}
}
}
#[derive(Clone)]
pub struct StateService<S, OnRequest, OnResponse> {
service: S,
on_request: OnRequest,
on_response: OnResponse,
}
impl<S, OnRequest, OnResponse, State, Request, Response> Service<Request>
for StateService<S, OnRequest, OnResponse>
where
S: Service<Request, Response = Response>,
OnRequest: Fn(&mut Request) -> State + Clone,
OnResponse: Fn(&mut Response, State) + Clone,
{
type Response = S::Response;
type Error = S::Error;
type Future = StateFuture<State, S::Future, OnResponse>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.service.poll_ready(cx)
}
fn call(&mut self, mut req: Request) -> Self::Future {
let state = (self.on_request)(&mut req);
let future = self.service.call(req);
StateFuture {
state: Some(state),
future,
on_response: self.on_response.clone(),
}
}
}
#[pin_project::pin_project]
pub struct StateFuture<State, Fut, OnResponse> {
state: Option<State>,
#[pin]
future: Fut,
on_response: OnResponse,
}
impl<State, Fut, OnResponse, Response, Error> Future for StateFuture<State, Fut, OnResponse>
where
Fut: Future<Output = Result<Response, Error>>,
OnResponse: Fn(&mut Response, State),
{
type Output = Result<Response, Error>;
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let this = self.project();
let mut result = futures::ready!(this.future.poll(cx));
match (&mut result, this.state.take()) {
(Ok(response), Some(state)) => {
(this.on_response)(response, state);
}
_ => {}
}
Poll::Ready(result)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment