Skip to content

Instantly share code, notes, and snippets.

@titanous
Created September 11, 2022 18:08
Show Gist options
  • Save titanous/9f83617b77bcbd4809460702b532977d to your computer and use it in GitHub Desktop.
Save titanous/9f83617b77bcbd4809460702b532977d to your computer and use it in GitHub Desktop.
use anyhow::{anyhow, Result};
use async_trait::async_trait;
use aws_sig_auth::signer::{
self, HttpSignatureType, OperationSigningConfig, RequestConfig, SigningError,
};
use aws_smithy_http::body::SdkBody;
use aws_types::credentials::ProvideCredentials;
use aws_types::SigningService;
use lazy_static::lazy_static;
use parking_lot::RwLock;
use std::error;
use std::sync::Arc;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
#[derive(Debug)]
pub struct ConfigSource {
token_generator: TokenGenerator,
config_template: tokio_postgres::Config,
}
impl<'a> ConfigSource {
pub fn new(
mut config_template: tokio_postgres::Config,
aws: aws_config::SdkConfig,
) -> Result<Self> {
let host = if let Some(tokio_postgres::config::Host::Tcp(h)) =
config_template.get_hosts().get(0)
{
Some(h.to_owned())
} else {
None
};
config_template.ssl_mode(tokio_postgres::config::SslMode::Require);
Ok(ConfigSource {
token_generator: TokenGenerator::new(
host.ok_or_else(|| anyhow!("postgres host must be configured for IAM auth"))?,
config_template
.get_ports()
.get(0)
.ok_or_else(|| anyhow!("postgres port must be configured for IAM auth"))?
.to_owned(),
config_template
.get_user()
.ok_or_else(|| anyhow!("postgres user must be configured for IAM auth"))?
.to_string(),
aws,
),
config_template,
})
}
}
#[async_trait]
impl deadpool_postgres::ConfigSource for ConfigSource {
async fn get_config(
&self,
) -> Result<Arc<tokio_postgres::Config>, Box<dyn error::Error + Sync + Send>> {
let mut config = self.config_template.clone();
config.password(self.token_generator.token().await?);
Ok(Arc::new(config))
}
}
#[derive(Debug)]
pub struct TokenGenerator {
host: String,
port: u16,
user: String,
config: aws_config::SdkConfig,
cache: RwLock<TokenCache>,
}
#[derive(Debug)]
struct TokenCache {
token: String,
ts: SystemTime,
}
impl TokenCache {
fn get(&self, now: SystemTime) -> Option<String> {
const TOKEN_LIFETIME: Duration = Duration::from_secs(14 * 60);
if now
.duration_since(self.ts)
.expect("clock may have gone backwards")
< TOKEN_LIFETIME
{
Some(self.token.clone())
} else {
None
}
}
}
impl TokenGenerator {
pub fn new(
host: String,
port: u16,
user: String,
config: aws_config::SdkConfig,
) -> TokenGenerator {
TokenGenerator {
host,
port,
user,
config,
cache: RwLock::new(TokenCache {
token: String::default(),
ts: UNIX_EPOCH,
}),
}
}
pub async fn token(&self) -> Result<String, SigningError> {
lazy_static! {
static ref SERVICE: SigningService = SigningService::from_static("rds-db");
}
let now = SystemTime::now();
if let Some(token) = self.cache.read().get(now) {
return Ok(token);
}
let credentials = self
.config
.credentials_provider()
.ok_or_else(|| anyhow!("missing AWS credentials"))?
.provide_credentials()
.await?;
// lock the cache again and fast path if the token has already been refreshed while we were awaiting credentials
let mut cache = self.cache.write();
if let Some(token) = cache.get(now) {
return Ok(token);
}
let signer = signer::SigV4Signer::new();
let mut operation_config = OperationSigningConfig::default_config();
operation_config.signature_type = HttpSignatureType::HttpRequestQueryParams;
operation_config.expires_in = Some(Duration::from_secs(15 * 60));
let request_config = RequestConfig {
request_ts: now,
region: &self
.config
.region()
.ok_or_else(|| anyhow!("missing AWS region in config"))?
.to_owned()
.into(),
service: &SERVICE,
payload_override: None,
};
let mut request = http::Request::builder()
.uri(format!(
"http://{hostname}:{port}/?Action=connect&DBUser={username}",
hostname = &self.host,
port = &self.port,
username = &self.user
))
.body(SdkBody::empty())
.expect("valid request");
let _signature = signer.sign(
&operation_config,
&request_config,
&credentials,
&mut request,
)?;
let mut uri = request.uri().to_string();
assert!(uri.starts_with("http://"));
let uri = uri.split_off("http://".len());
*cache = TokenCache {
token: uri.clone(),
ts: now,
};
Ok(uri)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment