Skip to content

Instantly share code, notes, and snippets.

@pushpendre
Created July 22, 2020 19:34
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save pushpendre/a17ce02e8bc2ee293f93360495070954 to your computer and use it in GitHub Desktop.
Save pushpendre/a17ce02e8bc2ee293f93360495070954 to your computer and use it in GitHub Desktop.
diff --git a/utils/tabular/features/abstract_feature_generator.py b/utils/tabular/features/abstract_feature_generator.py
index 155cc1f..2fbcdf8 100644
--- a/utils/tabular/features/abstract_feature_generator.py
+++ b/utils/tabular/features/abstract_feature_generator.py
@@ -89,20 +89,23 @@ class AbstractFeatureGenerator:
self.features_to_remove += self.banned_features
X_index = copy.deepcopy(X.index)
X.columns = X.columns.astype(str) # Ensure all column names are strings
+
+ # populate self.features_init, self.feature_type_family, self.features_to_remove
self.get_feature_types(X)
X = X.drop(self.features_to_remove, axis=1, errors='ignore')
self.features_init_to_keep = copy.deepcopy(list(X.columns))
- self.features_init_types = X.dtypes.to_dict()
+ self.features_init_types = {featname: typ for typ, featname_list in self.feature_type_family.items() for featname in featname_list}
self.feature_type_family_init_raw = get_type_groups_df(X)
X.reset_index(drop=True, inplace=True)
X_features = self.generate_features(X)
+ object_column_set = set(self.feature_type_family.get('object', []))
for column in X_features:
unique_value_count = len(X_features[column].unique())
if unique_value_count == 1:
self.features_to_remove_post.append(column)
# TODO: Consider making 0.99 a parameter to FeatureGenerator
- elif 'object' in self.feature_type_family and column in self.feature_type_family['object'] and (unique_value_count / X_len > 0.99):
+ elif column in object_column_set and (unique_value_count / X_len > 0.99):
self.features_to_remove_post.append(column)
self.features_binned = list(set(self.features_binned) - set(self.features_to_remove_post))
@@ -439,7 +442,7 @@ class AbstractFeatureGenerator:
# TODO: add option for user to specify dtypes on load
@staticmethod
def get_type_family(dtype):
- return get_type_family(dtype=dtype)
+ return get_type_family(dtype)
@staticmethod
def word_count(string):
diff --git a/utils/tabular/features/auto_ml_feature_generator.py b/utils/tabular/features/auto_ml_feature_generator.py
index 78ebf5c..ec6200f 100644
--- a/utils/tabular/features/auto_ml_feature_generator.py
+++ b/utils/tabular/features/auto_ml_feature_generator.py
@@ -62,7 +62,7 @@ class AutoMLFeatureGenerator(AbstractFeatureGenerator):
self._compute_feature_transformations()
X_features = pd.DataFrame(index=X.index)
for column in X.columns:
- if X[column].dtype.name == 'object':
+ if self.features_init_types[column] == 'object':
X[column].fillna('', inplace=True)
else:
X[column].fillna(np.nan, inplace=True)
diff --git a/utils/tabular/features/utils.py b/utils/tabular/features/utils.py
index 86c9320..765e966 100644
--- a/utils/tabular/features/utils.py
+++ b/utils/tabular/features/utils.py
@@ -6,27 +6,35 @@ import numpy as np
logger = logging.getLogger(__name__)
-def get_type_family(dtype):
+def get_type_family(dtype_toplevel):
"""From dtype, gets the dtype family."""
+ # check if dtype is Sparse dtype extension from pandas
+ is_sparse = dtype_toplevel.name.startswith('Sparse[')
+ dtype = dtype_toplevel.subtype if is_sparse else dtype_toplevel
+ ret = None
try:
if dtype.name is 'category':
- return 'category'
+ ret = 'category'
if 'datetime' in dtype.name:
- return 'datetime'
+ ret = 'datetime'
elif np.issubdtype(dtype, np.integer):
- return 'int'
+ ret = 'int'
elif np.issubdtype(dtype, np.floating):
- return 'float'
+ ret = 'float'
except Exception as err:
logger.exception(f'Warning: dtype {dtype} is not recognized as a valid dtype by numpy! AutoGluon may incorrectly handle this feature...')
logger.exception(err)
-
- if dtype.name in ['bool', 'bool_']:
- return 'bool'
- elif dtype.name in ['str', 'string', 'object']:
- return 'object'
- else:
- return dtype.name
+ if ret is None:
+ if dtype.name in ['bool', 'bool_']:
+ ret = 'bool'
+ elif dtype.name in ['str', 'string', 'object']:
+ ret = 'object'
+ else:
+ ret = dtype.name
+ # forget that we are sparse because storage does not
+ # affect semantics.
+ # ~~f'Sparse[{ret}]' if is_sparse else ret~~
+ return ret
def get_type_groups_df(df):
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment