Skip to content

Instantly share code, notes, and snippets.

@jnothman
Created November 17, 2014 09:36
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jnothman/dfd120711d336c88b14d to your computer and use it in GitHub Desktop.
Save jnothman/dfd120711d336c88b14d to your computer and use it in GitHub Desktop.
Allow nested scikit-learn params to be renamed, or multiple parameters tied to hold the same value
from abc import ABCMeta, abstractmethod
from .base import BaseEstimator
from .externals.six import iteritems, with_metaclass
class BaseParameterTranslator(with_metaclass(ABCMeta, BaseEstimator)):
@property
def fit(self):
return getattr(self, self._subest_param).fit
@property
def fit_transform(self):
return getattr(self, self._subest_param).fit_transform
@property
def fit_predict(self):
return getattr(self, self._subest_param).fit_predict
@property
def predict(self):
return getattr(self, self._subest_param).predict
@property
def predict_proba(self):
return getattr(self, self._subest_param).predict_proba
@property
def decision_function(self):
return getattr(self, self._subest_param).decision_function
@property
def score(self):
return getattr(self, self._subest_param).score
@property
def transform(self):
return getattr(self, self._subest_param).transform
@property
def inverse_transform(self):
return getattr(self, self._subest_param).inverse_transform
@abstractmethod
def _get_translation_rules(self):
pass
@abstractmethod
def _get_inverse_translation_rules(self):
pass
@property
@abstractmethod
def _subest_param(self):
pass
def set_params(self, **params):
local_params = self.get_params(deep=False)
translated = {}
rules = self._get_translation_rules()
for k, v in iteritems(params):
if k in local_params:
translated[k] = v
continue
matched = False
for translate_key, translate_value in rules:
translations = translate_key(k)
if translations:
matched = True
trans_v = translate_value(v)
for trans_k in translations:
translated[trans_k] = trans_v
if not matched:
raise ValueError('Invalid parameter %s for estimator %s' %
(k, self))
super(BaseParameterTranslator, self).set_params(**translated)
def get_params(self, deep=True):
subest_param = self._subest_param
result = super(BaseParameterTranslator, self).get_params(deep=False)
if not deep:
return result
rules = self._get_inverse_translation_rules()
for k, v in iteritems(getattr(self, subest_param).get_params(deep)):
k = subest_param + '__' + k
for translate_key, translate_value in rules:
translations = translate_key(k)
if translations:
trans_v = translate_value(v)
for trans_k in translations:
result[trans_k] = trans_v
return result
IDENTITY = lambda x: x
class ParameterTie(BaseParameterTranslator):
def __init__(self, _estimator, _ties):
self._estimator = _estimator
self._ties = _ties
# TODO: handle wildcard
_subest_param = '_estimator'
def _get_translation_rules(self):
return [
(lambda k: ['_estimator__' + t for t in self._ties.get(k, [])], IDENTITY),
(lambda k: None if k in self._ties else ['_estimator__' + k], IDENTITY),
]
def _get_inverse_translation_rules(self):
rule_map = {}
for alias, targets in iteritems(self._ties):
for target in targets:
rule_map['_estimator__' + target] = [alias]
return [
(rule_map.get, IDENTITY),
(lambda k: None if k in rule_map else [k[len('_estimator__'):]], IDENTITY),
]
if __name__ == '__main__':
from .pipeline import FeatureUnion
from .feature_selection import SelectPercentile
fu = FeatureUnion([('a', SelectPercentile()), ('b', SelectPercentile())])
tie = ParameterTie(fu, {'percentile': ['a__percentile', 'b__percentile']})
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment