Skip to content

Instantly share code, notes, and snippets.

@irogers
Last active May 3, 2023 11:27
Show Gist options
  • Save irogers/f529bba5658b1aa31c9cbef3d8a5d563 to your computer and use it in GitHub Desktop.
Save irogers/f529bba5658b1aa31c9cbef3d8a5d563 to your computer and use it in GitHub Desktop.
Pandas Helpers
from functools import wraps
from pathlib import Path
def cache_query_results(fp: Path, rerun_default: bool = False):
"""Exmple usage
@cache_query_result(fp="path_to_cache/data.pkl")
def query() -> pd.DataFrame:
return pd.read_sql(query, engine)
query(rerun=True)
Args:
fp (Path): _description_
rerun_default (bool, optional): _description_. Defaults to False.
"""
def decorator(function):
@wraps(function)
def wrapper(
*args,
rerun=rerun_default,
**kwargs,
) -> pd.DataFrame:
if rerun is True or not fp.exists():
df = function(*args, **kwargs)
df.to_pickle(fp)
else:
df = pd.read_pickle(fp)
return df
return wrapper
return decorator
from typing import List
import pandas as pd
def missing_values(df_input: pd.Dataframe) -> pd.Dataframe:
missing_values = (
df_input.isnull().sum()
.sort_values(ascending=False)
.to_frame(name='num_missing_values')
.assign(perc=lambda x: x['num_missing_values']/len(df)*100)
.round(1)
)
return missing_values[missing_values['num_missing_values'] > 0]
def contingency_table(xs: List[str], target: str, data: pd.DataFrame) -> DataFrame:
index = [data[col] for col in xs]
y = data[target]
crosstab = pd.crosstab(index, y, margins=True)
totals = crosstab.iloc[:, -1].copy().rename("Support")
values = crosstab.iloc[:, :-1]
contingency_table = (
values.div(values.sum(axis=1), axis=0)
.round(2)
.merge(totals, left_index=True, right_index=True)
.rename_axis(
len(xs)
* [
None,
],
axis=0,
)
)
return contingency_table
def compose_docstrings(df: pd.DataFrame) -> None:
docstring = " Columns:\n"
for col, dtype in df.dtypes.items():
docstring += f" {col} ({dtype}): _description_\n"
print(docstring)
def validate_columns(f):
@wraps(f)
def wrapper(*pos, **names):
df: pd.DataFrame = f(*pos, **names)
column_names = get_cols(f)
check_missing_columns(df, column_names)
check_extra_columns(df, column_names)
return df
return wrapper
def check_missing_columns(df: pd.DataFrame, column_names: Iterable[str]):
if len(missing_columns := set(column_names) - set(df.columns.values)) > 0:
raise ValueError(f"Missing columns: {missing_columns}")
def check_extra_columns(df: pd.DataFrame, column_names: Iterable[str]):
if len(extra_columns := set(df.columns.values) - set(column_names)) > 0:
raise ValueError(f"Extra columns: {extra_columns}")
def get_cols(func):
lines = [l.strip() for l in func.__doc__.splitlines() if l.strip() != ""]
columns = lines[lines.index("Columns:") + 1 :]
return [l.split(":")[0] for l in columns]
def parse_df(df_input: pd.DataFrame, col_defs: dict) -> pd.DataFrame:
col_names = {k: v[0] for k, v in col_defs.items()}
col_dtypes = {v[0]: v[1] for v in col_defs.values() if v[1] != np.datetime64}
col_datetimes = {v[0]: v[2] for v in col_defs.values() if v[1] == np.datetime64}
df = (
df_input[col_names.keys()]
.rename(columns=col_names)[col_names.values()]
.astype(col_dtypes)
)
for col, date_fromat in col_datetimes.items():
if not is_datetime64_any_dtype(df[col]):
df[col] = pd.to_datetime(df[col], format=date_fromat)
return df
import pandas as pd
from pandas.api.types import is_datetime64_any_dtype
from typing import Dict, Iterable
def parse_df(df_input: pd.DataFrame, col_defs: dict) -> pd.DataFrame:
"""Example usage:
COLS = {
"Date": ("date", np.datetime64, "%Y-%m-%d"),
"IntCol": ("int_col", int)
}
df_raw.pipe(parse_df, COLS)
Args:
df_input (pd.DataFrame): _description_
col_defs (dict): _description_
Returns:
pd.DataFrame: _description_
"""
col_names = {k: v[0] for k, v in col_defs.items()}
col_dtypes = {v[0]: v[1] for v in col_defs.values() if v[1] != np.datetime64}
col_datetimes = {v[0]: v[2] for v in col_defs.values() if v[1] == np.datetime64}
df = (
df_input[col_names.keys()]
.rename(columns=col_names)[col_names.values()]
.astype(col_dtypes)
)
for col, date_fromat in col_datetimes.items():
if not is_datetime64_any_dtype(df[col]):
df[col] = pd.to_datetime(df[col], format=date_fromat)
return df
def filter_df(df_input: pd.DataFrame, filters: Iterable[Dict]) -> pd.DataFrame:
"""Example usage:
df_origin_filtered = df_origin.pipe(
filter_df, filters=[{"medium": "blog", "source": "Qualiabio"}]
)
Args:
df_input (pd.DataFrame): _description_
filters (Iterable[Dict]): _description_
Returns:
pd.DataFrame: _description_
"""
df_filter = pd.DataFrame(filters)
df = df_input.merge(
df_filter.reset_index(drop=False),
on=list(df_filter.columns),
how="left"
)
return df.query("index.isna()").drop(columns='index').reset_index(drop=True)
def remove_outliers(df_input, column, k=1.5):
df = df_input.copy()
Q1 = df[column].quantile(0.25)
Q3 = df[column].quantile(0.75)
IQR = Q3 - Q1
filter_mask = f'(@Q1 - @k*@IQR) <= {column} <= (@Q3 + @k*@IQR)'
filtered = df.query(filter_mask).reset_index()
return filtered
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment