Last active
January 26, 2024 08:30
-
-
Save DefectingCat/bafb06920fcd6184599faf93072ba989 to your computer and use it in GitHub Desktop.
axum reverse proxy with cache to redis
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use super::error::{RouteError, RouteResult}; | |
use crate::AppState; | |
use anyhow::{anyhow, Result}; | |
use axum::{ | |
body::Body, | |
extract::{Request, State}, | |
http::{response::Parts, HeaderName, HeaderValue, Uri}, | |
response::{IntoResponse, Response}, | |
}; | |
use http_body_util::BodyExt; | |
use hyper::{body::Bytes, HeaderMap}; | |
use redis::{Client, Commands}; | |
use serde::{Deserialize, Serialize}; | |
use std::str::FromStr; | |
use std::{collections::HashMap, sync::Arc}; | |
use tokio::sync::RwLock; | |
use tracing::error; | |
static BACKEND_URI: &str = "http://192.168.1.13:8086"; | |
#[derive(Debug, Serialize, Deserialize)] | |
struct Cache { | |
headers: HashMap<String, String>, | |
body: Vec<u8>, | |
} | |
fn headermap_from_hashmap<'a, I, S>(headers: I) -> HeaderMap | |
where | |
I: Iterator<Item = (S, S)> + 'a, | |
S: AsRef<str> + 'a, | |
{ | |
headers | |
.map(|(name, val)| { | |
( | |
HeaderName::from_str(name.as_ref()), | |
HeaderValue::from_str(val.as_ref()), | |
) | |
}) | |
// We ignore the errors here. If you want to get a list of failed conversions, you can use Iterator::partition | |
// to help you out here | |
.filter(|(k, v)| k.is_ok() && v.is_ok()) | |
.map(|(k, v)| (k.unwrap(), v.unwrap())) | |
.collect() | |
} | |
fn hashmap_from_headermap(headers: HeaderMap<HeaderValue>) -> HashMap<String, String> { | |
headers | |
.iter() | |
.map(|(k, v)| { | |
( | |
k.as_str().to_owned(), | |
String::from_utf8_lossy(v.as_bytes()).into_owned(), | |
) | |
}) | |
.collect() | |
} | |
pub async fn proxy(State(state): State<AppState>, req: Request) -> RouteResult<impl IntoResponse> { | |
let (parts, body) = req.into_parts(); | |
let body = body.collect().await?.to_bytes(); | |
let req_body = std::str::from_utf8(&body).map_err(|err| anyhow!(err))?; | |
let req_parts = format!("{:?}", parts.headers); | |
let cache_key = format!("{}{}{}", parts.uri, req_parts, req_body); | |
let cache = get_cache(&cache_key, &state.rdb).await; | |
if let Ok(cache) = cache { | |
return Ok(cache); | |
}; | |
// build the request and body for the real server | |
let mut proxy_req = Request::from_parts(parts, body.into()); | |
let path = proxy_req.uri().path(); | |
let path_query = proxy_req | |
.uri() | |
.path_and_query() | |
.map(|v| v.as_str()) | |
.unwrap_or(path); | |
let uri = format!("{}{}", BACKEND_URI, path_query); | |
*proxy_req.uri_mut() = Uri::try_from(uri).map_err(|err| anyhow!(err))?; | |
let name_version = format!("{}-{}", env!("CARGO_PKG_NAME"), env!("CARGO_PKG_VERSION")); | |
let mut resp = state | |
.client | |
.request(proxy_req) | |
.await | |
.map_err(|_| RouteError::BadRequest())?; | |
let resp_headers = resp.headers_mut(); | |
resp_headers.insert( | |
"X-Proxy", | |
HeaderValue::from_str(&name_version).map_err(|err| anyhow!("{err}"))?, | |
); | |
let (parts, body) = resp.into_parts(); | |
let body = body | |
.collect() | |
.await | |
.map_err(|err| anyhow!("{err}"))? | |
.to_bytes(); | |
let set_parts = parts.clone(); | |
let set_body = body.clone(); | |
let rdb = state.rdb; | |
tokio::spawn(async move { | |
let _ = set_cache(cache_key, rdb, set_parts, set_body) | |
.await | |
.map_err(|err| error!("set cache to redis failed {}", err)); | |
}); | |
// response for client | |
let mut response: Response<Body> = Response::from_parts(parts, body.into()); | |
response.headers_mut().remove("transfer-encoding"); | |
Ok(response) | |
} | |
async fn get_cache(key: &str, rdb: &Arc<RwLock<Client>>) -> Result<Response> { | |
let cache: String = rdb.write().await.get(key)?; | |
let cache: Cache = serde_json::from_str(&cache).map_err(|err| anyhow!(err))?; | |
let mut response: Response<Body> = Response::new(cache.body.into()); | |
let headers = headermap_from_hashmap(cache.headers.iter()); | |
*response.headers_mut() = headers; | |
response.headers_mut().remove("transfer-encoding"); | |
Ok(response) | |
} | |
async fn set_cache(key: String, rdb: Arc<RwLock<Client>>, parts: Parts, body: Bytes) -> Result<()> { | |
let headers = hashmap_from_headermap(parts.headers); | |
let cache = Cache { | |
headers, | |
body: body.to_vec(), | |
}; | |
let cache = serde_json::to_string(&cache).map_err(|err| anyhow!(err))?; | |
rdb.write() | |
.await | |
.set_ex::<String, String, String>(key, cache, 60 * 60) | |
.map_err(|err| anyhow!(err))?; | |
Ok(()) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment