Skip to content

Instantly share code, notes, and snippets.

@jnothman
Last active July 13, 2021 01:58
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jnothman/ba46247a36d375136a6662d1b1ef4c6d to your computer and use it in GitHub Desktop.
Save jnothman/ba46247a36d375136a6662d1b1ef4c6d to your computer and use it in GitHub Desktop.
A wrapper for functions so that they can be parametrized with get_params and set_params in scikit-learn: proof of concept
from collections import defaultdict
import pandas as pd
class parametrized_function:
def __init__(self, _func, **kwargs):
self._func = _func
self.__doc__ = self._func.__doc__
self.__name__ = self._func.__name__
# TODO use inspect to automatically find parameters with defaults
self._params = kwargs
def __call__(self, *args, **kwargs):
kw = self._params
kw.update(kwargs)
return self._func(*args, **kw)
def get_params(self, deep=False):
out = self._params.copy()
out['_func'] = self._func
for key, value in out.items():
if deep and hasattr(value, 'get_parms'):
deep_items = value.get_params().items()
out.update((key + '__' + k, val) for k, val in deep_items)
return out
def set_params(self, **params):
if not params:
# Simple optimization to gain speed (inspect is slow)
return self
valid_params = self.get_params(deep=True)
nested_params = defaultdict(dict) # grouped by prefix
for key, value in params.items():
key, delim, sub_key = key.partition('__')
if key not in valid_params:
raise ValueError('Invalid parameter %s for estimator %s. '
'Check the list of available parameters '
'with `estimator.get_params().keys()`.' %
(key, self))
if delim:
nested_params[key][sub_key] = value
else:
self._params[key] = value
valid_params[key] = value
for key, sub_params in nested_params.items():
valid_params[key].set_params(**sub_params)
return self
if __name__ == '__main__':
from sklearn.feature_selection import mutual_info_regression, SelectKBest
from sklearn.model_selection import GridSearchCV
from sklearn.pipeline import make_pipeline
from sklearn.linear_model import LinearRegression
from sklearn.datasets import make_regression
mutual_info_regression = parametrized_function(mutual_info_regression,
n_neighbors=3)
X, y = make_regression()
gs = GridSearchCV(make_pipeline(SelectKBest(mutual_info_regression, k=1),
LinearRegression()),
{'selectkbest__score_func__n_neighbors': [3, 4]},
cv=5, return_train_score=False).fit(X, y)
print(pd.DataFrame(gs.cv_results_))
@jnothman
Copy link
Author

Note that if BaseEstimator's get_params and set_params were refactored to some BaseParametrized where _get_param was used instead of getattr and _set_param was used instead of setattr, we wouldn't need to duplicate so much code here

@thomasjpfan
Copy link

Reading through the comments, I do not find it too magically. I see this as an extension of how Pipeline or ColumnTransformer exposes their estimators with get_params and set_params.

As for refactoring into a BaseParametrized, if we make BaseParametrized easy to use for third parties, then they can easily extend any object to have the get_params and set_params interface.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment