Skip to content

Instantly share code, notes, and snippets.

@afpro
Last active April 17, 2024 12:40
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save afpro/7c76ab62e68d73150d9edcc9344c4ece to your computer and use it in GitHub Desktop.
Save afpro/7c76ab62e68d73150d9edcc9344c4ece to your computer and use it in GitHub Desktop.
rust http server
use std::{
borrow::Cow,
fmt::Debug,
future::Future,
net::SocketAddr,
pin::Pin,
task::{ready, Context, Poll},
time::Instant,
};
use axum::{
extract::{ConnectInfo, Request},
response::Response,
};
use pin_project::{pin_project, pinned_drop};
use tower::Service;
use tower_layer::Layer;
use tracing::{error, info, span, warn, Level, Span};
use uuid::Uuid;
#[derive(Copy, Clone)]
pub struct AccessLog;
impl<S> Layer<S> for AccessLog {
type Service = AccessLogService<S>;
fn layer(&self, inner: S) -> Self::Service {
AccessLogService { inner }
}
}
#[derive(Clone)]
pub struct AccessLogService<S> {
inner: S,
}
impl<S> AccessLogService<S> {
fn extract_remote<B>(req: &Request<B>) -> Cow<'static, str> {
match req.extensions().get::<ConnectInfo<SocketAddr>>() {
Some(ConnectInfo(addr)) => addr.to_string().into(),
None => "unknown".into(),
}
}
}
impl<S, Req, Resp> Service<Request<Req>> for AccessLogService<S>
where
S: Service<Request<Req>, Response = Response<Resp>>,
S::Future: Future<Output = Result<Response<Resp>, S::Error>>,
S::Error: Debug,
{
type Response = Response<Resp>;
type Error = S::Error;
type Future = AccessLogServiceFuture<S::Future>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request<Req>) -> Self::Future {
let span = span!(Level::INFO, "request", id=%Uuid::new_v4().simple());
{
let _guard = span.enter();
info!(
target: "request",
remote = %Self::extract_remote(&req),
method = %req.method(),
uri = %req.uri(),
headers = ?req.headers(),
"begin",
);
}
AccessLogServiceFuture::new(span, self.inner.call(req))
}
}
#[pin_project(PinnedDrop)]
pub struct AccessLogServiceFuture<F> {
span: Span,
done: bool,
start: Instant,
#[pin]
inner: F,
}
impl<F> AccessLogServiceFuture<F> {
fn new(span: Span, inner: F) -> Self {
Self {
span,
done: false,
start: Instant::now(),
inner,
}
}
}
impl<F, B, E> Future for AccessLogServiceFuture<F>
where
F: Future<Output = Result<Response<B>, E>>,
E: Debug,
{
type Output = Result<Response<B>, E>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
let _guard = this.span.enter();
let result = ready!(this.inner.poll(cx));
if !*this.done {
*this.done = true;
let cost = Instant::now().duration_since(*this.start);
match &result {
Ok(response) => {
if (400..=599).contains(&response.status().as_u16()) {
error!(
target: "request",
status = response.status().as_u16(),
cost = cost.as_millis(),
"end with error status"
);
} else {
info!(
target: "request",
status = response.status().as_u16(),
cost = cost.as_millis(),
"end ok"
);
}
}
Err(err) => {
error!(
target: "request",
cost = cost.as_millis(),
"end with uncached error {:?}", err
);
}
}
}
Poll::Ready(result)
}
}
#[pinned_drop]
impl<F> PinnedDrop for AccessLogServiceFuture<F> {
fn drop(self: Pin<&mut Self>) {
if !self.done {
let _guard = self.span.enter();
let cost = Instant::now().duration_since(self.start);
warn!(
target: "request",
cost = cost.as_millis(),
"request connection dropped before finish",
);
}
}
}
use std::convert::Infallible;
use anyhow::{Context, Result};
use clap::Args;
use mysql_async::{Conn, Opts, OptsBuilder, Pool};
use tracing::info;
#[derive(Args, Clone)]
pub struct DataMysqlOpts {
#[clap(
name = "mysql-host",
long = "mysql-host",
default_value = "127.0.0.1",
help = "mysql ip or host"
)]
pub host: String,
#[clap(
name = "mysql-port",
long = "mysql-port",
default_value = "3306",
help = "mysql port"
)]
pub port: u16,
#[clap(
name = "mysql-username",
long = "mysql-user",
default_value = "root",
help = "mysql username"
)]
pub user: String,
#[clap(
name = "mysql-password",
long = "mysql-pass",
default_value = "",
help = "mysql password"
)]
pub pass: String,
#[clap(
name = "mysql-db-name",
long = "mysql-db-name",
default_value = "archive",
help = "mysql database name"
)]
pub name: String,
}
#[derive(Clone)]
pub struct DataMysql {
pool: Pool,
}
impl DataMysql {
pub async fn create_by_opts(opts: &DataMysqlOpts) -> Result<Self> {
Ok(Self {
pool: Pool::new(opts),
})
}
pub async fn get_conn(&self) -> Result<Conn> {
self.pool
.get_conn()
.await
.context("obtain mysql connection")
}
pub async fn dump_info(&self) -> Result<()> {
let conn = self.get_conn().await?;
let (major, minor, patch) = conn.server_version();
info!("connected to mysql {}.{}.{}", major, minor, patch);
Ok(())
}
}
impl TryFrom<&DataMysqlOpts> for Opts {
type Error = Infallible;
fn try_from(value: &DataMysqlOpts) -> Result<Self, Self::Error> {
Ok(Opts::from(
OptsBuilder::default()
.ip_or_hostname(&value.host)
.tcp_port(value.port)
.user(Some(&value.user))
.pass(Some(&value.pass))
.db_name(Some(&value.name)),
))
}
}
use anyhow::{Context, Result};
use clap::Args;
use redis::{
aio::{MultiplexedConnection, PubSub},
cmd, Client, ConnectionAddr, ConnectionInfo, IntoConnectionInfo, RedisConnectionInfo,
RedisResult,
};
use tracing::info;
#[derive(Args, Clone)]
pub struct DataRedisOpts {
#[clap(
name = "redis-host",
long = "redis-host",
default_value = "127.0.0.1",
help = "redis ip or host"
)]
pub host: String,
#[clap(
name = "redis-port",
long = "redis-port",
default_value = "6379",
help = "redis port"
)]
pub port: u16,
#[clap(name = "redis-username", long = "redis-user", help = "redis username")]
pub user: Option<String>,
#[clap(name = "redis-password", long = "redis-pass", help = "redis password")]
pub pass: Option<String>,
#[clap(
name = "redis-index",
long = "redis-index",
default_value = "0",
help = "redis database index"
)]
pub index: i64,
}
#[derive(Clone)]
pub struct DataRedis {
client: Client,
}
impl DataRedis {
pub async fn create_by_opts(opts: &DataRedisOpts) -> Result<Self> {
let client = Client::open(opts).context("open redis client")?;
Ok(Self { client })
}
pub async fn get_conn(&self) -> Result<MultiplexedConnection> {
self.client
.get_multiplexed_async_connection()
.await
.context("obtain redis connection")
}
pub async fn get_pub_sub(&self) -> Result<PubSub> {
self.client
.get_async_pubsub()
.await
.context("obtain redis pub sub")
}
pub async fn dump_info(&self) -> Result<()> {
let mut conn = self.get_conn().await?;
let info = cmd("info")
.arg("server")
.query_async::<_, String>(&mut conn)
.await
.context("dump redis server info")?;
let version = info
.lines()
.map(|v| v.trim())
.filter_map(|v| v.strip_prefix("redis_version:"))
.next()
.context("can't extract redis_version from info dump")?;
info!("connected to redis {}", version);
Ok(())
}
}
impl IntoConnectionInfo for &DataRedisOpts {
fn into_connection_info(self) -> RedisResult<ConnectionInfo> {
Ok(ConnectionInfo {
addr: ConnectionAddr::Tcp(self.host.clone(), self.port),
redis: RedisConnectionInfo {
db: self.index,
username: self.user.clone(),
password: self.pass.clone(),
},
})
}
}
use std::{future::Future, time::Duration};
use anyhow::{Context, Error, Result};
use async_trait::async_trait;
use indoc::indoc;
use lazy_static::lazy_static;
use redis::{aio::ConnectionLike, AsyncCommands, RedisError, Script};
use tokio::{select, time::sleep};
use uuid::Uuid;
lazy_static! {
static ref SCRIPT_INIT_LOCK: Script = Script::new(indoc! {r#"
if (redis.call('SET', KEYS[1], ARGV[1], 'NX', 'EX', '2'))
then
return true;
else
return false;
end
"#});
static ref SCRIPT_UPDATE: Script = Script::new(indoc! {r#"
local current = redis.call('GET', KEYS[1]);
if (current ~= ARGV[1])
then
return false;
end
if (redis.call('EXPIRE', KEYS[1], '2'))
then
return true;
else
return false;
end
"#});
}
#[derive(Debug, thiserror::Error)]
pub enum DistributionLockError {
#[error("redis error {0}")]
RedisError(#[source] Error),
#[error("task error {0}")]
TaskError(#[source] Error),
#[error("lock current already acquired by others")]
LockCurrentAcquired,
#[error("lock update during task failed")]
LockUpdateFailed,
#[error("lock update during task failed by error {0}")]
LockUpdateError(#[source] RedisError),
#[error("lock release failed by error {0}")]
LockReleaseError(#[source] RedisError),
}
#[async_trait]
pub trait DistributionLock {
async fn run_task_with_lock<R: Send, F: Future<Output = Result<R>> + Send>(
&mut self,
lock: &str,
task: F,
) -> Result<R, DistributionLockError>;
}
#[async_trait]
impl<C> DistributionLock for C
where
C: ConnectionLike + Send,
{
async fn run_task_with_lock<R: Send, F: Future<Output = Result<R>> + Send>(
&mut self,
lock: &str,
task: F,
) -> Result<R, DistributionLockError> {
let key = lock_cache_key(lock);
let value = Uuid::new_v4();
let init_lock: bool = SCRIPT_INIT_LOCK
.prepare_invoke()
.key(&key)
.arg(value)
.invoke_async(self)
.await
.context("init lock")
.map_err(DistributionLockError::RedisError)?;
if !init_lock {
return Err(DistributionLockError::LockCurrentAcquired);
}
let update_lock = async {
loop {
match SCRIPT_UPDATE
.prepare_invoke()
.key(&key)
.arg(value)
.invoke_async::<_, bool>(self)
.await
{
Ok(true) => {
sleep(Duration::from_secs(1)).await;
}
Ok(false) => {
return DistributionLockError::LockUpdateFailed;
}
Err(err) => {
return DistributionLockError::LockUpdateError(err);
}
}
}
};
// run task & periodically update redis
let ret = select! {
task_ret = task => {
task_ret.map_err(DistributionLockError::TaskError)?
}
update_ret = update_lock => {
return Err(update_ret);
}
};
// remove redis key
self.del::<_, ()>(&key)
.await
.map_err(DistributionLockError::LockReleaseError)?;
Ok(ret)
}
}
fn lock_cache_key(name: &str) -> String {
format!("LOCK:{}", name)
}
#[cfg(test)]
mod test {
use std::time::Duration;
use redis::{cmd, Client};
use tokio::time::sleep;
use uuid::Uuid;
use super::{lock_cache_key, DistributionLock};
#[tokio::test]
async fn run_lock() {
let test_name = "test_lock";
let test_key = lock_cache_key(test_name);
let cache = Client::open("redis://127.0.0.1").expect("connect redis");
let mut conn = cache
.get_multiplexed_async_connection()
.await
.expect("obtain redis connection");
conn.clone()
.run_task_with_lock(test_name, async {
let value: Uuid = cmd("get").arg(&test_key).query_async(&mut conn).await?;
sleep(Duration::from_secs(10)).await;
let post_value: Uuid = cmd("get").arg(&test_key).query_async(&mut conn).await?;
assert_eq!(value, post_value);
Ok(())
})
.await
.expect("run with lock");
}
}
use async_trait::async_trait;
use axum::{
extract::{FromRequest, Request},
response::{IntoResponse, Response},
Json,
};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use tracing::warn;
#[macro_export]
macro_rules! api_response_combine {
($enum_name:ident { $($name:ident($type:ty),)+ }) => {
pub enum $enum_name {
Status($crate::rest::ApiStatus),
Raw(axum::response::Response),
$($name($type),)+
}
impl From<$crate::rest::ApiStatus> for $enum_name {
fn from(value: $crate::rest::ApiStatus) -> Self {
Self::Status(value)
}
}
impl From<axum::response::Response> for $enum_name {
fn from(value: axum::response::Response) -> Self {
Self::Raw(value)
}
}
$(
impl From<$type> for $enum_name {
fn from(value: $type) -> Self {
Self::$name(value)
}
}
)+
impl axum::response::IntoResponse for $enum_name {
fn into_response(self) -> axum::response::Response {
match self {
Self::Status(status) => $crate::rest::ApiResponse::<()>::from(status).into_response(),
Self::Raw(response) => response,
$(Self::$name(v) => v.into_response(),)+
}
}
}
};
}
#[macro_export]
macro_rules! api_require_token {
($request:ident) => {
match &$request.token {
Some(token) => token,
None => {
tracing::warn!("token missing");
return $crate::rest::ApiResponse::fail(
$crate::rest::ApiStatus::InvalidRequest,
Some("token missing".to_string()),
)
.into();
}
}
};
(take $request:ident) => {
match $request.token {
Some(token) => token,
None => {
tracing::warn!("token missing");
return $crate::rest::ApiResponse::fail(
$crate::rest::ApiStatus::InvalidRequest,
Some("token missing".to_string()),
)
.into();
}
}
};
}
#[macro_export]
macro_rules! api_require_data {
($request:ident) => {
match &$request.data {
Some(data) => data,
None => {
tracing::warn!("data missing");
return $crate::rest::ApiResponse::fail(
$crate::rest::ApiStatus::InvalidRequest,
Some("data missing".to_string()),
)
.into();
}
}
};
(take $request:ident) => {
match $request.data {
Some(data) => data,
None => {
tracing::warn!("data missing");
return $crate::rest::ApiResponse::fail(
$crate::rest::ApiStatus::InvalidRequest,
Some("data missing".to_string()),
)
.into();
}
}
};
}
#[derive(Serialize, Deserialize)]
pub struct ApiRequest<T = ()> {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub token: Option<String>,
#[serde(
default = "Option::default",
skip_serializing_if = "Option::is_none",
bound(
serialize = "Option<T>: Serialize",
deserialize = "Option<T>: Deserialize<'de>",
)
)]
pub data: Option<T>,
}
#[async_trait]
impl<T, S> FromRequest<S> for ApiRequest<T>
where
ApiRequest<T>: DeserializeOwned,
S: Send + Sync,
{
type Rejection = ApiResponse;
async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
match Json::<ApiRequest<T>>::from_request(req, state).await {
Ok(Json(data)) => Ok(data),
Err(err) => {
warn!("decode request failed: {:?}", err);
Err(ApiResponse::fail(ApiStatus::InvalidRequest, None))
}
}
}
}
#[derive(Serialize, Deserialize)]
pub struct ApiResponse<T = ()> {
pub status: ApiStatus,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub message: Option<String>,
#[serde(
default = "Option::default",
skip_serializing_if = "Option::is_none",
bound(
serialize = "Option<T>: Serialize",
deserialize = "Option<T>: Deserialize<'de>",
)
)]
pub data: Option<T>,
}
impl<T> From<ApiStatus> for ApiResponse<T> {
fn from(value: ApiStatus) -> Self {
Self {
status: value,
message: None,
data: None,
}
}
}
impl<T> ApiResponse<T> {
pub fn success(data: Option<T>) -> Self {
Self {
status: ApiStatus::Success,
message: None,
data,
}
}
pub fn fail(status: ApiStatus, message: Option<String>) -> Self {
assert!(status != ApiStatus::Success, "fail with success status");
Self {
status,
message,
data: None,
}
}
}
impl<T> IntoResponse for ApiResponse<T>
where
ApiResponse<T>: Serialize,
{
fn into_response(self) -> Response {
Json(self).into_response()
}
}
#[derive(Serialize, Deserialize, Copy, Clone, Eq, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum ApiStatus {
Success,
InvalidRequest,
ServerError,
}
use std::net::SocketAddr;
use anyhow::{Context, Result};
use axum::Router;
use tokio::{net::TcpListener, signal::ctrl_c};
use tracing::{error, info, instrument};
pub async fn run_http_server(bind: SocketAddr, router: Router<()>) -> Result<()> {
let tcp_listener = TcpListener::bind(bind)
.await
.with_context(|| format!("can't bind tcp socket {}", bind))?;
info!("bound at {}", bind);
axum::serve(
tcp_listener,
router.into_make_service_with_connect_info::<SocketAddr>(),
)
.with_graceful_shutdown(graceful_shutdown())
.await
.context("run http")?;
Ok(())
}
#[instrument("graceful-shutdown")]
async fn graceful_shutdown() {
match ctrl_c().await {
Ok(_) => {
info!("CTRL+C pressed, quiting");
}
Err(err) => {
error!("tokio CTRL+C signal handler error {}", err);
}
}
}
use std::borrow::Cow;
use clap::{Args, ValueEnum};
use tracing::level_filters::LevelFilter;
use tracing_appender::{non_blocking::WorkerGuard, rolling::Rotation};
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter, Layer};
#[derive(Args)]
pub struct TracingSetupOpts {
#[clap(long = "tracing", default_value = "debug", help = "tracing env filter")]
tracing: Cow<'static, str>,
#[clap(
long = "tracing-output-dir",
default_value = "./log",
help = "tracing output dir"
)]
tracing_output_dir: Cow<'static, str>,
#[clap(
long = "tracing-output-prefix",
default_value = "host",
help = "tracing output filename prefix"
)]
tracing_output_prefix: Cow<'static, str>,
#[clap(
long = "tracing-output-suffix",
default_value = "log",
help = "tracing output filename suffix"
)]
tracing_output_suffix: Cow<'static, str>,
#[clap(
long = "tracing-output-rotate",
default_value = "day",
help = "tracing output filename suffix"
)]
tracing_output_rotate: TracingRotateType,
#[clap(
long = "tracing-output-files",
default_value = "7",
help = "tracing output file count"
)]
tracing_output_files: usize,
}
#[derive(ValueEnum, Clone, Copy)]
pub enum TracingRotateType {
Minute,
Hour,
Day,
Never,
}
impl From<TracingRotateType> for Rotation {
fn from(value: TracingRotateType) -> Self {
match value {
TracingRotateType::Minute => Rotation::MINUTELY,
TracingRotateType::Hour => Rotation::HOURLY,
TracingRotateType::Day => Rotation::DAILY,
TracingRotateType::Never => Rotation::NEVER,
}
}
}
impl TracingSetupOpts {
pub fn setup(&self) -> WorkerGuard {
let file_rolling = tracing_appender::rolling::Builder::new()
.filename_prefix(self.tracing_output_prefix.as_ref())
.filename_suffix(self.tracing_output_suffix.as_ref())
.rotation(self.tracing_output_rotate.into())
.max_log_files(self.tracing_output_files) // week
.build(self.tracing_output_dir.as_ref())
.expect("create tracing file rolling output");
let (file_rolling, guard) = tracing_appender::non_blocking(file_rolling);
tracing_subscriber::registry()
.with(
tracing_subscriber::fmt::layer().with_filter(
EnvFilter::builder()
.with_default_directive(LevelFilter::DEBUG.into())
.parse_lossy(self.tracing.as_ref()),
),
)
.with(
tracing_subscriber::fmt::layer()
.json()
.with_ansi(false)
.with_writer(file_rolling)
.with_filter(
EnvFilter::builder()
.with_default_directive(LevelFilter::DEBUG.into())
.parse_lossy(self.tracing.as_ref()),
),
)
.try_init()
.expect("setup tracing output");
guard
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment