Skip to content

Instantly share code, notes, and snippets.

@drorata
Last active June 23, 2017 08:40
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save drorata/b25547d5f5ed01411658b46f30dfc140 to your computer and use it in GitHub Desktop.
Save drorata/b25547d5f5ed01411658b46f30dfc140 to your computer and use it in GitHub Desktop.
Applying transformations on subset of columns
import pandas as pd
import sklearn
from sklearn.preprocessing import StandardScaler
class GetDummiesCatCols(sklearn.base.BaseEstimator, sklearn.base.TransformerMixin):
"""Replace `cols` with their dummies (One Hot Encoding).
`cols` should be a list of column names holding categorical data.
Furthermore, this class streamlines the implementation of one hot encoding
as available on [pandas.get_dummies](https://pandas.pydata.org/pandas-docs/stable/generated/pandas.get_dummies.html)
"""
def __init__(self, cols=None):
self.cols = cols
def transform(self, df, **transform_params):
cols_dummy = pd.get_dummies(df[self.cols])
df = df.drop(self.cols, axis=1)
df = pd.concat([df, cols_dummy], axis=1)
return df
def fit(self, df, y=None, **fit_params):
return self
class StandartizeFloatCols(sklearn.base.BaseEstimator, sklearn.base.TransformerMixin):
"""Standard-scale the columns in the data frame.
`cols` should be a list of columns in the data.
"""
def __init__(self, cols=None):
self.cols = cols
if cols is None:
raise ValueError(
"The `cols` parameter has to be defined as a list of strs representing columns")
self.standard_scaler = StandardScaler()
self._is_fitted = False
def _are_cols_valid(self, df):
if set(self.cols).issubset(set(df.columns)):
return True
else:
raise ValueError(
"Class instantiated with columns that don't appear in the data frame"
)
def transform(self, df, **transform_params):
if not self._is_fitted:
raise NotFittedError("Fitting was not preformed")
self._are_cols_valid(df)
standartize_cols = pd.DataFrame(
# StandardScaler returns a NumPy.array, and thus indexing
# breaks. Explicitly fixed next.
self.standard_scaler.transform(df[self.cols]),
columns=self.cols,
# The index of the resulting DataFrame should be assigned and
# equal to the one of the original DataFrame. Otherwise, upon
# concatenation NaNs will be introduced.
index=df.index
)
df = df.drop(self.cols, axis=1)
df = pd.concat([df, standartize_cols], axis=1)
return df
def fit(self, df, y=None, **fit_params):
self._are_cols_valid(df)
self.standard_scaler.fit(df[self.cols])
self._is_fitted = True
return self
@drorata
Copy link
Author

drorata commented Jun 22, 2017

So, I realized that this version is wrong. The reason is that StandartizeFloatCols() fits the scaling factors during the transform. Consider the following example.

custom_std_scale = StandartizeFloatCols(cols=['v1', 'v2'])

N=1000
np.random.seed(42)
df = pd.DataFrame(
    {
        "v1": np.random.random(size=N),
        "v2": np.random.geometric(0.2, size=N)
    }
)

df_test = pd.DataFrame(
    {
        "v1": np.random.random(size=N),
        "v2": np.random.geometric(0.2, size=N)
    }
)

df_test_swap = pd.DataFrame(
    {
        "v1": np.random.geometric(0.2, size=N),
        "v2": np.random.random(size=N)
    }
)

Let us fit the class using df:

custom_std_scale.fit(df)

Now, we can use it to transform df_test. We can also expect (and this is the expected behavior) that the fitted class will use the scales that it learned in the previous step and apply them to df_test. As the distribution used are the same, the resulting columns should have zero mean and unit STD. Indeed:

custom_std_scale.transform(df_test).describe()

yields:

        v1              v2
mean	-7.460699e-17	9.148238e-17
std	1.000500e+00	1.000500e+00

However, if we use custom_std_scale on df_test_swap, we expect to have the non zero mean and non-unit STD since the distributions are swapped. BUT:

custom_std_scale.transform(df_test_swap).describe()

yields

        v1              v2
mean	-6.217249e-17	-7.016610e-17
std	1.000500e+00	1.000500e+00

The reason is that the instance re-fitted the underlying StandardScaler() instance.

I will fix it and commit a new version.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment