Last active
September 7, 2023 13:44
-
-
Save tdpetrou/6a97304dd4452a53be98e4f4e93196e6 to your computer and use it in GitHub Desktop.
A custom scikit-learn transformer for one-hot encoding categorical values, and standardizing numeric columns
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pandas as pd | |
import numpy as np | |
from sklearn.base import BaseEstimator | |
class BasicTransformer(BaseEstimator): | |
def __init__(self, cat_threshold=None, num_strategy='median', return_df=False): | |
# store parameters as public attributes | |
self.cat_threshold = cat_threshold | |
if num_strategy not in ['mean', 'median']: | |
raise ValueError('num_strategy must be either "mean" or "median"') | |
self.num_strategy = num_strategy | |
self.return_df = return_df | |
def fit(self, X, y=None): | |
# Assumes X is a DataFrame | |
self._columns = X.columns.values | |
# Split data into categorical and numeric | |
self._dtypes = X.dtypes.values | |
self._kinds = np.array([dt.kind for dt in X.dtypes]) | |
self._column_dtypes = {} | |
is_cat = self._kinds == 'O' | |
self._column_dtypes['cat'] = self._columns[is_cat] | |
self._column_dtypes['num'] = self._columns[~is_cat] | |
self._feature_names = self._column_dtypes['num'] | |
# Create a dictionary mapping categorical column to unique values above threshold | |
self._cat_cols = {} | |
for col in self._column_dtypes['cat']: | |
vc = X[col].value_counts() | |
if self.cat_threshold is not None: | |
vc = vc[vc > self.cat_threshold] | |
vals = vc.index.values | |
self._cat_cols[col] = vals | |
self._feature_names = np.append(self._feature_names, col + '_' + vals) | |
# get total number of new categorical columns | |
self._total_cat_cols = sum([len(v) for col, v in self._cat_cols.items()]) | |
# get mean or median | |
self._num_fill = X[self._column_dtypes['num']].agg(self.num_strategy) | |
return self | |
def transform(self, X): | |
# check that we have a DataFrame with same column names as the one we fit | |
if set(self._columns) != set(X.columns): | |
raise ValueError('Passed DataFrame has different columns than fit DataFrame') | |
elif len(self._columns) != len(X.columns): | |
raise ValueError('Passed DataFrame has different number of columns than fit DataFrame') | |
# fill missing values | |
X_num = X[self._column_dtypes['num']].fillna(self._num_fill) | |
# Standardize numerics | |
std = X_num.std() | |
X_num = (X_num - X_num.mean()) / std | |
zero_std = np.where(std == 0)[0] | |
# If there is 0 standard deviation, then all values are the same. Set them to 0. | |
if len(zero_std) > 0: | |
X_num.iloc[:, zero_std] = 0 | |
X_num = X_num.values | |
# create separate array for new encoded categoricals | |
X_cat = np.empty((len(X), self._total_cat_cols), dtype='int') | |
i = 0 | |
for col in self._column_dtypes['cat']: | |
vals = self._cat_cols[col] | |
for val in vals: | |
X_cat[:, i] = X[col] == val | |
i += 1 | |
# concatenate transformed numeric and categorical arrays | |
data = np.column_stack((X_num, X_cat)) | |
# return either a DataFrame or an array | |
if self.return_df: | |
return pd.DataFrame(data=data, columns=self._feature_names) | |
else: | |
return data | |
def fit_transform(self, X, y=None): | |
return self.fit(X).transform(X) | |
def get_feature_names(self): | |
return self._feature_names | |
train = pd.read_csv('https://raw.githubusercontent.com/DunderData/Machine-Learning-Tutorials/master/data/housing/train.csv') | |
train = train.drop(columns='Id') | |
y = train.pop('SalePrice') | |
from sklearn.pipeline import Pipeline | |
from sklearn.linear_model import Ridge | |
bt = BasicTransformer(cat_threshold=5, return_df=True) | |
basic_pipe = Pipeline([('bt', bt), ('ridge', Ridge())]) | |
from sklearn.model_selection import KFold, cross_val_score | |
kf = KFold(n_splits=5, shuffle=True, random_state=123) | |
cross_val_score(basic_pipe, train, y, cv=kf).mean() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment