Skip to content

Instantly share code, notes, and snippets.

@benjameep
Last active September 12, 2023 17:21
Show Gist options
  • Save benjameep/cbf6ffda9cf5c7729fb2643734edc164 to your computer and use it in GitHub Desktop.
Save benjameep/cbf6ffda9cf5c7729fb2643734edc164 to your computer and use it in GitHub Desktop.
pandas functions
from collections import UserDict
import re
import warnings
def to_snake_case(s):
return re.sub(r'(?<!^)(?=[A-Z])', '_', s).lower()
class NormalizedItem(UserDict):
field_names = {}
def normalize_value(self, val):
if type(val) is dict:
if 'uniqueName' in val:
return val['uniqueName']
if 'name' in val:
return val['name']
else:
print(val)
raise Exception('Unknown value type')
return val
def normalize_field_name(self, field: str):
if field.startswith('WEF_'):
return
if field not in self.field_names:
alias = to_snake_case(field.split('.')[-1])
for other_field, other_alias in self.field_names.items():
if other_alias == alias:
warnings.warn(f'Name conflict between {tuple(sorted([other_field,field]))}')
self.field_names[field] = alias
return self.field_names[field]
def __setitem__(self, field, val):
field = self.normalize_field_name(field)
val = self.normalize_value(val)
if field is None:
return
if field in item and item[field] != val:
print(item, field, val)
raise Exception('Value conflict')
super().__setitem__(field, val)
def normalize_item_row(item):
return NormalizedItem({
'id': item['id'],
'version': item['rev'],
**item['fields']
})
def plurality(df, keys=None):
if keys is None:
keys = df.index
t = df.groupby(keys).nunique()
t = t.where(t < 2, 2).apply(lambda col: col.value_counts()).fillna(0).astype(int).replace(0,'').T.rename(columns={0:'none',1:'one',2:'many'})
t['distinct'] = df.nunique()
t['top3'] = df.apply(lambda r: r.value_counts().index.tolist()[:3])
return t
def get_schema(df):
def get_db_type(s):
if s.empty:
return None
if s.str.fullmatch(r'-?\d+(\.\d+)?').all():
no_neg = s.str.lstrip('-0')
if s.str.contains('.', regex=False).any():
t = no_neg.str.split('.', expand=True)
return f'float({t[0].str.len().max()}, {t[1].str.rstrip("0").str.len().max()})'
elif s.str.contains('-', regex=False).any():
return f'int({no_neg.str.len().max()})'
else:
return f'unsigned_int({no_neg.str.len().max()})'
elif pd.api.types.is_datetime64_any_dtype(pd.to_datetime(s, errors='ignore')):
return 'datetime'
else:
return f'char({s.str.len().max()})'
return pd.DataFrame([{
'name': df[col].name,
'db_type': get_db_type(df[col][~df[col].isna()].astype(str)),
'nullable': df[col].isna().any(),
'unique': df[col].nunique() == len(df[col]),
} for col in df]).set_index('name').replace(False, '')
def drop_unhashable(df):
unhashable = []
for col in df:
s = df[col].dropna()
if not s.empty:
first_val = s.iloc[0]
if not pd.api.types.is_hashable(first_val):
if len(unhashable) == 0:
print('Dropping unhashable columns:')
print(f' {col} <{type(first_val).__name__}>')
unhashable.append(col)
return df.drop(columns=unhashable)
def get_counts(df):
df = drop_unhashable(df)
cnts = pd.concat([
df.isna().sum(),
df.nunique(),
df.apply(lambda r: (r.value_counts() == 1).sum()),
df.apply(lambda r: r.value_counts().index.tolist()[:3]),
], axis=1, keys=['na','distinct','unique','top3','min','max'])
cnts.replace(np.nan,'').to_csv('cnts.csv')
return cnts
import re
def to_snake_case(df: pd.DataFrame):
df.rename(columns=lambda s: re.sub(r'(?<!^)(?=[A-Z])', '_', s).lower(), inplace=True)
return df
import pandas as pd
import numpy as np
def can_be_parsed_as_datetime(series: pd.Series):
"""Check if a pandas Series can be parsed as datetime."""
sample = series.dropna().head(5) # Take a small sample for efficiency
for value in sample:
try:
pd.to_datetime(value)
except ValueError:
return False
return True
def determine_sql_dtype(data: pd.Series):
data = data.dropna()
# Boolean columns
if data.dtype == 'bool':
return 'bit'
# String columns
if data.dtype == 'object':
# Force data to be a string
data = data.astype(str)
# Check if the string column can be inferred as datetime
if can_be_parsed_as_datetime(data):
return 'datetime'
str_lengths = data.str.len()
max_length = str_lengths.max()
has_unicode = data.str.match(r'[^\u0000-\u007F]+').any()
n = 'n' if has_unicode else ''
# If only one unique length, use char
if max_length > 2000:
return f'{n}varchar(MAX)'
if str_lengths.nunique() == 1:
return f'{n}char({max_length})'
# Round to next largest power of two
rounded_length = 2**np.ceil(np.log2(max_length))
return f'{n}varchar({int(rounded_length)})'
# Datetime columns
if data.dtype == 'datetime64[ns]':
return 'datetime'
# Integer columns
if np.issubdtype(data.dtype, np.integer):
max_value = data.max()
# If positive
if data.min() >= 0:
if max_value <= 1:
return 'bit'
if max_value <= 255:
return 'tinyint'
if max_value <= 32767:
return 'smallint'
if max_value <= 2147483647:
return 'int'
return 'bigint'
if max_value >= -32768 and max_value <= 32767:
return 'smallint'
if max_value >= -2147483648 and max_value <= 2147483647:
return 'int'
else:
return 'bigint'
# Decimal columns
elif np.issubdtype(data.dtype, np.floating):
max_precision, max_scale = data.abs().astype(str).str.split('.',expand=True).apply(lambda n: n.str.len().max()).to_list()
return f'numeric({max_precision},{max_scale})'
raise Exception(f'Unhandled data type: {data.name} {data.dtype}')
# Test on the Titanic dataset
titanic = pd.read_csv('https://raw.githubusercontent.com/datasciencedojo/datasets/master/titanic.csv')
titanic.apply(determine_sql_dtype)
'''
MERGE INTO {} as target
USING {} as source
ON target.{} = source.{}
WHEN MATCHED THEN
UPDATE SET {}
WHEN NOT MATCHED BY TARGET THEN
INSERT ({})
VALUES ({});
OUTPUT
$action,
inserted.*,
deleted.*
'''
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment