Skip to content

Instantly share code, notes, and snippets.

@kotobukid
Created April 10, 2024 16:55
Show Gist options
  • Save kotobukid/98507bf6ee4bf389cb1c2458cfad0bf6 to your computer and use it in GitHub Desktop.
Save kotobukid/98507bf6ee4bf389cb1c2458cfad0bf6 to your computer and use it in GitHub Desktop.
Axumでミドルウェアでセットした値をリクエストハンドラで取得する
use std::net::SocketAddr;
use std::sync::Arc;
use axum;
use axum::response::IntoResponse;
use axum::{routing::{get, post}, Router, response::Html, Json, http::StatusCode, extract::Form, Extension};
use tower_http::trace::TraceLayer;
use axum::{
body::{Body, Bytes},
extract::Request,
middleware::{self, Next},
response::{Response},
};
use http_body_util::BodyExt;
use tokio::sync::Mutex;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
type Counter = Arc<Mutex<i64>>;
#[tokio::main]
async fn main() {
let counter: Counter = Arc::new(Mutex::new(0_i64));
let counter_layer: Extension<Counter> = Extension(counter);
let app = Router::new().nest("/private", Router::new()
.route("/", get(hello_handler))
.layer(middleware::from_fn(print_request_response))
.layer(counter_layer),
);
let port = 3000;
let addr = SocketAddr::from(([127, 0, 0, 1], port));
let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
tracing::debug!("listening on {}", listener.local_addr().unwrap());
axum::serve(listener, app.layer(TraceLayer::new_for_http()))
.await
.unwrap();
}
async fn hello_handler(Extension(counter): Extension<Counter>) -> impl IntoResponse {
println!("hello handler");
let c = counter.lock().await;
println!("count is {}", *c);
(StatusCode::OK, "hello hello!!")
}
async fn print_request_response(
Extension(counter): Extension<Counter>,
req: Request,
next: Next,
) -> Result<impl IntoResponse, (StatusCode, String)> {
let (parts, body) = req.into_parts();
let bytes = buffer_and_print("request", body).await?;
let mut req = Request::from_parts(parts, Body::from(bytes));
let mut count = counter.lock().await;
*count += 1;
req.extensions_mut().insert(Arc::new(Mutex::new(*count)));
let res = next.run(req).await;
println!("{:?}", res);
let (parts, body) = res.into_parts();
let bytes = buffer_and_print("response", body).await?;
let res = Response::from_parts(parts, Body::from(bytes));
Ok(res)
}
async fn buffer_and_print<B>(direction: &str, body: B) -> Result<Bytes, (StatusCode, String)>
where
B: axum::body::HttpBody<Data=Bytes>,
B::Error: std::fmt::Display,
{
let bytes = match body.collect().await {
Ok(collected) => collected.to_bytes(),
Err(err) => {
return Err((
StatusCode::BAD_REQUEST,
format!("failed to read {direction} body: {err}"),
));
}
};
if let Ok(body) = std::str::from_utf8(&bytes) {
tracing::debug!("{direction} body = {body:?}");
}
Ok(bytes)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment