Last active
September 12, 2023 17:21
-
-
Save benjameep/cbf6ffda9cf5c7729fb2643734edc164 to your computer and use it in GitHub Desktop.
pandas functions
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections import UserDict | |
import re | |
import warnings | |
def to_snake_case(s): | |
return re.sub(r'(?<!^)(?=[A-Z])', '_', s).lower() | |
class NormalizedItem(UserDict): | |
field_names = {} | |
def normalize_value(self, val): | |
if type(val) is dict: | |
if 'uniqueName' in val: | |
return val['uniqueName'] | |
if 'name' in val: | |
return val['name'] | |
else: | |
print(val) | |
raise Exception('Unknown value type') | |
return val | |
def normalize_field_name(self, field: str): | |
if field.startswith('WEF_'): | |
return | |
if field not in self.field_names: | |
alias = to_snake_case(field.split('.')[-1]) | |
for other_field, other_alias in self.field_names.items(): | |
if other_alias == alias: | |
warnings.warn(f'Name conflict between {tuple(sorted([other_field,field]))}') | |
self.field_names[field] = alias | |
return self.field_names[field] | |
def __setitem__(self, field, val): | |
field = self.normalize_field_name(field) | |
val = self.normalize_value(val) | |
if field is None: | |
return | |
if field in item and item[field] != val: | |
print(item, field, val) | |
raise Exception('Value conflict') | |
super().__setitem__(field, val) | |
def normalize_item_row(item): | |
return NormalizedItem({ | |
'id': item['id'], | |
'version': item['rev'], | |
**item['fields'] | |
}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def plurality(df, keys=None): | |
if keys is None: | |
keys = df.index | |
t = df.groupby(keys).nunique() | |
t = t.where(t < 2, 2).apply(lambda col: col.value_counts()).fillna(0).astype(int).replace(0,'').T.rename(columns={0:'none',1:'one',2:'many'}) | |
t['distinct'] = df.nunique() | |
t['top3'] = df.apply(lambda r: r.value_counts().index.tolist()[:3]) | |
return t | |
def get_schema(df): | |
def get_db_type(s): | |
if s.empty: | |
return None | |
if s.str.fullmatch(r'-?\d+(\.\d+)?').all(): | |
no_neg = s.str.lstrip('-0') | |
if s.str.contains('.', regex=False).any(): | |
t = no_neg.str.split('.', expand=True) | |
return f'float({t[0].str.len().max()}, {t[1].str.rstrip("0").str.len().max()})' | |
elif s.str.contains('-', regex=False).any(): | |
return f'int({no_neg.str.len().max()})' | |
else: | |
return f'unsigned_int({no_neg.str.len().max()})' | |
elif pd.api.types.is_datetime64_any_dtype(pd.to_datetime(s, errors='ignore')): | |
return 'datetime' | |
else: | |
return f'char({s.str.len().max()})' | |
return pd.DataFrame([{ | |
'name': df[col].name, | |
'db_type': get_db_type(df[col][~df[col].isna()].astype(str)), | |
'nullable': df[col].isna().any(), | |
'unique': df[col].nunique() == len(df[col]), | |
} for col in df]).set_index('name').replace(False, '') | |
def drop_unhashable(df): | |
unhashable = [] | |
for col in df: | |
s = df[col].dropna() | |
if not s.empty: | |
first_val = s.iloc[0] | |
if not pd.api.types.is_hashable(first_val): | |
if len(unhashable) == 0: | |
print('Dropping unhashable columns:') | |
print(f' {col} <{type(first_val).__name__}>') | |
unhashable.append(col) | |
return df.drop(columns=unhashable) | |
def get_counts(df): | |
df = drop_unhashable(df) | |
cnts = pd.concat([ | |
df.isna().sum(), | |
df.nunique(), | |
df.apply(lambda r: (r.value_counts() == 1).sum()), | |
df.apply(lambda r: r.value_counts().index.tolist()[:3]), | |
], axis=1, keys=['na','distinct','unique','top3','min','max']) | |
cnts.replace(np.nan,'').to_csv('cnts.csv') | |
return cnts | |
import re | |
def to_snake_case(df: pd.DataFrame): | |
df.rename(columns=lambda s: re.sub(r'(?<!^)(?=[A-Z])', '_', s).lower(), inplace=True) | |
return df |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pandas as pd | |
import numpy as np | |
def can_be_parsed_as_datetime(series: pd.Series): | |
"""Check if a pandas Series can be parsed as datetime.""" | |
sample = series.dropna().head(5) # Take a small sample for efficiency | |
for value in sample: | |
try: | |
pd.to_datetime(value) | |
except ValueError: | |
return False | |
return True | |
def determine_sql_dtype(data: pd.Series): | |
data = data.dropna() | |
# Boolean columns | |
if data.dtype == 'bool': | |
return 'bit' | |
# String columns | |
if data.dtype == 'object': | |
# Force data to be a string | |
data = data.astype(str) | |
# Check if the string column can be inferred as datetime | |
if can_be_parsed_as_datetime(data): | |
return 'datetime' | |
str_lengths = data.str.len() | |
max_length = str_lengths.max() | |
has_unicode = data.str.match(r'[^\u0000-\u007F]+').any() | |
n = 'n' if has_unicode else '' | |
# If only one unique length, use char | |
if max_length > 2000: | |
return f'{n}varchar(MAX)' | |
if str_lengths.nunique() == 1: | |
return f'{n}char({max_length})' | |
# Round to next largest power of two | |
rounded_length = 2**np.ceil(np.log2(max_length)) | |
return f'{n}varchar({int(rounded_length)})' | |
# Datetime columns | |
if data.dtype == 'datetime64[ns]': | |
return 'datetime' | |
# Integer columns | |
if np.issubdtype(data.dtype, np.integer): | |
max_value = data.max() | |
# If positive | |
if data.min() >= 0: | |
if max_value <= 1: | |
return 'bit' | |
if max_value <= 255: | |
return 'tinyint' | |
if max_value <= 32767: | |
return 'smallint' | |
if max_value <= 2147483647: | |
return 'int' | |
return 'bigint' | |
if max_value >= -32768 and max_value <= 32767: | |
return 'smallint' | |
if max_value >= -2147483648 and max_value <= 2147483647: | |
return 'int' | |
else: | |
return 'bigint' | |
# Decimal columns | |
elif np.issubdtype(data.dtype, np.floating): | |
max_precision, max_scale = data.abs().astype(str).str.split('.',expand=True).apply(lambda n: n.str.len().max()).to_list() | |
return f'numeric({max_precision},{max_scale})' | |
raise Exception(f'Unhandled data type: {data.name} {data.dtype}') | |
# Test on the Titanic dataset | |
titanic = pd.read_csv('https://raw.githubusercontent.com/datasciencedojo/datasets/master/titanic.csv') | |
titanic.apply(determine_sql_dtype) | |
''' | |
MERGE INTO {} as target | |
USING {} as source | |
ON target.{} = source.{} | |
WHEN MATCHED THEN | |
UPDATE SET {} | |
WHEN NOT MATCHED BY TARGET THEN | |
INSERT ({}) | |
VALUES ({}); | |
OUTPUT | |
$action, | |
inserted.*, | |
deleted.* | |
''' |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment