Skip to content

Instantly share code, notes, and snippets.

@DefectingCat
Last active January 26, 2024 08:30
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save DefectingCat/bafb06920fcd6184599faf93072ba989 to your computer and use it in GitHub Desktop.
Save DefectingCat/bafb06920fcd6184599faf93072ba989 to your computer and use it in GitHub Desktop.
axum reverse proxy with cache to redis
use super::error::{RouteError, RouteResult};
use crate::AppState;
use anyhow::{anyhow, Result};
use axum::{
body::Body,
extract::{Request, State},
http::{response::Parts, HeaderName, HeaderValue, Uri},
response::{IntoResponse, Response},
};
use http_body_util::BodyExt;
use hyper::{body::Bytes, HeaderMap};
use redis::{Client, Commands};
use serde::{Deserialize, Serialize};
use std::str::FromStr;
use std::{collections::HashMap, sync::Arc};
use tokio::sync::RwLock;
use tracing::error;
static BACKEND_URI: &str = "http://192.168.1.13:8086";
#[derive(Debug, Serialize, Deserialize)]
struct Cache {
headers: HashMap<String, String>,
body: Vec<u8>,
}
fn headermap_from_hashmap<'a, I, S>(headers: I) -> HeaderMap
where
I: Iterator<Item = (S, S)> + 'a,
S: AsRef<str> + 'a,
{
headers
.map(|(name, val)| {
(
HeaderName::from_str(name.as_ref()),
HeaderValue::from_str(val.as_ref()),
)
})
// We ignore the errors here. If you want to get a list of failed conversions, you can use Iterator::partition
// to help you out here
.filter(|(k, v)| k.is_ok() && v.is_ok())
.map(|(k, v)| (k.unwrap(), v.unwrap()))
.collect()
}
fn hashmap_from_headermap(headers: HeaderMap<HeaderValue>) -> HashMap<String, String> {
headers
.iter()
.map(|(k, v)| {
(
k.as_str().to_owned(),
String::from_utf8_lossy(v.as_bytes()).into_owned(),
)
})
.collect()
}
pub async fn proxy(State(state): State<AppState>, req: Request) -> RouteResult<impl IntoResponse> {
let (parts, body) = req.into_parts();
let body = body.collect().await?.to_bytes();
let req_body = std::str::from_utf8(&body).map_err(|err| anyhow!(err))?;
let req_parts = format!("{:?}", parts.headers);
let cache_key = format!("{}{}{}", parts.uri, req_parts, req_body);
let cache = get_cache(&cache_key, &state.rdb).await;
if let Ok(cache) = cache {
return Ok(cache);
};
// build the request and body for the real server
let mut proxy_req = Request::from_parts(parts, body.into());
let path = proxy_req.uri().path();
let path_query = proxy_req
.uri()
.path_and_query()
.map(|v| v.as_str())
.unwrap_or(path);
let uri = format!("{}{}", BACKEND_URI, path_query);
*proxy_req.uri_mut() = Uri::try_from(uri).map_err(|err| anyhow!(err))?;
let name_version = format!("{}-{}", env!("CARGO_PKG_NAME"), env!("CARGO_PKG_VERSION"));
let mut resp = state
.client
.request(proxy_req)
.await
.map_err(|_| RouteError::BadRequest())?;
let resp_headers = resp.headers_mut();
resp_headers.insert(
"X-Proxy",
HeaderValue::from_str(&name_version).map_err(|err| anyhow!("{err}"))?,
);
let (parts, body) = resp.into_parts();
let body = body
.collect()
.await
.map_err(|err| anyhow!("{err}"))?
.to_bytes();
let set_parts = parts.clone();
let set_body = body.clone();
let rdb = state.rdb;
tokio::spawn(async move {
let _ = set_cache(cache_key, rdb, set_parts, set_body)
.await
.map_err(|err| error!("set cache to redis failed {}", err));
});
// response for client
let mut response: Response<Body> = Response::from_parts(parts, body.into());
response.headers_mut().remove("transfer-encoding");
Ok(response)
}
async fn get_cache(key: &str, rdb: &Arc<RwLock<Client>>) -> Result<Response> {
let cache: String = rdb.write().await.get(key)?;
let cache: Cache = serde_json::from_str(&cache).map_err(|err| anyhow!(err))?;
let mut response: Response<Body> = Response::new(cache.body.into());
let headers = headermap_from_hashmap(cache.headers.iter());
*response.headers_mut() = headers;
response.headers_mut().remove("transfer-encoding");
Ok(response)
}
async fn set_cache(key: String, rdb: Arc<RwLock<Client>>, parts: Parts, body: Bytes) -> Result<()> {
let headers = hashmap_from_headermap(parts.headers);
let cache = Cache {
headers,
body: body.to_vec(),
};
let cache = serde_json::to_string(&cache).map_err(|err| anyhow!(err))?;
rdb.write()
.await
.set_ex::<String, String, String>(key, cache, 60 * 60)
.map_err(|err| anyhow!(err))?;
Ok(())
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment