Skip to content

Instantly share code, notes, and snippets.

@zilto
Last active July 21, 2023 21:10
Show Gist options
  • Save zilto/520cfefecaa5fec9c81014db9efe46df to your computer and use it in GitHub Desktop.
Save zilto/520cfefecaa5fec9c81014db9efe46df to your computer and use it in GitHub Desktop.
How to use Hamilton within Prefect
from hamilton import base, driver
from prefect import flow, task
from prefect.blocks.system import JSON
# import modules containing your dataflow functions
import train_model
import evaluate_model
# use the @task to define Prefect tasks, which adds logging, retries, etc.
# the function parameters define the config and inputs needed by Hamilton
@task
def train_and_evaluate_model_task(
features_path: str,
hamilton_config: str,
label: str,
feature_set: list[str],
validation_user_ids: list[str],
) -> None:
"""Train and evaluate machine learning model"""
# define the Driver object with configurations and modules
dr = driver.Driver(
hamilton_config,
train_model, # imported data transformation module
evaluate_model, # imported data transformation module
adapter=base.SimplePythonGraphAdapter(base.DictResult()),
)
# execute the DAG to produce and outputs the requested `final_vars`
dr.execute(
final_vars=["save_validation_preds", "model_results"],
inputs=dict(
features_path=features_path,
label=label,
feature_set=feature_set,
validation_user_ids=validation_user_ids,
),
)
# use @flow to define the Prefect flow.
# the function parameters define the config and inputs needed by all tasks
# this way, we prevent having constants being hardcoded in the flow or task body
@flow(
name="hamilton-absenteeism-prediction",
description="Predict absenteeism using Hamilton and Prefect",
)
def absenteeism_prediction_flow(
features_path: str = ...,
feature_set: list[str] = [
"age_zero_mean_unit_variance",
"has_children",
"has_pet",
"is_summer",
"service_time",
],
label: str = "absenteeism_time_in_hours",
validation_user_ids: list[str] = [...],
):
"""Predict absenteeism using Hamilton and Prefect"""
# ... more tasks
# load a Prefect Block containing the Hamilton Driver config
hamilton_config_block = JSON.load("hamilton-train-and-evaluate-config")
# call the Prefect task from the workflow
train_and_evaluate_model_task(
features_path=features_path,
hamilton_config=json.load(hamilton_config_block),
label=label,
feature_set=feature_set,
validation_user_ids=validation_user_ids,
)
# ... more tasks
if __name__ == "__main__":
absenteeism_prediction_flow()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment