Created
April 10, 2024 16:55
-
-
Save kotobukid/98507bf6ee4bf389cb1c2458cfad0bf6 to your computer and use it in GitHub Desktop.
Axumでミドルウェアでセットした値をリクエストハンドラで取得する
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::net::SocketAddr; | |
use std::sync::Arc; | |
use axum; | |
use axum::response::IntoResponse; | |
use axum::{routing::{get, post}, Router, response::Html, Json, http::StatusCode, extract::Form, Extension}; | |
use tower_http::trace::TraceLayer; | |
use axum::{ | |
body::{Body, Bytes}, | |
extract::Request, | |
middleware::{self, Next}, | |
response::{Response}, | |
}; | |
use http_body_util::BodyExt; | |
use tokio::sync::Mutex; | |
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; | |
type Counter = Arc<Mutex<i64>>; | |
#[tokio::main] | |
async fn main() { | |
let counter: Counter = Arc::new(Mutex::new(0_i64)); | |
let counter_layer: Extension<Counter> = Extension(counter); | |
let app = Router::new().nest("/private", Router::new() | |
.route("/", get(hello_handler)) | |
.layer(middleware::from_fn(print_request_response)) | |
.layer(counter_layer), | |
); | |
let port = 3000; | |
let addr = SocketAddr::from(([127, 0, 0, 1], port)); | |
let listener = tokio::net::TcpListener::bind(addr).await.unwrap(); | |
tracing::debug!("listening on {}", listener.local_addr().unwrap()); | |
axum::serve(listener, app.layer(TraceLayer::new_for_http())) | |
.await | |
.unwrap(); | |
} | |
async fn hello_handler(Extension(counter): Extension<Counter>) -> impl IntoResponse { | |
println!("hello handler"); | |
let c = counter.lock().await; | |
println!("count is {}", *c); | |
(StatusCode::OK, "hello hello!!") | |
} | |
async fn print_request_response( | |
Extension(counter): Extension<Counter>, | |
req: Request, | |
next: Next, | |
) -> Result<impl IntoResponse, (StatusCode, String)> { | |
let (parts, body) = req.into_parts(); | |
let bytes = buffer_and_print("request", body).await?; | |
let mut req = Request::from_parts(parts, Body::from(bytes)); | |
let mut count = counter.lock().await; | |
*count += 1; | |
req.extensions_mut().insert(Arc::new(Mutex::new(*count))); | |
let res = next.run(req).await; | |
println!("{:?}", res); | |
let (parts, body) = res.into_parts(); | |
let bytes = buffer_and_print("response", body).await?; | |
let res = Response::from_parts(parts, Body::from(bytes)); | |
Ok(res) | |
} | |
async fn buffer_and_print<B>(direction: &str, body: B) -> Result<Bytes, (StatusCode, String)> | |
where | |
B: axum::body::HttpBody<Data=Bytes>, | |
B::Error: std::fmt::Display, | |
{ | |
let bytes = match body.collect().await { | |
Ok(collected) => collected.to_bytes(), | |
Err(err) => { | |
return Err(( | |
StatusCode::BAD_REQUEST, | |
format!("failed to read {direction} body: {err}"), | |
)); | |
} | |
}; | |
if let Ok(body) = std::str::from_utf8(&bytes) { | |
tracing::debug!("{direction} body = {body:?}"); | |
} | |
Ok(bytes) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment