Skip to content

Instantly share code, notes, and snippets.

@skrawcz
Created June 29, 2023 18:02
Show Gist options
  • Save skrawcz/21855b8284a32ee8ed506515af27265e to your computer and use it in GitHub Desktop.
Save skrawcz/21855b8284a32ee8ed506515af27265e to your computer and use it in GitHub Desktop.
Usage pattern for Hamilton within an Airflow DAG task
from airflow.decorators import dag, task
from airflow.operators.python import get_current_context
# set the Airflow DAG parameters. This will appear in the Airflow UI.
DEFAULT_DAG_PARAMS = dict(
label="absenteeism_time_in_hours",
feature_set=[
"age_zero_mean_unit_variance",
"has_children",
"has_pet",
"is_summer",
"service_time",
],
h_train_and_evaluate=dict(...), # config for the Hamilton Driver
)
@dag(
dag_id="hamilton-absenteeism-prediction",
description="Predict absenteeism using Hamilton and Airflow",
start_date=datetime(2023, 6, 18),
params=DEFAULT_DAG_PARAMS, # pass the default params to the Airflow DAG
)
def absenteeism_prediction_dag():
"""Predict absenteeism using Hamilton and Airflow"""
# Below we have a single Airflow task that uses 2 Python modules (evaluate_model, train_model).
# Both are loaded into the Hamilton driver in a single Airflow task, reducing the number of
# Airflow task and preventing having to move data between the two steps. However, it remains
# beneficial to separate the code into 2 modules since training and evaluation are independent and
# might be reused in separate contexts.
@task
def train_and_evaluate_model(features_path: str):
"""Train and evaluate a machine learning model"""
import evaluate_model # user defined function module
import train_model # user defined function module
from hamilton import base, driver
context = get_current_context()
PARAMS = context["params"] # get the Airflow runtime config
hamilton_config = PARAMS["h_train_and_evaluate_model"]
dr = driver.Driver(
hamilton_config,
train_model,
evaluate_model, # pass function modules to the Hamilton driver
adapter=base.SimplePythonGraphAdapter(base.DictResult()),
)
results = dr.execute(
# `final_vars` specifies Hamilton functions results we want as outputs.
final_vars=["save_validation_preds", "model_results"],
inputs={
"features_path": features_path, # value retrieved from Airflow XCom
"label": PARAMS["label"],
"feature_set": PARAMS["feature_set"],
},
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment