Skip to content

Instantly share code, notes, and snippets.

@cletustboone
Last active July 25, 2022 21:23
Show Gist options
  • Save cletustboone/036409522c609f460004d6708c6f4062 to your computer and use it in GitHub Desktop.
Save cletustboone/036409522c609f460004d6708c6f4062 to your computer and use it in GitHub Desktop.
use std::{error::Error, fmt::Display, task::Poll, time::Duration};
use futures::Future;
use pin_project::pin_project;
use tokio::time::{sleep, Sleep};
use tower::{BoxError, Layer, Service};
use tracing::{instrument, warn};
use crate::gateway::GatewayRequest;
#[derive(Debug, Clone)]
pub struct CapabilityCheck<S> {
pub inner: S,
}
impl<S> CapabilityCheck<S> {
pub fn new(inner: S) -> Self {
Self { inner }
}
}
pub struct CapabilityCheckLayer {}
impl CapabilityCheckLayer {
pub fn new() -> Self {
Self {}
}
}
impl<S> Layer<S> for CapabilityCheckLayer {
type Service = CapabilityCheck<S>;
fn layer(&self, inner: S) -> Self::Service {
CapabilityCheck::new(inner)
}
}
// Implement tower::Service for CapbilibytCheck.
// It introduces a fake async delay, then checks the request, then passes it to
// the inner service if the request passes the capability check.
impl<S> Service<GatewayRequest> for CapabilityCheck<S>
where
S: Service<GatewayRequest> + std::fmt::Debug,
GatewayRequest: std::fmt::Debug,
S::Error: Into<BoxError>,
{
type Response = S::Response;
type Error = BoxError;
type Future = CapabilityCheckFuture<S::Future>;
fn poll_ready(
&mut self,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
#[instrument(name = "cap check")]
fn call(&mut self, req: GatewayRequest) -> Self::Future {
let delay = sleep(Duration::from_millis(200));
let response_future = self.inner.call(req.clone());
CapabilityCheckFuture {
response_future,
delay,
req,
}
}
}
#[pin_project]
pub struct CapabilityCheckFuture<F> {
#[pin]
response_future: F,
#[pin]
delay: Sleep,
req: GatewayRequest,
}
impl<F, Response, Error> Future for CapabilityCheckFuture<F>
where
F: Future<Output = Result<Response, Error>>,
Error: Into<BoxError>,
{
type Output = Result<Response, BoxError>;
fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
let this = self.project();
match this.delay.poll(cx) {
Poll::Ready(_) => {
// Now that we've waited, let's check the request. No "foo" for you!
if this.req.req == "foo".to_string() {
warn!(request = %this.req.req, "failed capability check");
let error = Box::new(CapabilityError(()));
return Poll::Ready(Err(error));
}
// If the request passed, let's poll the inner future and return
// its result whenever that happens
match this.response_future.poll(cx) {
Poll::Ready(result) => {
let result = result.map_err(Into::into);
return Poll::Ready(result);
}
Poll::Pending => {}
}
}
Poll::Pending => {}
}
Poll::Pending
}
}
#[derive(Debug, Default)]
pub struct CapabilityError(());
// Something nice to give back to the client
impl Display for CapabilityError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.pad("incapable")
}
}
// All of error's default implementations are fine with us
// Note implementing Error allows us to satisfy an outer service's BoxError requirement.
impl Error for CapabilityError {}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment