-
-
Save zilto/520cfefecaa5fec9c81014db9efe46df to your computer and use it in GitHub Desktop.
How to use Hamilton within Prefect
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from hamilton import base, driver | |
from prefect import flow, task | |
from prefect.blocks.system import JSON | |
# import modules containing your dataflow functions | |
import train_model | |
import evaluate_model | |
# use the @task to define Prefect tasks, which adds logging, retries, etc. | |
# the function parameters define the config and inputs needed by Hamilton | |
@task | |
def train_and_evaluate_model_task( | |
features_path: str, | |
hamilton_config: str, | |
label: str, | |
feature_set: list[str], | |
validation_user_ids: list[str], | |
) -> None: | |
"""Train and evaluate machine learning model""" | |
# define the Driver object with configurations and modules | |
dr = driver.Driver( | |
hamilton_config, | |
train_model, # imported data transformation module | |
evaluate_model, # imported data transformation module | |
adapter=base.SimplePythonGraphAdapter(base.DictResult()), | |
) | |
# execute the DAG to produce and outputs the requested `final_vars` | |
dr.execute( | |
final_vars=["save_validation_preds", "model_results"], | |
inputs=dict( | |
features_path=features_path, | |
label=label, | |
feature_set=feature_set, | |
validation_user_ids=validation_user_ids, | |
), | |
) | |
# use @flow to define the Prefect flow. | |
# the function parameters define the config and inputs needed by all tasks | |
# this way, we prevent having constants being hardcoded in the flow or task body | |
@flow( | |
name="hamilton-absenteeism-prediction", | |
description="Predict absenteeism using Hamilton and Prefect", | |
) | |
def absenteeism_prediction_flow( | |
features_path: str = ..., | |
feature_set: list[str] = [ | |
"age_zero_mean_unit_variance", | |
"has_children", | |
"has_pet", | |
"is_summer", | |
"service_time", | |
], | |
label: str = "absenteeism_time_in_hours", | |
validation_user_ids: list[str] = [...], | |
): | |
"""Predict absenteeism using Hamilton and Prefect""" | |
# ... more tasks | |
# load a Prefect Block containing the Hamilton Driver config | |
hamilton_config_block = JSON.load("hamilton-train-and-evaluate-config") | |
# call the Prefect task from the workflow | |
train_and_evaluate_model_task( | |
features_path=features_path, | |
hamilton_config=json.load(hamilton_config_block), | |
label=label, | |
feature_set=feature_set, | |
validation_user_ids=validation_user_ids, | |
) | |
# ... more tasks | |
if __name__ == "__main__": | |
absenteeism_prediction_flow() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment