Skip to content

Instantly share code, notes, and snippets.

@gardnervickers
Last active December 15, 2019 05:10
Show Gist options
  • Save gardnervickers/db90ed7cfb6e3653bab9154f5a32c73f to your computer and use it in GitHub Desktop.
Save gardnervickers/db90ed7cfb6e3653bab9154f5a32c73f to your computer and use it in GitHub Desktop.
//! Shim to allow using Rusoto with the new Hyper
use bytes::Bytes;
use futures::future::TryFutureExt;
use futures::TryStreamExt;
use futures::{compat::*, lock::Mutex};
use futures01;
use http::header::{HeaderName, HeaderValue};
use http::{HeaderMap, Method};
use hyper;
use rusoto_core::{
self, request::HttpResponse, signature::SignedRequestPayload, HttpDispatchError,
};
use std::pin::Pin;
use std::time::Duration;
use tokio::prelude::*;
pub struct ShimmedRequestDispatcher {
client: hyper::Client<hyper_rustls::HttpsConnector<hyper::client::HttpConnector>>,
}
impl ShimmedRequestDispatcher {
pub fn new() -> Self {
let mut builder = hyper::Client::builder();
builder.keep_alive(true);
builder.retry_canceled_requests(false);
let client = builder.build(hyper_rustls::HttpsConnector::new());
Self { client }
}
}
fn from_hyper(hyper_response: hyper::Response<hyper::Body>) -> HttpResponse {
let status = hyper_response.status();
let headers = hyper_response
.headers()
.iter()
.map(|(h, v)| {
let value_string = v.to_str().unwrap().to_owned();
(h.clone(), value_string)
})
.collect();
let body = hyper_response.into_body();
let body = body.map_ok(|v| hyper::Chunk::into_bytes(v));
let body = body.map_err(|err| std::io::Error::new(std::io::ErrorKind::Other, err));
let body = rusoto_core::ByteStream::new(body.compat());
HttpResponse {
status,
headers,
body: body,
}
}
async fn dispatch(
client: hyper::Client<hyper_rustls::HttpsConnector<hyper::client::HttpConnector>>,
request: rusoto_core::signature::SignedRequest,
) -> Result<HttpResponse, HttpDispatchError> {
let hyper_method = match request.method().as_ref() {
"POST" => Method::POST,
"PUT" => Method::PUT,
"DELETE" => Method::DELETE,
"GET" => Method::GET,
"HEAD" => Method::HEAD,
v => {
return Err(HttpDispatchError::new(format!(
"Unsupported HTTP verb {}",
v
)));
}
};
let mut hyper_headers = HeaderMap::new();
for h in request.headers().iter() {
let header_name = match h.0.parse::<HeaderName>() {
Ok(name) => name,
Err(err) => {
return Err(HttpDispatchError::new(format!(
"error parsing header name: {}",
err
)));
}
};
for v in h.1.iter() {
let header_value = match HeaderValue::from_bytes(v) {
Ok(value) => value,
Err(err) => {
return Err(HttpDispatchError::new(format!(
"error parsing header value: {}",
err
)));
}
};
hyper_headers.append(&header_name, header_value);
}
}
// Add a default user-agent header if one is not already present.
if !hyper_headers.contains_key("user-agent") {
hyper_headers.insert("user-agent", "rad".parse().unwrap());
}
let mut final_uri = format!(
"{}://{}{}",
request.scheme(),
request.hostname(),
request.canonical_path()
);
if !request.canonical_query_string().is_empty() {
final_uri = final_uri + &format!("?{}", request.canonical_query_string());
}
let mut http_request_builder = hyper::Request::builder();
http_request_builder.method(hyper_method);
http_request_builder.uri(final_uri);
struct LockStream {
inner: Mutex<futures::compat::Compat01As03<rusoto_core::ByteStream>>,
}
impl Stream for LockStream {
type Item = Result<Bytes, std::io::Error>;
fn poll_next(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> futures::Poll<Option<Self::Item>> {
let lock = self.inner.lock();
futures::pin_mut!(lock);
let inner = &mut *futures::ready!(lock.poll(cx));
futures::pin_mut!(inner);
inner.poll_next(cx)
}
}
let payload = request.payload;
let body = match payload {
None => hyper::Body::empty(),
Some(SignedRequestPayload::Stream(s)) => {
let lock_stream = LockStream {
inner: Mutex::new(s.compat()),
};
hyper::Body::wrap_stream(lock_stream)
}
Some(SignedRequestPayload::Buffer(buf)) => {
let once_stream = futures::stream::once(futures::future::ready(Result::<
Bytes,
std::io::Error,
>::Ok(buf)));
hyper::Body::wrap_stream(once_stream)
}
};
let mut http_request = match http_request_builder.body(body) {
Ok(req) => req,
Err(err) => {
return Err(HttpDispatchError::new(err.to_string()));
}
};
*http_request.headers_mut() = hyper_headers;
let response = client
.request(http_request)
.await
.map_err(|e| HttpDispatchError::new(e.to_string()))?;
Ok(from_hyper(response))
}
impl rusoto_core::DispatchSignedRequest for ShimmedRequestDispatcher {
type Future =
Box<dyn futures01::Future<Item = HttpResponse, Error = HttpDispatchError> + Send + 'static>;
fn dispatch(
&self,
request: rusoto_core::signature::SignedRequest,
timeout: Option<Duration>,
) -> Self::Future {
let client = self.client.clone();
let fut = async move {
if let Some(timeout) = timeout {
match dispatch(client, request).timeout(timeout).await {
Err(timeout_err) => Err(rusoto_core::request::HttpDispatchError::new(format!(
"timeout sending request: {:?}",
timeout_err
))),
Ok(inner) => inner,
}
} else {
dispatch(client, request).await
}
};
Box::new(Box::pin(fut).compat())
}
}
pub struct ShimmedCredentialProvider {
inner: rusoto_core::credential::ChainProvider,
}
impl ShimmedCredentialProvider {
pub fn new() -> Self {
Self {
inner: rusoto_core::credential::ChainProvider::new(),
}
}
}
impl rusoto_core::credential::ProvideAwsCredentials for ShimmedCredentialProvider {
type Future = Box<
dyn futures01::Future<
Item = rusoto_core::credential::AwsCredentials,
Error = rusoto_core::credential::CredentialsError,
> + Send
+ 'static,
>;
fn credentials(&self) -> Self::Future {
Box::new(self.inner.credentials())
}
}
/// usage
fn foo () {
let provider = ShimmedCredentialProvider::new();
let dispatcher = ShimmedRequestDispatcher::new();
let region = rusoto_core::Region::UsWest2;
let logger = logger.new(o!("region" => region.name().to_owned()));
let client = S3Client::new_with(dispatcher, provider, region);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment