Skip to content

Instantly share code, notes, and snippets.

@jonashaag
Created September 11, 2023 12:09
Show Gist options
  • Save jonashaag/4466b92e338e0c49c59d827580f19e10 to your computer and use it in GitHub Desktop.
Save jonashaag/4466b92e338e0c49c59d827580f19e10 to your computer and use it in GitHub Desktop.
Pandas shrink dtypes
import numpy as np
import pandas as pd
from pandas.api.types import is_numeric_dtype
from pandas.core.dtypes.base import ExtensionDtype
def shrink_dtype(series: pd.Series) -> pd.Series:
smallest_dtype = get_smallest_dtype(series)
if smallest_dtype == series.dtype:
return series
else:
return series.astype(smallest_dtype)
def shrink_dtypes(df: pd.DataFrame) -> pd.DataFrame:
new_dtypes = {
col: smallest_dtype
for col in df.columns
if (smallest_dtype := get_smallest_dtype(df[col])) != df[col].dtype
}
if new_dtypes:
return df.astype(new_dtypes)
else:
return df
def get_smallest_dtype(series: pd.Series) -> type[ExtensionDtype] | type[np.generic]:
for dtype in _get_possible_dtypes(series):
try:
series.astype(dtype)
return dtype
except (TypeError, ValueError):
pass
return series.dtype
def _get_possible_dtypes(series):
if not is_numeric_dtype(series.dtype):
return []
if series.hasnans:
if series.min() < 0:
return [pd.Int8Dtype, pd.Int16Dtype, pd.Int32Dtype]
else:
return [pd.UInt8Dtype, pd.UInt16Dtype, pd.UInt32Dtype]
else:
if series.min() < 0:
return [np.int8, np.int16, np.int32]
else:
return [np.uint8, np.uint16, np.uint32]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment