Last active
June 23, 2017 08:40
-
-
Save drorata/b25547d5f5ed01411658b46f30dfc140 to your computer and use it in GitHub Desktop.
Applying transformations on subset of columns
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pandas as pd | |
import sklearn | |
from sklearn.preprocessing import StandardScaler | |
class GetDummiesCatCols(sklearn.base.BaseEstimator, sklearn.base.TransformerMixin): | |
"""Replace `cols` with their dummies (One Hot Encoding). | |
`cols` should be a list of column names holding categorical data. | |
Furthermore, this class streamlines the implementation of one hot encoding | |
as available on [pandas.get_dummies](https://pandas.pydata.org/pandas-docs/stable/generated/pandas.get_dummies.html) | |
""" | |
def __init__(self, cols=None): | |
self.cols = cols | |
def transform(self, df, **transform_params): | |
cols_dummy = pd.get_dummies(df[self.cols]) | |
df = df.drop(self.cols, axis=1) | |
df = pd.concat([df, cols_dummy], axis=1) | |
return df | |
def fit(self, df, y=None, **fit_params): | |
return self | |
class StandartizeFloatCols(sklearn.base.BaseEstimator, sklearn.base.TransformerMixin): | |
"""Standard-scale the columns in the data frame. | |
`cols` should be a list of columns in the data. | |
""" | |
def __init__(self, cols=None): | |
self.cols = cols | |
if cols is None: | |
raise ValueError( | |
"The `cols` parameter has to be defined as a list of strs representing columns") | |
self.standard_scaler = StandardScaler() | |
self._is_fitted = False | |
def _are_cols_valid(self, df): | |
if set(self.cols).issubset(set(df.columns)): | |
return True | |
else: | |
raise ValueError( | |
"Class instantiated with columns that don't appear in the data frame" | |
) | |
def transform(self, df, **transform_params): | |
if not self._is_fitted: | |
raise NotFittedError("Fitting was not preformed") | |
self._are_cols_valid(df) | |
standartize_cols = pd.DataFrame( | |
# StandardScaler returns a NumPy.array, and thus indexing | |
# breaks. Explicitly fixed next. | |
self.standard_scaler.transform(df[self.cols]), | |
columns=self.cols, | |
# The index of the resulting DataFrame should be assigned and | |
# equal to the one of the original DataFrame. Otherwise, upon | |
# concatenation NaNs will be introduced. | |
index=df.index | |
) | |
df = df.drop(self.cols, axis=1) | |
df = pd.concat([df, standartize_cols], axis=1) | |
return df | |
def fit(self, df, y=None, **fit_params): | |
self._are_cols_valid(df) | |
self.standard_scaler.fit(df[self.cols]) | |
self._is_fitted = True | |
return self |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
So, I realized that this version is wrong. The reason is that
StandartizeFloatCols()
fits the scaling factors during the transform. Consider the following example.Let us fit the class using
df
:Now, we can use it to transform
df_test
. We can also expect (and this is the expected behavior) that the fitted class will use the scales that it learned in the previous step and apply them todf_test
. As the distribution used are the same, the resulting columns should have zero mean and unit STD. Indeed:yields:
However, if we use
custom_std_scale
ondf_test_swap
, we expect to have the non zero mean and non-unit STD since the distributions are swapped. BUT:yields
The reason is that the instance re-fitted the underlying
StandardScaler()
instance.I will fix it and commit a new version.