Created
September 11, 2023 12:09
-
-
Save jonashaag/4466b92e338e0c49c59d827580f19e10 to your computer and use it in GitHub Desktop.
Pandas shrink dtypes
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import pandas as pd | |
from pandas.api.types import is_numeric_dtype | |
from pandas.core.dtypes.base import ExtensionDtype | |
def shrink_dtype(series: pd.Series) -> pd.Series: | |
smallest_dtype = get_smallest_dtype(series) | |
if smallest_dtype == series.dtype: | |
return series | |
else: | |
return series.astype(smallest_dtype) | |
def shrink_dtypes(df: pd.DataFrame) -> pd.DataFrame: | |
new_dtypes = { | |
col: smallest_dtype | |
for col in df.columns | |
if (smallest_dtype := get_smallest_dtype(df[col])) != df[col].dtype | |
} | |
if new_dtypes: | |
return df.astype(new_dtypes) | |
else: | |
return df | |
def get_smallest_dtype(series: pd.Series) -> type[ExtensionDtype] | type[np.generic]: | |
for dtype in _get_possible_dtypes(series): | |
try: | |
series.astype(dtype) | |
return dtype | |
except (TypeError, ValueError): | |
pass | |
return series.dtype | |
def _get_possible_dtypes(series): | |
if not is_numeric_dtype(series.dtype): | |
return [] | |
if series.hasnans: | |
if series.min() < 0: | |
return [pd.Int8Dtype, pd.Int16Dtype, pd.Int32Dtype] | |
else: | |
return [pd.UInt8Dtype, pd.UInt16Dtype, pd.UInt32Dtype] | |
else: | |
if series.min() < 0: | |
return [np.int8, np.int16, np.int32] | |
else: | |
return [np.uint8, np.uint16, np.uint32] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment