Skip to content

Instantly share code, notes, and snippets.

@tmokmss
Created January 1, 2024 14:07
Show Gist options
  • Save tmokmss/0554831387e4724cda10864ab1e89711 to your computer and use it in GitHub Desktop.
Save tmokmss/0554831387e4724cda10864ab1e89711 to your computer and use it in GitHub Desktop.
Aurora Data API Proxy with pgwire
use std::fmt::Debug;
use std::sync::Arc;
use async_trait::async_trait;
use aws_sdk_rdsdata::error::ProvideErrorMetadata;
use futures::{stream, Sink, SinkExt, StreamExt};
use tokio::net::TcpListener;
use pgwire::api::auth::noop::NoopStartupHandler;
use pgwire::api::portal::Portal;
use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler, StatementOrPortal};
use pgwire::api::results::{
DataRowEncoder, DescribeResponse, FieldFormat, FieldInfo, QueryResponse, Response, Tag,
};
use pgwire::api::stmt::NoopQueryParser;
use pgwire::api::{ClientInfo, MakeHandler, StatelessMakeHandler, Type};
use pgwire::error::ErrorInfo;
use pgwire::error::{PgWireError, PgWireResult};
use pgwire::messages::response::NoticeResponse;
use pgwire::messages::PgWireBackendMessage;
use pgwire::tokio::process_socket;
use regex::Regex;
pub struct DummyProcessor {
query_parser: Arc<NoopQueryParser>,
rds_data_client: aws_sdk_rdsdata::Client,
re_position: regex::Regex,
re_code: regex::Regex,
}
#[async_trait]
impl SimpleQueryHandler for DummyProcessor {
async fn do_query<'a, C>(
&self,
client: &mut C,
query: &'a str,
) -> PgWireResult<Vec<Response<'a>>>
where
C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
C::Error: Debug,
PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
{
client
.send(PgWireBackendMessage::NoticeResponse(NoticeResponse::from(
ErrorInfo::new(
"NOTICE".to_owned(),
"01000".to_owned(),
format!("Query received {}", query),
),
)))
.await?;
println!("simple do_query {:?}", client.metadata());
println!("simple do_query {}", client.socket_addr());
println!("simple do_query {}", query);
let resource_arn = "arn:aws:rds:xxxx";
let secret_arn =
"arn:aws:secretsmanager:xxxx";
let resp = self
.rds_data_client
.execute_statement()
.resource_arn(resource_arn)
.secret_arn(secret_arn)
.sql(query)
.include_result_metadata(true)
.send()
.await;
if query.to_uppercase().starts_with("SELECT") {
match resp {
Ok(res) => {
println!("{:?}", res.records());
println!("{:?}", res.column_metadata());
let schema = res.column_metadata().into_iter().map(|col| {
FieldInfo::new(
col.name().unwrap_or("unnamed column").to_string(),
None,
None,
match col.type_name() {
Some(data_type) => match data_type {
"int2" => Type::INT2,
"int4" => Type::INT4,
"int8" => Type::INT8,
"_int4" => Type::INT4_ARRAY,
"varchar" => Type::VARCHAR,
"text" => Type::TEXT,
"bool" => Type::BOOL,
"timestamptz" => Type::TIMESTAMPTZ,
"jsonb" => Type::JSONB,
"json" => Type::JSON,
"blob" => Type::BYTEA,
"float4" => Type::FLOAT4,
"float8" => Type::FLOAT8,
"date" => Type::DATE,
"time" => Type::TIME,
"timetz" => Type::TIMETZ,
"numeric" => Type::NUMERIC,
"timestamp" => Type::TIMESTAMP,
"bytea" => Type::BYTEA,
"cidr" => Type::CIDR,
"inet" => Type::INET,
"macaddr" => Type::MACADDR,
"uuid" => Type::UUID,
_ => Type::UNKNOWN,
},
None => Type::UNKNOWN,
},
match col.type_name() {
Some(data_type) => match data_type {
"jsonb" => FieldFormat::Binary,
"blob" => FieldFormat::Binary,
"bytea" => FieldFormat::Binary,
_ => FieldFormat::Text,
},
None => FieldFormat::Text,
},
)
});
let schema: Arc<Vec<FieldInfo>> = Arc::new(schema.collect());
let res = Arc::new(res);
let records = res.records().to_owned();
let data: Vec<Vec<aws_sdk_rdsdata::types::Field>> =
records.into_iter().map(|r| r).collect();
let schema_ref = schema.clone();
let data_row_stream =
stream::iter(data.clone()).enumerate().map(move |(i, r)| {
// let sc = schema[i];
let mut encoder = DataRowEncoder::new(schema_ref.clone());
r.into_iter().for_each(|col| {
if col.is_array_value() {
let arr = col.as_array_value().unwrap();
if arr.is_boolean_values() {
encoder
.encode_field(arr.as_boolean_values().unwrap())
.unwrap();
} else if arr.is_double_values() {
encoder
.encode_field(arr.as_double_values().unwrap())
.unwrap();
} else if arr.is_long_values() {
encoder
.encode_field(arr.as_long_values().unwrap())
.unwrap();
} else if arr.is_string_values() {
encoder
.encode_field(arr.as_string_values().unwrap())
.unwrap();
} else if arr.is_unknown() {
encoder
.encode_field(arr.as_string_values().unwrap())
.unwrap();
}
} else if col.is_string_value() {
encoder
.encode_field(col.as_string_value().unwrap())
.unwrap();
} else if col.is_double_value() {
encoder
.encode_field(col.as_double_value().unwrap())
.unwrap();
} else if col.is_boolean_value() {
encoder
.encode_field(col.as_boolean_value().unwrap())
.unwrap();
} else if col.is_long_value() {
encoder.encode_field(col.as_long_value().unwrap()).unwrap();
} else if col.is_is_null() {
encoder.encode_field(col.as_is_null().unwrap()).unwrap();
}
// else if col.is_blob_value() {
// encoder
// .encode_field(col.as_blob_value().unwrap().into_inner())
// .unwrap();
// }
else {
println!("unknown {:?}", col)
}
});
encoder.finish()
});
Ok(vec![Response::Query(QueryResponse::new(
schema,
data_row_stream,
))])
}
Err(e) => {
println!("{:?}", e);
println!("{:?}", e.message());
let error_info = match e.message() {
Some(message) => {
let pos = self.re_position.find(message).unwrap().as_str();
let code = self.re_code.find(message).unwrap().as_str();
let mut info = ErrorInfo::new(
"ERROR".to_string(),
code.to_string(),
message.to_string(),
);
info.set_position(Some(pos.to_string()));
info
}
None => ErrorInfo::new(
"ERROR".to_string(),
"code".to_string(),
"unknown error".to_string(),
),
};
Err(PgWireError::UserError(Box::new(error_info)))
}
}
} else {
match resp {
Ok(res) => {
println!("{:?}", res.records());
println!("{:?}", res.column_metadata());
println!("{:?}", res.number_of_records_updated());
Ok(vec![Response::Execution(Tag::new_for_execution(
"OK",
Some(res.number_of_records_updated().try_into().unwrap()),
))])
}
Err(e) => {
println!("{:?}", e);
println!("{:?}", e.message());
let error_info = match e.message() {
Some(message) => {
let pos = &self.re_position.captures(message).unwrap()[1];
let code = &self.re_code.captures(message).unwrap()[1];
let mut info = ErrorInfo::new(
"ERROR".to_string(),
code.to_string(),
message.to_string(),
);
info.set_position(Some(pos.to_string()));
info
}
None => ErrorInfo::new(
"ERROR".to_string(),
"code".to_string(),
"unknown error".to_string(),
),
};
Err(PgWireError::UserError(Box::new(error_info)))
}
}
}
}
}
#[async_trait]
impl ExtendedQueryHandler for DummyProcessor {
type Statement = String;
type QueryParser = NoopQueryParser;
fn query_parser(&self) -> Arc<Self::QueryParser> {
self.query_parser.clone()
}
async fn do_query<'a, C>(
&self,
_client: &mut C,
portal: &'a Portal<Self::Statement>,
_max_rows: usize,
) -> PgWireResult<Response<'a>>
where
C: ClientInfo + Unpin + Send + Sync,
{
println!("extended do_query {}", _client.socket_addr());
println!("extended do_query {}", portal.statement().statement());
return Err(PgWireError::UserError(Box::new(ErrorInfo::new(
"ERROR".to_string(),
"code".to_string(),
"Extended query is not supported".to_string(),
))));
let query = portal.statement().statement();
if query.to_uppercase().starts_with("SELECT") {
let f1 = FieldInfo::new("id".into(), None, None, Type::INT4, FieldFormat::Text);
let f2 = FieldInfo::new("name".into(), None, None, Type::VARCHAR, FieldFormat::Text);
let schema = Arc::new(vec![f1, f2]);
let data = vec![
(Some(0), Some("Tom")),
(Some(1), Some("Jerry")),
(Some(2), None),
];
let schema_ref = schema.clone();
let data_row_stream = stream::iter(data.into_iter()).map(move |r| {
let mut encoder = DataRowEncoder::new(schema_ref.clone());
encoder.encode_field(&r.0)?;
encoder.encode_field(&r.1)?;
encoder.finish()
});
Ok(Response::Query(QueryResponse::new(schema, data_row_stream)))
} else {
Ok(Response::Execution(Tag::new_for_execution("OK", Some(1))))
}
}
async fn do_describe<C>(
&self,
_client: &mut C,
target: StatementOrPortal<'_, Self::Statement>,
) -> PgWireResult<DescribeResponse>
where
C: ClientInfo + Unpin + Send + Sync,
{
println!("extended do_describe {}", _client.socket_addr());
println!("extended do_describe {:?}", target);
return Err(PgWireError::UserError(Box::new(ErrorInfo::new(
"ERROR".to_string(),
"code".to_string(),
"Extended query is not supported".to_string(),
))));
match target {
StatementOrPortal::Statement(stmt) => Ok(DescribeResponse::new(None, [].to_vec())),
StatementOrPortal::Portal(portal) => Ok(DescribeResponse::new(None, [].to_vec())),
}
}
}
#[tokio::main]
pub async fn main() {
let config = aws_config::from_env().region("us-east-1").load().await;
let client = aws_sdk_rdsdata::Client::new(&config);
let processor = Arc::new(StatelessMakeHandler::new(Arc::new(DummyProcessor {
query_parser: Arc::new(NoopQueryParser::new()),
rds_data_client: client,
re_position: Regex::new(r"Position: (\d+)").unwrap(),
re_code: Regex::new(r"SQLState: ([0-9A-Z]+)").unwrap(),
})));
let authenticator = Arc::new(StatelessMakeHandler::new(Arc::new(NoopStartupHandler)));
let server_addr = "127.0.0.1:5432";
let listener = TcpListener::bind(server_addr).await.unwrap();
println!("Listening to {}", server_addr);
loop {
let incoming_socket = listener.accept().await.unwrap();
let authenticator_ref = authenticator.make();
let processor_ref = processor.make();
tokio::spawn(async move {
process_socket(
incoming_socket.0,
None,
authenticator_ref,
processor_ref.clone(),
processor_ref,
)
.await
});
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment