Skip to content

Instantly share code, notes, and snippets.

@Ploppz
Last active August 30, 2023 13:03
Show Gist options
  • Save Ploppz/54e65eea580f17764c7560a3e7190141 to your computer and use it in GitHub Desktop.
Save Ploppz/54e65eea580f17764c7560a3e7190141 to your computer and use it in GitHub Desktop.
use std::task::{Context, Poll};
use futures_util::future::{ok, FutureExt, LocalBoxFuture, Ready};
use actix_web::{
body::{Body, ResponseBody},
dev::{Service, ServiceRequest, ServiceResponse, Transform},
error::{Error, Result},
};
use serde_json::json;
pub struct ErrorHandler {}
impl Default for ErrorHandler {
fn default() -> Self {
ErrorHandler {}
}
}
impl ErrorHandler {
pub fn new() -> Self {
ErrorHandler::default()
}
}
impl<S> Transform<S> for ErrorHandler
where
S: Service<Request = ServiceRequest, Response = ServiceResponse, Error = Error>,
S::Future: 'static,
{
type Request = ServiceRequest;
type Response = ServiceResponse;
type Error = Error;
type InitError = ();
type Transform = ErrorHandlerMiddleware<S>;
type Future = Ready<Result<Self::Transform, Self::InitError>>;
fn new_transform(&self, service: S) -> Self::Future {
ok(ErrorHandlerMiddleware { service })
}
}
#[doc(hidden)]
pub struct ErrorHandlerMiddleware<S> {
service: S,
}
impl<S> Service for ErrorHandlerMiddleware<S>
where
S: Service<Request = ServiceRequest, Response = ServiceResponse, Error = Error>,
S::Future: 'static,
{
type Request = ServiceRequest;
type Response = ServiceResponse;
type Error = Error;
type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.service.poll_ready(cx)
}
fn call(&mut self, req: ServiceRequest) -> Self::Future {
let log = req.app_data::<slog::Logger>();
let fut = self.service.call(req);
async move {
let mut res = fut.await?;
let status = res.status();
if !res.status().is_success() {
use actix_web::http::{HeaderName, HeaderValue};
res.headers_mut().insert(HeaderName::from_static("content-type"),
HeaderValue::from_static("application/json"));
Ok(res.map_body(|_head, old_body| match old_body.as_ref() {
Some(Body::Bytes(bytes)) => {
let msg = std::str::from_utf8(bytes)
.unwrap_or("<no message>")
.to_string();
ResponseBody::Body(Body::from(json! ({
"message": msg,
"code": u16::from(status),
})))
}
Some(Body::Empty) | Some(Body::None) | None => ResponseBody::Body(Body::from(json! ({
"message": status.to_string(),
"code": u16::from(status),
}))),
Some(Body::Message (_)) => {
// Print warning if logger exists
if let Some(log) = log {
slog::warn!(log, "ErrorHandler middleware: Body::Message case unimplemented (pass-through)");
} else {
println!("WARNING: ErrorHandler middleware: Body::Message case unimplemented (pass-through)");
}
old_body
},
}))
} else {
Ok(res)
}
}
.boxed_local()
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment