Skip to content

Instantly share code, notes, and snippets.

@WolfEYc
Created March 12, 2024 17:36
Show Gist options
  • Save WolfEYc/f6ce1e947372742bf1108290447e5b0a to your computer and use it in GitHub Desktop.
Save WolfEYc/f6ce1e947372742bf1108290447e5b0a to your computer and use it in GitHub Desktop.
Polars Oracle Helpers
import asyncio
import re
from contextlib import asynccontextmanager
from functools import reduce
from itertools import chain
from operator import itemgetter
from typing import Any, Coroutine, Iterable, NamedTuple, Optional
import oracledb
import polars as pl
from mirouteapi.connections.oracle import ENV, MRTE_ORACLE_PWD
from mirouteapi.logger import LOGGER
class ListReplacements(NamedTuple):
query: str
new_kwargs: dict[str, Any]
def oracle_arraytype(lst: list):
if len(lst) == 0 or isinstance(lst[0], str):
return ORACLE_TYPES["string"]
if isinstance(lst[0], int):
return ORACLE_TYPES["number"]
if isinstance(lst[0], bytes):
return ORACLE_TYPES["bytes"]
return ORACLE_TYPES["string"]
def python_list_to_oracle_array(lst: list):
arr_type = oracle_arraytype(lst)
return arr_type.newobject(lst)
def replace_query_list(query: str, list_key: str) -> str:
return query.replace(f":{list_key}", f"(SELECT * FROM TABLE(:{list_key}))")
def replace_query_lists(query: str, list_keys: Iterable[str]) -> str:
return reduce(replace_query_list, list_keys, query)
def replace_lists_and_query(query: str, kwargs: dict[str, Any]) -> ListReplacements:
list_kwargs = filter(lambda item: isinstance(item[1], list), kwargs.items())
non_list_kwargs = filter(lambda item: not isinstance(item[1], list), kwargs.items())
list_kwargs = list(list_kwargs)
list_keys: map[str] = map(itemgetter(0), list_kwargs)
query = replace_query_lists(query, list_keys)
list_kwargs = map(
lambda item: (item[0], python_list_to_oracle_array(item[1])), list_kwargs
)
new_kwargs = chain(non_list_kwargs, list_kwargs)
new_kwargs = dict(new_kwargs)
return ListReplacements(query, new_kwargs)
def replace_lists(kwargs: dict[str, Any]):
list_kwargs = filter(lambda item: isinstance(item[1], list), kwargs.items())
list_kwargs = map(
lambda item: (item[0], python_list_to_oracle_array(item[1])), list_kwargs
)
kwargs.update(list_kwargs)
return kwargs
LIMIT_REGEX = re.compile(r"LIMIT\s+(\d+)", re.IGNORECASE)
def limit_replacement(match: re.Match):
return f"FETCH NEXT {match.group(1)} ROWS ONLY"
def replace_limit_sql(query: str) -> str:
return LIMIT_REGEX.sub(limit_replacement, query)
async def cursor_to_df(cursor: oracledb.AsyncCursor, to_lower: bool) -> pl.DataFrame:
columns = map(itemgetter(0), cursor.description)
data = await cursor.fetchall()
if to_lower:
columns = map(str.lower, columns)
columns = list(columns)
cursor.close()
return pl.DataFrame(data, schema=columns, orient="row")
async def fetch(
conn: oracledb.AsyncConnection,
query: str,
*,
schema_overrides: Optional[dict] = None,
to_lower: bool = True,
**kwargs,
) -> pl.DataFrame:
query, kwargs = replace_lists_and_query(query, kwargs)
query = replace_limit_sql(query)
with conn.cursor() as cursor:
LOGGER.debug(f"Querying Oracle with:\n{query}\nwith kwargs:\n{kwargs}")
await cursor.execute(query, **kwargs)
data = await cursor.fetchall()
columns = map(itemgetter(0), cursor.description)
if to_lower:
columns = map(str.lower, columns)
columns = list(columns)
return pl.DataFrame(
data,
schema=columns,
orient="row",
schema_overrides=schema_overrides,
infer_schema_length=len(data),
)
async def fetch_proc(
conn: oracledb.AsyncConnection,
proc: str,
*,
out_keys: dict[str, Any],
to_lower: bool = True,
**kwargs,
) -> dict[str, Any]:
kwargs = replace_lists(kwargs)
with conn.cursor() as cursor:
out_vals = map(cursor.var, out_keys.values())
out_vals = list(out_vals)
out_dict = zip(out_keys.keys(), out_vals)
kwargs.update(out_dict)
LOGGER.debug(f"Calling Oracle stored proc:\n{proc}\nwith kwargs:\n{kwargs}")
await cursor.callproc(proc, keyword_parameters=kwargs)
out_results = map(oracledb.Var.getvalue, out_vals)
out_results = map(
lambda x: cursor_to_df(x, to_lower)
if isinstance(x, oracledb.AsyncCursor)
else x,
out_results,
)
out_results = zip(out_keys.keys(), out_results)
out_results = dict(out_results)
coros_dict = filter(lambda x: isinstance(x[1], Coroutine), out_results.items())
coros_dict = dict(coros_dict)
results = await asyncio.gather(*coros_dict.values()) # type: ignore
coros_dict = zip(coros_dict.keys(), results)
out_results.update(coros_dict)
return out_results
def get_oracle_types(nice_names_to_oracle_type_names: dict[str, str]):
with oracledb.connect(**ENV, password=MRTE_ORACLE_PWD) as con:
oracle_types = map(con.gettype, nice_names_to_oracle_type_names.values())
nice_names_to_oracle_types = zip(
nice_names_to_oracle_type_names.keys(), oracle_types
)
nice_names_to_oracle_types = dict(nice_names_to_oracle_types)
return nice_names_to_oracle_types
ORACLE_TYPES = get_oracle_types(
{
"number": "SYS.ODCINUMBERLIST",
"string": "SYS.ODCIVARCHAR2LIST",
"bytes": "SYS.ODCIRAWLIST",
}
)
def gen_set_sql(col: str, include_nulls: bool = False):
return f"{col} = :{col}" if include_nulls else f"{col} = COALESCE(:{col}, {col})"
async def update_many(
conn: oracledb.AsyncConnection,
df: pl.DataFrame,
table: str,
pkey_cols: set[str],
include_nulls: bool = False,
):
df_cols = set(df.columns)
non_pkey_cols = df_cols.difference(pkey_cols)
set_sqls = map(lambda x: gen_set_sql(x, include_nulls), non_pkey_cols)
set_sql = ", ".join(set_sqls)
where_sqls = map(lambda x: gen_set_sql(x, True), pkey_cols)
where_sql = " AND ".join(where_sqls)
update_sql = f"""--sql
UPDATE {table}
SET {set_sql}
WHERE {where_sql}
"""
rows = df.to_dicts()
LOGGER.debug(f"Updating Oracle with:\n{update_sql}\nwith df:\n{df}")
await conn.executemany(update_sql, rows)
return conn
async def insert_many(
conn: oracledb.AsyncConnection,
df: pl.DataFrame,
table: str,
):
columns_sql = ", ".join(df.columns)
values_sqls = map(lambda col: f":{col}", df.columns)
values_sql = ", ".join(values_sqls)
insert_sql = f"""--sql
INSERT INTO {table} ({columns_sql})
VALUES ({values_sql})
"""
LOGGER.debug(f"Inserting into Oracle with:\n{insert_sql}\nwith df:\n{df}")
rows = df.to_dicts()
await conn.executemany(insert_sql, rows)
return conn
class PoolWrapper:
pool: oracledb.AsyncConnectionPool
async def init(self):
self.pool = oracledb.create_pool_async(password=MRTE_ORACLE_PWD, **ENV)
@asynccontextmanager
async def acquire(self):
async with self.pool.acquire() as conn:
yield conn
async def close(self):
await self.pool.close(force=True)
async def fetch(
self,
query: str,
*,
schema_overrides: Optional[dict] = None,
to_lower: bool = True,
**kwargs,
) -> pl.DataFrame:
async with self.acquire() as conn:
return await fetch(
conn,
query,
schema_overrides=schema_overrides,
to_lower=to_lower,
**kwargs,
)
async def fetch_proc(
self,
proc: str,
*,
out_keys: dict[str, Any],
to_lower: bool = True,
**kwargs,
) -> dict[str, Any]:
async with self.acquire() as conn:
return await fetch_proc(
conn, proc, out_keys=out_keys, to_lower=to_lower, **kwargs
)
@asynccontextmanager
async def update_many(
self,
df: pl.DataFrame,
table: str,
pkey_cols: set[str],
include_nulls: bool = False,
):
async with self.acquire() as conn:
await update_many(conn, df, table, pkey_cols, include_nulls)
yield ConnWrapper(conn)
if conn.transaction_in_progress:
await conn.commit()
async def update_many_autocommit(
self,
df: pl.DataFrame,
table: str,
pkey_cols: set[str],
include_nulls: bool = False,
):
async with self.acquire() as conn:
conn.autocommit = True
await update_many(conn, df, table, pkey_cols, include_nulls)
conn.autocommit = False
@asynccontextmanager
async def insert_many(self, df: pl.DataFrame, table: str):
async with self.acquire() as conn:
await insert_many(conn, df, table)
yield ConnWrapper(conn)
if conn.transaction_in_progress:
await conn.commit()
async def insert_many_autocommit(self, df: pl.DataFrame, table: str):
async with self.acquire() as conn:
conn.autocommit = True
await insert_many(conn, df, table)
conn.autocommit = False
class ConnWrapper:
conn: oracledb.AsyncConnection
def __init__(self, conn: oracledb.AsyncConnection):
self.conn = conn
async def fetch(
self,
query: str,
*,
schema_overrides: Optional[dict] = None,
to_lower: bool = True,
**kwargs,
) -> pl.DataFrame:
return await fetch(
self.conn,
query,
schema_overrides=schema_overrides,
to_lower=to_lower,
**kwargs,
)
async def fetch_proc(
self,
proc: str,
*,
out_keys: dict[str, Any],
to_lower: bool = True,
**kwargs,
) -> dict[str, Any]:
return await fetch_proc(
self.conn,
proc,
to_lower=to_lower,
out_keys=out_keys,
**kwargs,
)
async def update_many(
self,
df: pl.DataFrame,
table: str,
pkey_cols: set[str],
include_nulls: bool = False,
):
await update_many(self.conn, df, table, pkey_cols, include_nulls)
return self
async def insert_many(self, df: pl.DataFrame, table: str):
await insert_many(self.conn, df, table)
return self
async def commit(self):
await self.conn.commit()
return self
async def rollback(self):
await self.conn.rollback()
return self
ORACLE = PoolWrapper()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment