-
-
Save zilto/18d0b7ac42099f9db9750607cf53d3ec to your computer and use it in GitHub Desktop.
Usage pattern for Hamilton within an Airflow DAG task
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from airflow.decorators import dag, task | |
from airflow.operators.python import get_current_context | |
# set the Airflow DAG parameters. This will appear in the Airflow UI. | |
DEFAULT_DAG_PARAMS = dict( | |
label="absenteeism_time_in_hours", | |
feature_set=[ | |
"age_zero_mean_unit_variance", | |
"has_children", | |
"has_pet", | |
"is_summer", | |
"service_time", | |
], | |
h_train_and_evaluate=dict(...), # config for the Hamilton Driver | |
) | |
@dag( | |
dag_id="hamilton-absenteeism-prediction", | |
description="Predict absenteeism using Hamilton and Airflow", | |
start_date=datetime(2023, 6, 18), | |
params=DEFAULT_DAG_PARAMS, # pass the default params to the Airflow DAG | |
) | |
def absenteeism_prediction_dag(): | |
"""Predict absenteeism using Hamilton and Airflow""" | |
# Below we have a single Airflow task that uses 2 Python modules (evaluate_model, train_model). | |
# Both are loaded into the Hamilton driver in a single Airflow task, reducing the number of | |
# Airflow task and preventing having to move data between the two steps. However, it remains | |
# beneficial to separate the code into 2 modules since training and evaluation are independent and | |
# might be reused in separate contexts. | |
@task | |
def train_and_evaluate_model(features_path: str): | |
"""Train and evaluate a machine learning model""" | |
import evaluate_model # user defined function module | |
import train_model # user defined function module | |
from hamilton import base, driver | |
context = get_current_context() | |
PARAMS = context["params"] # get the Airflow runtime config | |
hamilton_config = PARAMS["h_train_and_evaluate_model"] | |
dr = driver.Driver( | |
hamilton_config, | |
train_model, | |
evaluate_model, # pass function modules to the Hamilton driver | |
adapter=base.SimplePythonGraphAdapter(base.DictResult()), | |
) | |
results = dr.execute( | |
# `final_vars` specifies Hamilton functions results we want as outputs. | |
final_vars=["save_validation_preds", "model_results"], | |
inputs={ | |
"features_path": features_path, # value retrieved from Airflow XCom | |
"label": PARAMS["label"], | |
"feature_set": PARAMS["feature_set"], | |
}, | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment