Skip to content

Instantly share code, notes, and snippets.

@konstin
Created November 8, 2023 20:31
Show Gist options
  • Save konstin/54b983e7f0f4f77d38b4151e6a9f295c to your computer and use it in GitHub Desktop.
Save konstin/54b983e7f0f4f77d38b4151e6a9f295c to your computer and use it in GitHub Desktop.
use std::future::Future;
use std::path::Path;
use std::time::SystemTime;
use http_cache_semantics::{AfterResponse, BeforeRequest, CachePolicy};
use reqwest::{Client, Request, Response};
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use tracing::{trace, warn};
use url::Url;
use crate::error::Error;
#[derive(Debug)]
enum CachedResponse<Payload: Serialize> {
/// The cached response is fresh without an HTTP request (e.g. immutable)
FreshCache(Payload),
/// The cached response is fresh after an HTTP request (e.g. 304 not modified)
NotModified(DataWithCachePolicy<Payload>),
/// There was no prior cached response or the cache was outdated
///
/// The cache policy is `None` if it isn't storable
ModifiedOrNew(Response, Option<CachePolicy>),
}
/// Serialize the actual payload together with its caching information
#[derive(Debug, Deserialize, Serialize)]
struct DataWithCachePolicy<Payload: Serialize> {
data: Payload,
cache_policy: CachePolicy,
}
#[derive(Debug, Clone)]
pub(crate) struct CachedClient(Client);
impl CachedClient {
pub(crate) fn new(client: Client) -> Self {
Self(client)
}
/// Makes a cached request with a custom data transformation
///
/// If a new response was received (no prior cached response or modified on the serde), the
/// response through `transform_response` and only the result is cached and returned
pub(crate) async fn get_transformed_cached<
Payload: Serialize + DeserializeOwned,
Callback,
CallbackReturn,
>(
&self,
url: Url,
cache_file: &Path,
transform_response: Callback,
) -> Result<Payload, Error>
where
Callback: FnOnce(Response) -> CallbackReturn,
CallbackReturn: Future<Output = Result<Payload, Error>>,
{
let cached = if let Ok(cached) = fs_err::tokio::read(&cache_file).await {
match serde_json::from_slice::<DataWithCachePolicy<Payload>>(&cached) {
Ok(data) => Some(data),
Err(err) => {
warn!(
"Broken cache entry at {}, removing: {err}",
cache_file.display()
);
let _ = fs_err::tokio::remove_file(&cache_file).await;
None
}
}
} else {
None
};
let req = self.0.get(url.clone()).build()?;
let cached_response = self.send_cached(req, cached).await?;
match cached_response {
CachedResponse::FreshCache(data) => Ok(data),
CachedResponse::NotModified(data_with_cache_policy) => {
fs_err::tokio::write(cache_file, &serde_json::to_vec(&data_with_cache_policy)?)
.await?;
Ok(data_with_cache_policy.data)
}
CachedResponse::ModifiedOrNew(res, cache_policy) => {
let data = transform_response(res).await?;
if let Some(cache_policy) = cache_policy {
let data_with_cache_policy = DataWithCachePolicy { data, cache_policy };
if let Some(parent) = cache_file.parent() {
fs_err::tokio::create_dir_all(parent).await?;
}
fs_err::tokio::write(cache_file, &serde_json::to_vec(&data_with_cache_policy)?)
.await?;
Ok(data_with_cache_policy.data)
} else {
Ok(data)
}
}
}
}
async fn send_cached<T: Serialize + DeserializeOwned>(
&self,
mut req: Request,
cached: Option<DataWithCachePolicy<T>>,
) -> Result<CachedResponse<T>, Error> {
// The converted types are from the specific `reqwest` types to the more generic `http`
// types
let mut converted_req = http::Request::try_from(
req.try_clone()
.expect("You can't use streaming request bodies with this function"),
)?;
let cached_response = if let Some(cached) = cached {
match cached
.cache_policy
.before_request(&converted_req, SystemTime::now())
{
BeforeRequest::Fresh(_) => {
trace!("Fresh cache for {}", req.url());
CachedResponse::FreshCache(cached.data)
}
BeforeRequest::Stale { request, matches } => {
debug_assert!(
matches,
"cache doesn't match previous request for {}",
req.url()
);
trace!("Revalidation request for {}", req.url());
for header in &request.headers {
req.headers_mut().insert(header.0.clone(), header.1.clone());
converted_req
.headers_mut()
.insert(header.0.clone(), header.1.clone());
}
let res = self.0.execute(req).await?.error_for_status()?;
let mut converted_res = http::Response::new(());
*converted_res.status_mut() = res.status();
for header in res.headers() {
converted_res.headers_mut().insert(
http::HeaderName::from(header.0),
http::HeaderValue::from(header.1),
);
}
let after_response = cached.cache_policy.after_response(
&converted_req,
&converted_res,
SystemTime::now(),
);
match after_response {
AfterResponse::NotModified(new_policy, _parts) => {
CachedResponse::NotModified(DataWithCachePolicy {
data: cached.data,
cache_policy: new_policy,
})
}
AfterResponse::Modified(new_policy, _parts) => {
CachedResponse::ModifiedOrNew(
res,
new_policy.is_storable().then_some(new_policy),
)
}
}
}
}
} else {
// No reusable cache
trace!("{} {}", req.method(), req.url());
let res = self.0.execute(req).await?.error_for_status()?;
let mut converted_res = http::Response::new(());
*converted_res.status_mut() = res.status();
for header in res.headers() {
converted_res.headers_mut().insert(
http::HeaderName::from(header.0),
http::HeaderValue::from(header.1),
);
}
let cache_policy =
CachePolicy::new(&converted_req.into_parts().0, &converted_res.into_parts().0);
CachedResponse::ModifiedOrNew(res, cache_policy.is_storable().then_some(cache_policy))
};
Ok(cached_response)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment