-
-
Save titanous/9f83617b77bcbd4809460702b532977d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use anyhow::{anyhow, Result}; | |
use async_trait::async_trait; | |
use aws_sig_auth::signer::{ | |
self, HttpSignatureType, OperationSigningConfig, RequestConfig, SigningError, | |
}; | |
use aws_smithy_http::body::SdkBody; | |
use aws_types::credentials::ProvideCredentials; | |
use aws_types::SigningService; | |
use lazy_static::lazy_static; | |
use parking_lot::RwLock; | |
use std::error; | |
use std::sync::Arc; | |
use std::time::{Duration, SystemTime, UNIX_EPOCH}; | |
#[derive(Debug)] | |
pub struct ConfigSource { | |
token_generator: TokenGenerator, | |
config_template: tokio_postgres::Config, | |
} | |
impl<'a> ConfigSource { | |
pub fn new( | |
mut config_template: tokio_postgres::Config, | |
aws: aws_config::SdkConfig, | |
) -> Result<Self> { | |
let host = if let Some(tokio_postgres::config::Host::Tcp(h)) = | |
config_template.get_hosts().get(0) | |
{ | |
Some(h.to_owned()) | |
} else { | |
None | |
}; | |
config_template.ssl_mode(tokio_postgres::config::SslMode::Require); | |
Ok(ConfigSource { | |
token_generator: TokenGenerator::new( | |
host.ok_or_else(|| anyhow!("postgres host must be configured for IAM auth"))?, | |
config_template | |
.get_ports() | |
.get(0) | |
.ok_or_else(|| anyhow!("postgres port must be configured for IAM auth"))? | |
.to_owned(), | |
config_template | |
.get_user() | |
.ok_or_else(|| anyhow!("postgres user must be configured for IAM auth"))? | |
.to_string(), | |
aws, | |
), | |
config_template, | |
}) | |
} | |
} | |
#[async_trait] | |
impl deadpool_postgres::ConfigSource for ConfigSource { | |
async fn get_config( | |
&self, | |
) -> Result<Arc<tokio_postgres::Config>, Box<dyn error::Error + Sync + Send>> { | |
let mut config = self.config_template.clone(); | |
config.password(self.token_generator.token().await?); | |
Ok(Arc::new(config)) | |
} | |
} | |
#[derive(Debug)] | |
pub struct TokenGenerator { | |
host: String, | |
port: u16, | |
user: String, | |
config: aws_config::SdkConfig, | |
cache: RwLock<TokenCache>, | |
} | |
#[derive(Debug)] | |
struct TokenCache { | |
token: String, | |
ts: SystemTime, | |
} | |
impl TokenCache { | |
fn get(&self, now: SystemTime) -> Option<String> { | |
const TOKEN_LIFETIME: Duration = Duration::from_secs(14 * 60); | |
if now | |
.duration_since(self.ts) | |
.expect("clock may have gone backwards") | |
< TOKEN_LIFETIME | |
{ | |
Some(self.token.clone()) | |
} else { | |
None | |
} | |
} | |
} | |
impl TokenGenerator { | |
pub fn new( | |
host: String, | |
port: u16, | |
user: String, | |
config: aws_config::SdkConfig, | |
) -> TokenGenerator { | |
TokenGenerator { | |
host, | |
port, | |
user, | |
config, | |
cache: RwLock::new(TokenCache { | |
token: String::default(), | |
ts: UNIX_EPOCH, | |
}), | |
} | |
} | |
pub async fn token(&self) -> Result<String, SigningError> { | |
lazy_static! { | |
static ref SERVICE: SigningService = SigningService::from_static("rds-db"); | |
} | |
let now = SystemTime::now(); | |
if let Some(token) = self.cache.read().get(now) { | |
return Ok(token); | |
} | |
let credentials = self | |
.config | |
.credentials_provider() | |
.ok_or_else(|| anyhow!("missing AWS credentials"))? | |
.provide_credentials() | |
.await?; | |
// lock the cache again and fast path if the token has already been refreshed while we were awaiting credentials | |
let mut cache = self.cache.write(); | |
if let Some(token) = cache.get(now) { | |
return Ok(token); | |
} | |
let signer = signer::SigV4Signer::new(); | |
let mut operation_config = OperationSigningConfig::default_config(); | |
operation_config.signature_type = HttpSignatureType::HttpRequestQueryParams; | |
operation_config.expires_in = Some(Duration::from_secs(15 * 60)); | |
let request_config = RequestConfig { | |
request_ts: now, | |
region: &self | |
.config | |
.region() | |
.ok_or_else(|| anyhow!("missing AWS region in config"))? | |
.to_owned() | |
.into(), | |
service: &SERVICE, | |
payload_override: None, | |
}; | |
let mut request = http::Request::builder() | |
.uri(format!( | |
"http://{hostname}:{port}/?Action=connect&DBUser={username}", | |
hostname = &self.host, | |
port = &self.port, | |
username = &self.user | |
)) | |
.body(SdkBody::empty()) | |
.expect("valid request"); | |
let _signature = signer.sign( | |
&operation_config, | |
&request_config, | |
&credentials, | |
&mut request, | |
)?; | |
let mut uri = request.uri().to_string(); | |
assert!(uri.starts_with("http://")); | |
let uri = uri.split_off("http://".len()); | |
*cache = TokenCache { | |
token: uri.clone(), | |
ts: now, | |
}; | |
Ok(uri) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment