Skip to content

Instantly share code, notes, and snippets.

@Coldsp33d
Created June 18, 2019 02:54
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Coldsp33d/30d0ff66a7164b371816e52b6f6e00bc to your computer and use it in GitHub Desktop.
Save Coldsp33d/30d0ff66a7164b371816e52b6f6e00bc to your computer and use it in GitHub Desktop.
Selective error handling
from sklearn.preprocessing import OneHotEncoder
class SelectiveHandlerOHE(OneHotEncoder):
def __init__(self, *args, raise_error_cols=[], **kwargs):
kwargs['handle_unknown'] = 'ignore' # change the default
self.raise_error_cols = raise_error_cols.copy()
super().__init__(*args, **kwargs)
def check_cols(self, X):
if self.raise_error_cols and any(
c not in X.columns for c in self.raise_error_cols):
msg = ("One or more column names are incorrect. "
"Please check the column names passed "
"to the `raise_error_cols` argument")
raise ValueError(msg)
self.columns = X.columns
def fit(self, X):
self.check_cols(X)
return super().fit(X)
def transform(self, X):
X_ = np.array(X)
if X_.ndim > 1:
for c in self.raise_error_cols:
idx = self.columns.get_loc(c)
arr1 = X_[:, idx]
arr2 = self.categories_[idx]
if not np.in1d(arr1, arr2).all():
cats = ','.join(np.setdiff1d(arr1, arr2))
msg = ("Found unknown categories {0} in column {1}"
" during fit".format(cats, c))
raise ValueError(msg)
return super().transform(X)
def fit_transform(self, X):
self.check_cols(X)
return super().fit_transform(X)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment