Skip to content

Instantly share code, notes, and snippets.

@DefectingCat
Last active June 5, 2024 14:40
Show Gist options
  • Save DefectingCat/871f8887b23bace80d248b38c82e3aa8 to your computer and use it in GitHub Desktop.
Save DefectingCat/871f8887b23bace80d248b38c82e3aa8 to your computer and use it in GitHub Desktop.
axum routes nest and trace layer
use std::{collections::HashMap, time::Duration};
use axum::{
async_trait,
body::Bytes,
extract::{FromRequestParts, MatchedPath, Path, Request},
http::{request::Parts, HeaderMap, StatusCode, Uri},
middleware,
response::{IntoResponse, Response},
routing::get,
RequestPartsExt, Router,
};
use tower::ServiceBuilder;
use tower_http::{
classify::ServerErrorsFailureClass, compression::CompressionLayer, cors::CorsLayer,
timeout::TimeoutLayer, trace::TraceLayer,
};
use tracing::{error, info, info_span, Span};
use crate::{middlewares::add_version, AppState};
pub mod users;
/// 注册所有的路由以及中间件
pub fn routes() -> Router<AppState> {
Router::new()
.route("/", get(hello).post(hello))
.nest("/:version/", users::routes())
.fallback(fallback)
.layer(
ServiceBuilder::new()
.layer(middleware::from_fn(add_version))
.layer(CorsLayer::permissive())
.layer(TimeoutLayer::new(Duration::from_secs(15)))
.layer(CompressionLayer::new()),
)
.layer(
TraceLayer::new_for_http()
.make_span_with(|request: &Request<_>| {
let matched_path = request
.extensions()
.get::<MatchedPath>()
.map(MatchedPath::as_str);
info_span!(
"http_request",
method = ?request.method(),
matched_path,
some_other_field = tracing::field::Empty,
)
})
.on_request(|req: &Request<_>, _span: &Span| {
info!("{} {}", req.method(), req.uri());
})
.on_response(|res: &Response, latency: Duration, _span: &Span| {
info!("{} {}ms", res.status(), latency.as_millis());
})
.on_body_chunk(|_chunk: &Bytes, _latency: Duration, _span: &Span| {})
.on_eos(
|_trailers: Option<&HeaderMap>, _stream_duration: Duration, _span: &Span| {},
)
.on_failure(
|error: ServerErrorsFailureClass, _latency: Duration, _span: &Span| {
error!("{}", error);
},
),
)
}
/// hello world
pub async fn hello() -> String {
format!("hello {}", env!("CARGO_PKG_NAME"))
}
/// Fallback route handler for handling unmatched routes.
///
/// This asynchronous function takes a `Uri` as an argument, representing the unmatched route.
/// It logs a message indicating that the specified route is not found and returns a standard
/// "Not Found" response with a `StatusCode` of `404`.
///
/// # Arguments
///
/// - `uri`: The `Uri` representing the unmatched route.
///
/// # Returns
///
/// Returns a tuple `(StatusCode, &str)` where `StatusCode` is set to `NOT_FOUND` (404),
/// indicating that the route was not found, and the string "Not found" as the response body.
pub async fn fallback(uri: Uri) -> impl IntoResponse {
info!("route {} not found", uri);
(StatusCode::NOT_FOUND, "Not found")
}
#[derive(Debug)]
enum Version {
V1,
V2,
V3,
}
#[async_trait]
impl<S> FromRequestParts<S> for Version
where
S: Send + Sync,
{
type Rejection = Response;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
let params: Path<HashMap<String, String>> =
parts.extract().await.map_err(IntoResponse::into_response)?;
let version = params
.get("version")
.ok_or_else(|| (StatusCode::NOT_FOUND, "version param missing").into_response())?;
match version.as_str() {
"v1" => Ok(Version::V1),
"v2" => Ok(Version::V2),
"v3" => Ok(Version::V3),
_ => Err((StatusCode::NOT_FOUND, "unknown version").into_response()),
}
}
}
pub fn routes() -> Router<AppState> {
Router::new()
.route("/", get(hello).post(hello))
.nest("/:version/", users::routes())
.fallback(fallback)
.layer(
ServiceBuilder::new()
.layer(middleware::from_fn(add_version))
.layer(CorsLayer::permissive())
.layer(TimeoutLayer::new(Duration::from_secs(15)))
.layer(CompressionLayer::new()),
)
.layer(
TraceLayer::new_for_http()
.make_span_with(|request: &Request<_>| {
let matched_path = request
.extensions()
.get::<MatchedPath>()
.map(MatchedPath::as_str);
info_span!(
"http_request",
method = ?request.method(),
matched_path,
some_other_field = tracing::field::Empty,
)
})
.on_request(|req: &Request<_>, _span: &Span| {
info!("{} {}", req.method(), req.uri());
})
.on_response(|res: &Response, latency: Duration, _span: &Span| {
info!("{} {}ms", res.status(), latency.as_millis());
})
.on_body_chunk(|_chunk: &Bytes, _latency: Duration, _span: &Span| {})
.on_eos(
|_trailers: Option<&HeaderMap>, _stream_duration: Duration, _span: &Span| {},
)
.on_failure(
|error: ServerErrorsFailureClass, _latency: Duration, _span: &Span| {
error!("{}", error);
},
),
)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment