Skip to content

Instantly share code, notes, and snippets.

@Termina1
Last active November 26, 2021 14:11
Show Gist options
  • Save Termina1/6ea4b00c63743f5cebb79f9e6b42b72c to your computer and use it in GitHub Desktop.
Save Termina1/6ea4b00c63743f5cebb79f9e6b42b72c to your computer and use it in GitHub Desktop.
Aoc FdW on pgx
use serde::Deserialize;
use std::{collections::HashMap, result::*};
pub struct AocOptions {
pub api: &'static str,
pub cookie: &'static str,
}
pub fn deserialize_ts<'de, D>(deserializer: D) -> Result<i32, D::Error>
where
D: serde::Deserializer<'de>,
{
i32::deserialize(deserializer).or_else(|_| Ok(0))
}
#[derive(Deserialize, Debug)]
pub struct AocMember {
pub id: String,
pub name: String,
pub completion_day_level: serde_json::Value,
#[serde(deserialize_with = "deserialize_ts")]
// if last_star_ts is 0 for some reason api returns string
pub last_star_ts: i32,
pub stars: i32,
pub global_score: i32,
pub local_score: i32,
}
#[derive(Deserialize, Debug)]
pub struct AocResponse {
owner_id: String,
event: String,
pub members: HashMap<String, AocMember>,
}
pub fn get_aoc_leader_board(server: &AocOptions) -> Result<AocResponse, String> {
let client = reqwest::blocking::Client::new();
client
.get(server.api)
.header("Cookie", server.cookie)
.send()
.map_err(|err| err.to_string())
.and_then(|resp| match resp.status() {
reqwest::StatusCode::OK => {
let text = resp.text().unwrap();
let json: AocResponse = serde_json::from_str(&text).map_err(|err| err.to_string())?;
Ok(json)
}
_ => Err("Invalid status code".to_string()),
})
}
use aoc::AocMember;
use pgx::*;
use std::collections::HashMap;
use std::panic;
use std::*;
mod aoc;
const SERVER_OPTIONS: [&str; 2] = ["api", "cookie"];
// const FOREIGN_TABLE_OID: u32 = 3118;
const FOREIGN_SERVER_OID: u32 = 1417;
pg_module_magic!();
#[allow(non_camel_case_types)]
struct fdw_handler {
handler: PgBox<pg_sys::FdwRoutine, AllocatedByRust>,
}
// This is a hack to add fdw_handler support for pgx
impl IntoDatum for fdw_handler {
fn into_datum(self) -> Option<pg_sys::Datum> {
Some(self.handler.into_pg() as pg_sys::Datum)
}
fn type_oid() -> u32 {
pg_sys::NodeTag_T_FdwRoutine
}
}
extension_sql!("", name = "fdw_handler", creates = [Type(fdw_handler)]);
#[pg_guard]
pub unsafe extern "C" fn get_rel_size(
_root: *mut pg_sys::PlannerInfo,
_baserel: *mut pg_sys::RelOptInfo,
_tableoid: pg_sys::Oid,
) {
}
#[pg_guard]
pub unsafe extern "C" fn get_foreign_plan(
_root: *mut pg_sys::PlannerInfo,
baserel: *mut pg_sys::RelOptInfo,
_tableoid: pg_sys::Oid,
_best_path: *mut pg_sys::ForeignPath,
tlist: *mut pg_sys::List,
scan_clauses: *mut pg_sys::List,
outer_plan: *mut pg_sys::Plan,
) -> *mut pg_sys::ForeignScan {
let scan_clauses = pg_sys::extract_actual_clauses(scan_clauses, false);
pg_sys::make_foreignscan(
tlist,
scan_clauses,
(*baserel).relid,
std::ptr::null_mut(),
std::ptr::null_mut(),
std::ptr::null_mut(),
std::ptr::null_mut(),
outer_plan,
)
}
#[pg_guard]
pub unsafe extern "C" fn import_schema(
stmt: *mut pg_sys::ImportForeignSchemaStmt,
_oid: pg_sys::Oid,
) -> *mut pg_sys::List {
let mut list = PgList::new();
let server_name = ffi::CStr::from_ptr((*stmt).server_name)
.to_str()
.unwrap_or("");
let local_schema = ffi::CStr::from_ptr((*stmt).local_schema)
.to_str()
.unwrap_or("");
let table_creation = format!(
"
CREATE FOREIGN TABLE {}.aoc_members (\n\
id text not null,\n\
name text not null,\n\
completion_day_level jsonb,\n\
last_star_ts int,\n\
stars int,\n\
global_score int,\n\
local_score int)\n\
SERVER {}",
local_schema, server_name,
);
list.push(ffi::CString::new(table_creation).unwrap().into_raw());
list.into_pg()
}
#[pg_guard]
pub unsafe extern "C" fn get_foreign_paths(
root: *mut pg_sys::PlannerInfo,
baserel: *mut pg_sys::RelOptInfo,
_table_id: pg_sys::Oid,
) {
let path = pg_sys::create_foreignscan_path(
root,
baserel,
ptr::null_mut(), /* default pathtarget */
(*baserel).rows,
10.0,
100.0,
ptr::null_mut(), /* no pathkeys */
(*baserel).lateral_relids,
ptr::null_mut(), /* no extra plan */
ptr::null_mut(),
);
pg_sys::add_path(baserel, path as *mut pg_sys::Path);
}
#[pg_extern]
fn aoc_fdw_handler() -> fdw_handler {
let mut handler = PgBox::<pg_sys::FdwRoutine>::alloc_node(pg_sys::NodeTag_T_FdwRoutine);
handler.ImportForeignSchema = Some(import_schema);
handler.GetForeignRelSize = Some(get_rel_size);
handler.GetForeignPlan = Some(get_foreign_plan);
handler.BeginForeignScan = Some(begin_foreign_scan);
handler.EndForeignScan = Some(end_foreign_scan);
handler.ReScanForeignScan = Some(rescan_foreign_scan);
handler.GetForeignPaths = Some(get_foreign_paths);
handler.IterateForeignScan = Some(iterate_foreign_scan);
fdw_handler { handler }
}
unsafe fn get_server_options(table_id: u32) -> aoc::AocOptions {
let table = pg_sys::GetForeignTable(table_id);
let server = pg_sys::GetForeignServer((*table).serverid);
let server_options: PgList<pg_sys::DefElem> = PgList::from_pg((*server).options);
let options: HashMap<String, *mut pg_sys::DefElem> =
server_options
.iter_ptr()
.fold(HashMap::new(), |mut acc, opt| {
let name = ffi::CStr::from_ptr((*opt).defname)
.to_str()
.unwrap_or_else(|_| error!("Corrupted name of server option name"));
acc.insert(name.to_string(), opt);
acc
});
options
.get("api")
.and_then(|api| {
options.get("cookie").and_then(|cookie| {
let api_str = ffi::CStr::from_ptr(pg_sys::defGetString(*api))
.to_str()
.unwrap_or_else(|_| error!("Couldn't get api option"));
let cookie_str = ffi::CStr::from_ptr(pg_sys::defGetString(*cookie))
.to_str()
.unwrap_or_else(|_| error!("Couldn't get server option"));
Some(aoc::AocOptions {
api: api_str,
cookie: cookie_str,
})
})
})
.unwrap_or_else(|| error!("Couldn't retrieve server options"))
}
struct FdwState {
iter: *mut pg_sys::List,
}
#[pg_guard]
pub unsafe extern "C" fn begin_foreign_scan(node: *mut pg_sys::ForeignScanState, eflags: i32) {
debug5!("start scan");
// Do nothin in expain and analyze case
if eflags & (pg_sys::EXEC_FLAG_EXPLAIN_ONLY as i32) != 0 {
return;
}
let scan_state = (*node).ss;
let table_id = (*scan_state.ss_currentRelation).rd_id;
let aoc_options = get_server_options(table_id);
let aoc_result = aoc::get_aoc_leader_board(&aoc_options);
let aoc = match aoc_result {
Ok(result) => result,
Err(err) => error!("Error during http request: {}", err),
};
let mut state = PgBox::<FdwState>::alloc0();
let mut list = PgList::new();
for value in aoc.members.values() {
let mut t = PgBox::<AocMember>::alloc0();
t.id = value.id.clone();
t.completion_day_level = value.completion_day_level.clone();
t.name = value.name.clone();
t.last_star_ts = value.last_star_ts.clone();
t.stars = value.stars;
t.global_score = value.global_score;
t.local_score = value.local_score;
list.push(t.into_pg())
}
state.iter = list.into_pg();
save_state(node, state);
}
unsafe fn exec_clear_tuple(slot: *mut pg_sys::TupleTableSlot) {
let ops = *(*slot).tts_ops;
match ops.clear {
None => (),
Some(func) => func(slot),
}
}
unsafe fn get_state(node: *mut pg_sys::ForeignScanState) -> PgBox<FdwState> {
PgBox::from_pg((*node).fdw_state as *mut FdwState)
}
unsafe fn save_state(node: *mut pg_sys::ForeignScanState, state: PgBox<FdwState, AllocatedByRust>) {
(*node).fdw_state = state.into_pg() as *mut ffi::c_void;
}
#[pg_guard]
pub unsafe extern "C" fn iterate_foreign_scan(
node: *mut pg_sys::ForeignScanState,
) -> *mut pg_sys::TupleTableSlot {
let slot = (*node).ss.ss_ScanTupleSlot;
exec_clear_tuple(slot);
let mut state: PgBox<FdwState> = get_state(node);
let mut list: PgList<AocMember> = PgList::from_pg(state.iter);
if list.len() <= 0 {
return slot;
}
let row = list.pop();
let aoc_row: PgBox<AocMember> = match row {
None => return slot,
Some(row) => PgBox::from_pg(row),
};
let rel = PgRelation::from_pg((*node).ss.ss_currentRelation);
let tuple_desc = rel.tuple_desc();
let mut is_null = vec![false; tuple_desc.len()];
let mut data = vec![0; tuple_desc.len()];
let mut i = 0;
for attr in tuple_desc.iter() {
let val = (match name_data_to_str(&attr.attname) {
"id" => aoc_row.id.clone().into_datum(),
"name" => aoc_row.name.clone().into_datum(),
"completion_day_level" => JsonB(aoc_row.completion_day_level.clone()).into_datum(),
"last_star_ts" => aoc_row.last_star_ts.clone().into_datum(),
"stars" => aoc_row.stars.into_datum(),
"global_score" => aoc_row.global_score.into_datum(),
"local_score" => aoc_row.local_score.into_datum(),
n => error!("Unexpected column {}", n),
})
.unwrap_or_else(|| error!("Unable convert into datum"));
data[i] = val;
i = i + 1;
}
let htuple =
pg_sys::heap_form_tuple(tuple_desc.as_ptr(), data.as_mut_ptr(), is_null.as_mut_ptr());
pg_sys::ExecStoreHeapTuple(htuple, slot, false);
state.iter = list.into_pg();
return slot;
}
#[pg_guard]
pub unsafe extern "C" fn end_foreign_scan(node: *mut pg_sys::ForeignScanState) {
debug5!("end scan");
pg_sys::pfree((*node).fdw_state);
(*node).fdw_state = ptr::null_mut();
}
#[pg_guard]
pub unsafe extern "C" fn rescan_foreign_scan(node: *mut pg_sys::ForeignScanState) {
debug5!("re scan");
begin_foreign_scan(node, 0);
}
#[pg_extern]
unsafe fn aoc_fdw_validator(args: Vec<Option<&str>>, oid: pg_sys::Oid) {
if oid != FOREIGN_SERVER_OID && args.len() > 0 {
panic!("Unexpected arguments")
}
let named_params = args
.clone()
.into_iter()
.map(|v| match v {
Some(x) => {
let mut t = x.split("=");
(
t.next().expect("Missing name"),
t.next().expect("Missing value"),
)
}
None => ("", ""),
})
.collect::<Vec<(&str, &str)>>();
for (name, _) in named_params {
if !SERVER_OPTIONS.contains(&name) {
error!("Unexpected argument {} for AoC table", name);
}
}
}
extension_sql!(
"\n\
CREATE FOREIGN DATA WRAPPER aoc_fdw\n\
HANDLER aoc_fdw_handler
VALIDATOR aoc_fdw_validator; \n\
",
name = "create_fdw",
requires = [aoc_fdw_handler, aoc_fdw_validator]
);
#[cfg(any(test, feature = "pg_test"))]
#[pg_schema]
mod tests {
use pgx::*;
use rand::distributions::Alphanumeric;
use rand::{thread_rng, Rng};
#[pg_test]
fn test_select() {
let server: String = thread_rng()
.sample_iter(&Alphanumeric)
.map(char::from)
.filter(|t| t.is_alphabetic())
.take(7)
.collect();
pgx::Spi::run(format!(
r#"create server {} foreign data wrapper aoc_fdw OPTIONS (api 'https://adventofcode.com/2020/leaderboard/private/view/657296.json')"#,
server
).as_str());
pgx::Spi::run(r#"CREATE SCHEMA mytest;"#);
pgx::Spi::run(
format!(
r#"IMPORT FOREIGN SCHEMA aoc FROM SERVER {} INTO mytest"#,
server
)
.as_str(),
);
let count: Option<i32> = pgx::Spi::get_one(r#"SELECT count(*) FROM mytest.aoc_members"#);
match count {
None => error!("No result from query"),
Some(count) => assert_eq!(count, 30),
};
}
}
#[cfg(test)]
pub mod pg_test {
pub fn setup(_options: Vec<&str>) {}
pub fn postgresql_conf_options() -> Vec<&'static str> {
// return any postgresql.conf settings that are required for your tests
vec![]
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment