Created
March 11, 2022 15:09
-
-
Save krsnewwave/a91e3249bcd96d6e00c6d737e6595958 to your computer and use it in GitHub Desktop.
kedro pipeline
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# in <root>/src/<project>/pipelines/data_science/pipeline.py | |
from kedro.pipeline import node, pipeline | |
from .nodes import split_data, fit_xgboost | |
def create_plot_roc_node(): | |
return node( | |
func=plot_roc, | |
inputs=["clf", "X_test", "y_test"], | |
outputs="roc_graph", | |
name="plot_roc", | |
) | |
def create_split_node(): | |
return node( | |
func=split_data, | |
inputs=["model_input_table", "params:split_options"], | |
outputs=["X_train", "X_test", "y_train", "y_test"], | |
name="split_data_node", | |
) | |
def create_xgb_pipeline(**kwargs): | |
split_node = create_split_node() | |
plot_node = create_plot_roc_node() | |
xgb_pipe_instance = pipeline( | |
[ | |
split_node, | |
node( | |
func=fit_xgboost, | |
inputs=["X_train", "y_train", "X_test", "y_test", | |
"params:xgboost_params_full_feats"], | |
outputs={"clf": "clf", "model_metrics": "model_metrics"}, | |
name="train_xgboost", | |
), | |
plot_node | |
], | |
) | |
return pipeline( | |
pipe=xgb_pipe_instance, | |
inputs="model_input_table", | |
namespace="xgboost_pipe", | |
parameters={"params:xgboost_params_full_feats"}) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment