Skip to content

Instantly share code, notes, and snippets.

@ZaxR
Created September 22, 2022 14:21
Show Gist options
  • Save ZaxR/39d86ce7080998d441d6ce14a85adc57 to your computer and use it in GitHub Desktop.
Save ZaxR/39d86ce7080998d441d6ce14a85adc57 to your computer and use it in GitHub Desktop.
#! /usr/bin/env python3
"""CLV acquisition predictions, triggered daily."""
from datetime import datetime, timedelta
from typing import Optional
from airflow import DAG
from helpers import k8s, mlops, mlops_factories, settings
from helpers.custom_operators import ModifiedKubernetesPodOperator, SnapshotStartDateOperator, UpdateConfOperator
from helpers.sensors.bigquery import BigQueryTableUpdatedSensor
from helpers.slack_ids import SlackID
##############################################################################
# DAG Setup
##############################################################################
DAG_ID = "clv_acquisition_predict"
PROJECT_NAME = "ds-cc-clv-acquisition"
# Capturing run_id based on start date instead of datetime.now().
# Note that the run_id in the UI will still be created using execution_date, and cannot be changed.
RUN_ID = mlops.run_id_template(settings)
RELEASE_TYPE = "prod" if settings.IS_PROD_ENVIRONMENT else "dev"
DOCKER_TAG = "prod" if settings.IS_PROD_ENVIRONMENT else "dev"
GA_DEBUG_MODE = False # True for local development only
GA_TRACKING_ID = "UA-11111111-1" if settings.IS_PROD_ENVIRONMENT else "UA-11111111-2"
GCP_PROJECT = "prodproject" if settings.IS_PROD_ENVIRONMENT else "devproject"
OUTPUT_BQ_PROJECT = "prodproject" if settings.IS_PROD_ENVIRONMENT else "devproject"
GOOGLE_BUCKET = "gcsbucket"
DEV_SLACK_ID = SlackID.DEVELOPER
CHANNEL_ID = SlackID.AUTOBIDDER if settings.IS_PROD_ENVIRONMENT else DEV_SLACK_ID
SENSOR_INTERVAL = 60 * 5
MAX_TOTAL_TASK_RUN_TIME = (
60 * 10
) # Max expected total DAG run time excluding sensor waiting
CONF = {
"IMAGE": f"gcr.io/zorodataplatform/{PROJECT_NAME}:{DOCKER_TAG}",
"NOTIFICATIONS": {
"SUCCESS": {"SLACK_IDS": [DEV_SLACK_ID, CHANNEL_ID]},
"FAILURE": {
"MAINTAINER_SLACK_IDS": [DEV_SLACK_ID],
"SLACK_IDS": [DEV_SLACK_ID],
},
},
"RELEASE_TYPE": RELEASE_TYPE,
"PREDICT_BATCH": {
"CPU": "1",
"ENV_VARS": {
"GOOGLE_BUCKET": GOOGLE_BUCKET,
"GOOGLE_PROJECT_BQ": GCP_PROJECT,
"OUTPUT_BQ_PROJECT": OUTPUT_BQ_PROJECT,
"RUN_ID": RELEASE_TYPE,
},
"IMAGE": f"gcr.io/zorodataplatform/{PROJECT_NAME}:{DOCKER_TAG}",
"MEMORY": "2G",
"RELEASE_TYPE": RELEASE_TYPE,
"SCRIPT": "legacy_predict.py", # "predict.py",
# Typically takes 1-3 minutes
"TIMEOUT": int(60 * 10),
},
}
##############################################################################
# Helper Functions
##############################################################################
def add_task(
dag: DAG,
task_name: str,
task_suffix: Optional[str] = None,
task_suffix_sep: str = mlops.TASK_SUFFIX_SEP,
task_retries: int = mlops.DEFAULT_TASK_RETRIES,
default_task_timeout: int = mlops.DEFAULT_TASK_TIMEOUT,
do_xcom_push: bool = False,
trigger_rule: str = "all_success",
) -> None:
"""Factory function to add add a task to `dag`.
Args:
dag: DAG to add the task to.
task_name: Base name of the task in the DAG.
Will be the whole task name if no `task_suffix` is provided.
Used to determine the namespace from which to grab `conf`.
task_suffix: Suffix to be appended to `task_name` when naming the task, if desired.
Useful when repeating a task multiple times in the same dag.
Used to determine the sub-namespace from which to grab `conf`, if provided.
task_suffix_sep: Characters to separate `task_name` from `task_suffix` in the task name,
when `task_suffix` is provided.
task_retries: Number of task tries. Supercedes the DAG's default.
default_task_timeout: Default task timeout in seconds, if not provided via conf.
do_xcom_push: Whether or not to push xcom.
Only enable this if you're writing xcom to /airflow/xcom/return.json,
or a handshake error (see MLOPS-94) may result.
"""
if task_suffix is not None and not isinstance(task_suffix, str):
raise ValueError("`task_suffix` must be a string")
user_defined_filters = {
"get_secret_name": mlops.get_secret_name,
"get_mlops_tolerations": k8s.get_mlops_tolerations,
}
user_defined_macros = {
"get_task_conf": mlops.get_task_conf,
}
mlops.add_macros_and_filters(
dag=dag,
user_defined_filters=user_defined_filters,
user_defined_macros=user_defined_macros,
)
full_task_name = (
f"{task_name}{task_suffix_sep}{task_suffix}".lower()
if task_suffix is not None
else task_name.lower()
)
conf_str = f"""get_task_conf(dag_run, {task_name!r}, {task_suffix!r})"""
task = ModifiedKubernetesPodOperator( # noqa: F841
task_id=full_task_name,
trigger_rule=trigger_rule,
name=f"mlops_{full_task_name}",
image=f"{{{{ {conf_str}['IMAGE'] }}}}",
namespace="default",
cluster_name=mlops.CLUSTER_NAME,
cluster_zone=mlops.CLUSTER_ZONE,
image_pull_policy="Always",
labels={
"timeout_seconds": f"{{{{ {conf_str}['TIMEOUT'] | default('{default_task_timeout}', true) }}}}",
},
startup_timeout_seconds=600,
env_vars=f"{{{{ {conf_str} }}}}",
secrets=f'{{{{ {conf_str}.get("RELEASE_TYPE", "exp") | get_secret_name() }}}}',
resources={
"request_memory": f"{{{{ {conf_str}.get('MEMORY', '12G') }}}}",
"request_cpu": f"{{{{ {conf_str}.get('CPU', '3') }}}}",
"limit_memory": f"{{{{ {conf_str}.get('MEMORY', '12G') }}}}",
"limit_cpu": f"{{{{ {conf_str}.get('CPU', '3') }}}}",
"limit_gpu": f"{{{{ {conf_str}.get('GPU_COUNT') }}}}",
},
# Pushes the content of /airflow/xcom/return.json from container to an XCom when the container ends.
do_xcom_push=do_xcom_push,
node_selectors={"zdp/purpose": "mlops"},
tolerations=f"{{{{ {conf_str}.get('GPU_TYPE') | get_mlops_tolerations }}}}",
# Params overwrite execution_timeout from conf["TRAIN_TIMEOUT"], if provided
execution_timeout=timedelta(seconds=default_task_timeout),
params={
"TIMEOUT_CONF": "TIMEOUT",
"CONF_STR": conf_str,
},
retries=task_retries,
retry_delay=timedelta(seconds=300),
retry_exponential_backoff=True,
on_failure_callback=mlops.send_failure_message,
dag=dag,
)
##############################################################################
# DAG
##############################################################################
with DAG(
dag_id=DAG_ID,
schedule_interval="0 17 * * *", # Daily at 17:00 UTC; time is relatively arbitrary
max_active_runs=1,
catchup=False,
default_args=mlops.get_dag_default_args(),
) as dag:
t_snapshot_start_date = SnapshotStartDateOperator(task_id="snapshot_start_date")
t_update_conf = UpdateConfOperator(
task_id="update_conf", given_conf=CONF, replace=False
)
# TODO: Replace with a Deferrable Operator once available in Composer
t_bq_sensor_s_customer_acq = BigQueryTableUpdatedSensor(
task_id="bq_sensor_s_customer_acq",
project_id=GCP_PROJECT,
dataset_id="some_dataset",
table_id=f"s_customer_acquisition",
comparison_time='{{dag_run.conf["_zoro_mlops"]["first_attempt_start_date"]}}',
poke_interval=SENSOR_INTERVAL,
# Must time out before the next schedule interval (daily) to avoid potential deadlock.
# timeout is relative to a single task try (i.e. timeout*retries = total time).
timeout=60 * 60 * 4 - (SENSOR_INTERVAL + MAX_TOTAL_TASK_RUN_TIME),
# Reschedule mode frees up a worker slot between checks.
mode="reschedule",
retries=0,
)
partition_date = (
datetime.now(mlops.CHICAGO_TIMEZONE) - timedelta(days=1)
).strftime(
"%Y%m%d"
) # e.g. 20220620
# This task will always run, even if there's no valid input data
mlops_factories.add_predict_batch_task(dag)
(
t_snapshot_start_date
>> t_update_conf
>> t_bq_sensor_s_customer_acq
>> dag.get_task("predict_batch")
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment