-
-
Save konstin/54b983e7f0f4f77d38b4151e6a9f295c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::future::Future; | |
use std::path::Path; | |
use std::time::SystemTime; | |
use http_cache_semantics::{AfterResponse, BeforeRequest, CachePolicy}; | |
use reqwest::{Client, Request, Response}; | |
use serde::de::DeserializeOwned; | |
use serde::{Deserialize, Serialize}; | |
use tracing::{trace, warn}; | |
use url::Url; | |
use crate::error::Error; | |
#[derive(Debug)] | |
enum CachedResponse<Payload: Serialize> { | |
/// The cached response is fresh without an HTTP request (e.g. immutable) | |
FreshCache(Payload), | |
/// The cached response is fresh after an HTTP request (e.g. 304 not modified) | |
NotModified(DataWithCachePolicy<Payload>), | |
/// There was no prior cached response or the cache was outdated | |
/// | |
/// The cache policy is `None` if it isn't storable | |
ModifiedOrNew(Response, Option<CachePolicy>), | |
} | |
/// Serialize the actual payload together with its caching information | |
#[derive(Debug, Deserialize, Serialize)] | |
struct DataWithCachePolicy<Payload: Serialize> { | |
data: Payload, | |
cache_policy: CachePolicy, | |
} | |
#[derive(Debug, Clone)] | |
pub(crate) struct CachedClient(Client); | |
impl CachedClient { | |
pub(crate) fn new(client: Client) -> Self { | |
Self(client) | |
} | |
/// Makes a cached request with a custom data transformation | |
/// | |
/// If a new response was received (no prior cached response or modified on the serde), the | |
/// response through `transform_response` and only the result is cached and returned | |
pub(crate) async fn get_transformed_cached< | |
Payload: Serialize + DeserializeOwned, | |
Callback, | |
CallbackReturn, | |
>( | |
&self, | |
url: Url, | |
cache_file: &Path, | |
transform_response: Callback, | |
) -> Result<Payload, Error> | |
where | |
Callback: FnOnce(Response) -> CallbackReturn, | |
CallbackReturn: Future<Output = Result<Payload, Error>>, | |
{ | |
let cached = if let Ok(cached) = fs_err::tokio::read(&cache_file).await { | |
match serde_json::from_slice::<DataWithCachePolicy<Payload>>(&cached) { | |
Ok(data) => Some(data), | |
Err(err) => { | |
warn!( | |
"Broken cache entry at {}, removing: {err}", | |
cache_file.display() | |
); | |
let _ = fs_err::tokio::remove_file(&cache_file).await; | |
None | |
} | |
} | |
} else { | |
None | |
}; | |
let req = self.0.get(url.clone()).build()?; | |
let cached_response = self.send_cached(req, cached).await?; | |
match cached_response { | |
CachedResponse::FreshCache(data) => Ok(data), | |
CachedResponse::NotModified(data_with_cache_policy) => { | |
fs_err::tokio::write(cache_file, &serde_json::to_vec(&data_with_cache_policy)?) | |
.await?; | |
Ok(data_with_cache_policy.data) | |
} | |
CachedResponse::ModifiedOrNew(res, cache_policy) => { | |
let data = transform_response(res).await?; | |
if let Some(cache_policy) = cache_policy { | |
let data_with_cache_policy = DataWithCachePolicy { data, cache_policy }; | |
if let Some(parent) = cache_file.parent() { | |
fs_err::tokio::create_dir_all(parent).await?; | |
} | |
fs_err::tokio::write(cache_file, &serde_json::to_vec(&data_with_cache_policy)?) | |
.await?; | |
Ok(data_with_cache_policy.data) | |
} else { | |
Ok(data) | |
} | |
} | |
} | |
} | |
async fn send_cached<T: Serialize + DeserializeOwned>( | |
&self, | |
mut req: Request, | |
cached: Option<DataWithCachePolicy<T>>, | |
) -> Result<CachedResponse<T>, Error> { | |
// The converted types are from the specific `reqwest` types to the more generic `http` | |
// types | |
let mut converted_req = http::Request::try_from( | |
req.try_clone() | |
.expect("You can't use streaming request bodies with this function"), | |
)?; | |
let cached_response = if let Some(cached) = cached { | |
match cached | |
.cache_policy | |
.before_request(&converted_req, SystemTime::now()) | |
{ | |
BeforeRequest::Fresh(_) => { | |
trace!("Fresh cache for {}", req.url()); | |
CachedResponse::FreshCache(cached.data) | |
} | |
BeforeRequest::Stale { request, matches } => { | |
debug_assert!( | |
matches, | |
"cache doesn't match previous request for {}", | |
req.url() | |
); | |
trace!("Revalidation request for {}", req.url()); | |
for header in &request.headers { | |
req.headers_mut().insert(header.0.clone(), header.1.clone()); | |
converted_req | |
.headers_mut() | |
.insert(header.0.clone(), header.1.clone()); | |
} | |
let res = self.0.execute(req).await?.error_for_status()?; | |
let mut converted_res = http::Response::new(()); | |
*converted_res.status_mut() = res.status(); | |
for header in res.headers() { | |
converted_res.headers_mut().insert( | |
http::HeaderName::from(header.0), | |
http::HeaderValue::from(header.1), | |
); | |
} | |
let after_response = cached.cache_policy.after_response( | |
&converted_req, | |
&converted_res, | |
SystemTime::now(), | |
); | |
match after_response { | |
AfterResponse::NotModified(new_policy, _parts) => { | |
CachedResponse::NotModified(DataWithCachePolicy { | |
data: cached.data, | |
cache_policy: new_policy, | |
}) | |
} | |
AfterResponse::Modified(new_policy, _parts) => { | |
CachedResponse::ModifiedOrNew( | |
res, | |
new_policy.is_storable().then_some(new_policy), | |
) | |
} | |
} | |
} | |
} | |
} else { | |
// No reusable cache | |
trace!("{} {}", req.method(), req.url()); | |
let res = self.0.execute(req).await?.error_for_status()?; | |
let mut converted_res = http::Response::new(()); | |
*converted_res.status_mut() = res.status(); | |
for header in res.headers() { | |
converted_res.headers_mut().insert( | |
http::HeaderName::from(header.0), | |
http::HeaderValue::from(header.1), | |
); | |
} | |
let cache_policy = | |
CachePolicy::new(&converted_req.into_parts().0, &converted_res.into_parts().0); | |
CachedResponse::ModifiedOrNew(res, cache_policy.is_storable().then_some(cache_policy)) | |
}; | |
Ok(cached_response) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment