Created
January 1, 2024 14:07
-
-
Save tmokmss/0554831387e4724cda10864ab1e89711 to your computer and use it in GitHub Desktop.
Aurora Data API Proxy with pgwire
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::fmt::Debug; | |
use std::sync::Arc; | |
use async_trait::async_trait; | |
use aws_sdk_rdsdata::error::ProvideErrorMetadata; | |
use futures::{stream, Sink, SinkExt, StreamExt}; | |
use tokio::net::TcpListener; | |
use pgwire::api::auth::noop::NoopStartupHandler; | |
use pgwire::api::portal::Portal; | |
use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler, StatementOrPortal}; | |
use pgwire::api::results::{ | |
DataRowEncoder, DescribeResponse, FieldFormat, FieldInfo, QueryResponse, Response, Tag, | |
}; | |
use pgwire::api::stmt::NoopQueryParser; | |
use pgwire::api::{ClientInfo, MakeHandler, StatelessMakeHandler, Type}; | |
use pgwire::error::ErrorInfo; | |
use pgwire::error::{PgWireError, PgWireResult}; | |
use pgwire::messages::response::NoticeResponse; | |
use pgwire::messages::PgWireBackendMessage; | |
use pgwire::tokio::process_socket; | |
use regex::Regex; | |
pub struct DummyProcessor { | |
query_parser: Arc<NoopQueryParser>, | |
rds_data_client: aws_sdk_rdsdata::Client, | |
re_position: regex::Regex, | |
re_code: regex::Regex, | |
} | |
#[async_trait] | |
impl SimpleQueryHandler for DummyProcessor { | |
async fn do_query<'a, C>( | |
&self, | |
client: &mut C, | |
query: &'a str, | |
) -> PgWireResult<Vec<Response<'a>>> | |
where | |
C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send + Sync, | |
C::Error: Debug, | |
PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>, | |
{ | |
client | |
.send(PgWireBackendMessage::NoticeResponse(NoticeResponse::from( | |
ErrorInfo::new( | |
"NOTICE".to_owned(), | |
"01000".to_owned(), | |
format!("Query received {}", query), | |
), | |
))) | |
.await?; | |
println!("simple do_query {:?}", client.metadata()); | |
println!("simple do_query {}", client.socket_addr()); | |
println!("simple do_query {}", query); | |
let resource_arn = "arn:aws:rds:xxxx"; | |
let secret_arn = | |
"arn:aws:secretsmanager:xxxx"; | |
let resp = self | |
.rds_data_client | |
.execute_statement() | |
.resource_arn(resource_arn) | |
.secret_arn(secret_arn) | |
.sql(query) | |
.include_result_metadata(true) | |
.send() | |
.await; | |
if query.to_uppercase().starts_with("SELECT") { | |
match resp { | |
Ok(res) => { | |
println!("{:?}", res.records()); | |
println!("{:?}", res.column_metadata()); | |
let schema = res.column_metadata().into_iter().map(|col| { | |
FieldInfo::new( | |
col.name().unwrap_or("unnamed column").to_string(), | |
None, | |
None, | |
match col.type_name() { | |
Some(data_type) => match data_type { | |
"int2" => Type::INT2, | |
"int4" => Type::INT4, | |
"int8" => Type::INT8, | |
"_int4" => Type::INT4_ARRAY, | |
"varchar" => Type::VARCHAR, | |
"text" => Type::TEXT, | |
"bool" => Type::BOOL, | |
"timestamptz" => Type::TIMESTAMPTZ, | |
"jsonb" => Type::JSONB, | |
"json" => Type::JSON, | |
"blob" => Type::BYTEA, | |
"float4" => Type::FLOAT4, | |
"float8" => Type::FLOAT8, | |
"date" => Type::DATE, | |
"time" => Type::TIME, | |
"timetz" => Type::TIMETZ, | |
"numeric" => Type::NUMERIC, | |
"timestamp" => Type::TIMESTAMP, | |
"bytea" => Type::BYTEA, | |
"cidr" => Type::CIDR, | |
"inet" => Type::INET, | |
"macaddr" => Type::MACADDR, | |
"uuid" => Type::UUID, | |
_ => Type::UNKNOWN, | |
}, | |
None => Type::UNKNOWN, | |
}, | |
match col.type_name() { | |
Some(data_type) => match data_type { | |
"jsonb" => FieldFormat::Binary, | |
"blob" => FieldFormat::Binary, | |
"bytea" => FieldFormat::Binary, | |
_ => FieldFormat::Text, | |
}, | |
None => FieldFormat::Text, | |
}, | |
) | |
}); | |
let schema: Arc<Vec<FieldInfo>> = Arc::new(schema.collect()); | |
let res = Arc::new(res); | |
let records = res.records().to_owned(); | |
let data: Vec<Vec<aws_sdk_rdsdata::types::Field>> = | |
records.into_iter().map(|r| r).collect(); | |
let schema_ref = schema.clone(); | |
let data_row_stream = | |
stream::iter(data.clone()).enumerate().map(move |(i, r)| { | |
// let sc = schema[i]; | |
let mut encoder = DataRowEncoder::new(schema_ref.clone()); | |
r.into_iter().for_each(|col| { | |
if col.is_array_value() { | |
let arr = col.as_array_value().unwrap(); | |
if arr.is_boolean_values() { | |
encoder | |
.encode_field(arr.as_boolean_values().unwrap()) | |
.unwrap(); | |
} else if arr.is_double_values() { | |
encoder | |
.encode_field(arr.as_double_values().unwrap()) | |
.unwrap(); | |
} else if arr.is_long_values() { | |
encoder | |
.encode_field(arr.as_long_values().unwrap()) | |
.unwrap(); | |
} else if arr.is_string_values() { | |
encoder | |
.encode_field(arr.as_string_values().unwrap()) | |
.unwrap(); | |
} else if arr.is_unknown() { | |
encoder | |
.encode_field(arr.as_string_values().unwrap()) | |
.unwrap(); | |
} | |
} else if col.is_string_value() { | |
encoder | |
.encode_field(col.as_string_value().unwrap()) | |
.unwrap(); | |
} else if col.is_double_value() { | |
encoder | |
.encode_field(col.as_double_value().unwrap()) | |
.unwrap(); | |
} else if col.is_boolean_value() { | |
encoder | |
.encode_field(col.as_boolean_value().unwrap()) | |
.unwrap(); | |
} else if col.is_long_value() { | |
encoder.encode_field(col.as_long_value().unwrap()).unwrap(); | |
} else if col.is_is_null() { | |
encoder.encode_field(col.as_is_null().unwrap()).unwrap(); | |
} | |
// else if col.is_blob_value() { | |
// encoder | |
// .encode_field(col.as_blob_value().unwrap().into_inner()) | |
// .unwrap(); | |
// } | |
else { | |
println!("unknown {:?}", col) | |
} | |
}); | |
encoder.finish() | |
}); | |
Ok(vec![Response::Query(QueryResponse::new( | |
schema, | |
data_row_stream, | |
))]) | |
} | |
Err(e) => { | |
println!("{:?}", e); | |
println!("{:?}", e.message()); | |
let error_info = match e.message() { | |
Some(message) => { | |
let pos = self.re_position.find(message).unwrap().as_str(); | |
let code = self.re_code.find(message).unwrap().as_str(); | |
let mut info = ErrorInfo::new( | |
"ERROR".to_string(), | |
code.to_string(), | |
message.to_string(), | |
); | |
info.set_position(Some(pos.to_string())); | |
info | |
} | |
None => ErrorInfo::new( | |
"ERROR".to_string(), | |
"code".to_string(), | |
"unknown error".to_string(), | |
), | |
}; | |
Err(PgWireError::UserError(Box::new(error_info))) | |
} | |
} | |
} else { | |
match resp { | |
Ok(res) => { | |
println!("{:?}", res.records()); | |
println!("{:?}", res.column_metadata()); | |
println!("{:?}", res.number_of_records_updated()); | |
Ok(vec![Response::Execution(Tag::new_for_execution( | |
"OK", | |
Some(res.number_of_records_updated().try_into().unwrap()), | |
))]) | |
} | |
Err(e) => { | |
println!("{:?}", e); | |
println!("{:?}", e.message()); | |
let error_info = match e.message() { | |
Some(message) => { | |
let pos = &self.re_position.captures(message).unwrap()[1]; | |
let code = &self.re_code.captures(message).unwrap()[1]; | |
let mut info = ErrorInfo::new( | |
"ERROR".to_string(), | |
code.to_string(), | |
message.to_string(), | |
); | |
info.set_position(Some(pos.to_string())); | |
info | |
} | |
None => ErrorInfo::new( | |
"ERROR".to_string(), | |
"code".to_string(), | |
"unknown error".to_string(), | |
), | |
}; | |
Err(PgWireError::UserError(Box::new(error_info))) | |
} | |
} | |
} | |
} | |
} | |
#[async_trait] | |
impl ExtendedQueryHandler for DummyProcessor { | |
type Statement = String; | |
type QueryParser = NoopQueryParser; | |
fn query_parser(&self) -> Arc<Self::QueryParser> { | |
self.query_parser.clone() | |
} | |
async fn do_query<'a, C>( | |
&self, | |
_client: &mut C, | |
portal: &'a Portal<Self::Statement>, | |
_max_rows: usize, | |
) -> PgWireResult<Response<'a>> | |
where | |
C: ClientInfo + Unpin + Send + Sync, | |
{ | |
println!("extended do_query {}", _client.socket_addr()); | |
println!("extended do_query {}", portal.statement().statement()); | |
return Err(PgWireError::UserError(Box::new(ErrorInfo::new( | |
"ERROR".to_string(), | |
"code".to_string(), | |
"Extended query is not supported".to_string(), | |
)))); | |
let query = portal.statement().statement(); | |
if query.to_uppercase().starts_with("SELECT") { | |
let f1 = FieldInfo::new("id".into(), None, None, Type::INT4, FieldFormat::Text); | |
let f2 = FieldInfo::new("name".into(), None, None, Type::VARCHAR, FieldFormat::Text); | |
let schema = Arc::new(vec![f1, f2]); | |
let data = vec![ | |
(Some(0), Some("Tom")), | |
(Some(1), Some("Jerry")), | |
(Some(2), None), | |
]; | |
let schema_ref = schema.clone(); | |
let data_row_stream = stream::iter(data.into_iter()).map(move |r| { | |
let mut encoder = DataRowEncoder::new(schema_ref.clone()); | |
encoder.encode_field(&r.0)?; | |
encoder.encode_field(&r.1)?; | |
encoder.finish() | |
}); | |
Ok(Response::Query(QueryResponse::new(schema, data_row_stream))) | |
} else { | |
Ok(Response::Execution(Tag::new_for_execution("OK", Some(1)))) | |
} | |
} | |
async fn do_describe<C>( | |
&self, | |
_client: &mut C, | |
target: StatementOrPortal<'_, Self::Statement>, | |
) -> PgWireResult<DescribeResponse> | |
where | |
C: ClientInfo + Unpin + Send + Sync, | |
{ | |
println!("extended do_describe {}", _client.socket_addr()); | |
println!("extended do_describe {:?}", target); | |
return Err(PgWireError::UserError(Box::new(ErrorInfo::new( | |
"ERROR".to_string(), | |
"code".to_string(), | |
"Extended query is not supported".to_string(), | |
)))); | |
match target { | |
StatementOrPortal::Statement(stmt) => Ok(DescribeResponse::new(None, [].to_vec())), | |
StatementOrPortal::Portal(portal) => Ok(DescribeResponse::new(None, [].to_vec())), | |
} | |
} | |
} | |
#[tokio::main] | |
pub async fn main() { | |
let config = aws_config::from_env().region("us-east-1").load().await; | |
let client = aws_sdk_rdsdata::Client::new(&config); | |
let processor = Arc::new(StatelessMakeHandler::new(Arc::new(DummyProcessor { | |
query_parser: Arc::new(NoopQueryParser::new()), | |
rds_data_client: client, | |
re_position: Regex::new(r"Position: (\d+)").unwrap(), | |
re_code: Regex::new(r"SQLState: ([0-9A-Z]+)").unwrap(), | |
}))); | |
let authenticator = Arc::new(StatelessMakeHandler::new(Arc::new(NoopStartupHandler))); | |
let server_addr = "127.0.0.1:5432"; | |
let listener = TcpListener::bind(server_addr).await.unwrap(); | |
println!("Listening to {}", server_addr); | |
loop { | |
let incoming_socket = listener.accept().await.unwrap(); | |
let authenticator_ref = authenticator.make(); | |
let processor_ref = processor.make(); | |
tokio::spawn(async move { | |
process_socket( | |
incoming_socket.0, | |
None, | |
authenticator_ref, | |
processor_ref.clone(), | |
processor_ref, | |
) | |
.await | |
}); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment