Last active
October 11, 2020 23:52
-
-
Save nbertagnolli/3a1cb25129ddd3ef8e986b2dd31f1e43 to your computer and use it in GitHub Desktop.
Gets the feature names in order from an arbitrary sklearn pipeline
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sklearn.pipeline import FeatureUnion, Pipeline | |
def get_feature_names(model, names: List[str], name: str) -> List[str]: | |
"""Thie method extracts the feature names in order from a Sklearn Pipeline | |
This method only works with composed Pipelines and FeatureUnions. It will | |
pull out all names using DFS from a model. | |
Args: | |
model: The model we are interested in | |
names: The list of names of final featurizaiton steps | |
name: The current name of the step we want to evaluate. | |
Returns: | |
feature_names: The list of feature names extracted from the pipeline. | |
""" | |
# Check if the name is one of our feature steps. This is the base case. | |
if name in names: | |
# If it has the named_steps atribute it's a pipeline and we need to access the features | |
if hasattr(model, "named_steps"): | |
return extract_feature_names(model.named_steps[name], name) | |
# Otherwise get the feature directly | |
else: | |
return extract_feature_names(model, name) | |
elif type(model) is Pipeline: | |
feature_names = [] | |
for name in model.named_steps.keys(): | |
feature_names += get_feature_names(model.named_steps[name], names, name) | |
return feature_names | |
elif type(model) is FeatureUnion: | |
feature_names= [] | |
for name, new_model in model.transformer_list: | |
feature_names += get_feature_names(new_model, names, name) | |
return feature_names | |
# If it is none of the above do not add it. | |
else: | |
return [] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment