Created
October 11, 2020 23:42
-
-
Save nbertagnolli/ba49aa8833d3d749b76f1ae8b19a2cca to your computer and use it in GitHub Desktop.
Extracts feature names from an sklearn base model, transformer, etc.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def extract_feature_names(model, name) -> List[str]: | |
"""Extracts the feature names from arbitrary sklearn models | |
Args: | |
model: The Sklearn model, transformer, clustering algorithm, etc. which we want to get named features for. | |
name: The name of the current step in the pipeline we are at. | |
Returns: | |
The list of feature names. If the model does not have named features it constructs feature names | |
by appending an index to the provided name. | |
""" | |
if hasattr(model, "get_feature_names"): | |
return model.get_feature_names() | |
elif hasattr(model, "n_clusters"): | |
return [f"{name}_{x}" for x in range(model.n_clusters)] | |
elif hasattr(model, "n_components"): | |
return [f"{name}_{x}" for x in range(model.n_components)] | |
elif hasattr(model, "components_"): | |
n_components = model.components_.shape[0] | |
return [f"{name}_{x}" for x in range(n_components)] | |
elif hasattr(model, "classes_"): | |
return classes_ | |
else: | |
return [name] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment