Skip to content

Instantly share code, notes, and snippets.

@pingsutw
Created July 26, 2023 20:04
Show Gist options
  • Save pingsutw/9c7acc288a67d2702a2293cc1e73456e to your computer and use it in GitHub Desktop.
Save pingsutw/9c7acc288a67d2702a2293cc1e73456e to your computer and use it in GitHub Desktop.
Airflow example
from datetime import timedelta
from airflow import models
from airflow.contrib.operators import dataproc_operator
from airflow.operators.latest_only_operator import LatestOnlyOperator
from airflow.utils import trigger_rule
from chat_plugin import ChatNotifyOperator
from dataproc import hermes
from porch import settings as porch_settings, label
from porch import utils as porch_utils
app_name = "reserve_price_training"
dag_id = "dataproc_" + app_name
deployedLibraries = [
porch_utils.maven_to_url(porch_settings.MAVEN_REPOSITORY, "com.porch.science", "pmml-spark-cli", "1.0.289169"),
porch_utils.maven_to_url(porch_settings.MAVEN_REPOSITORY, "com.porch.science", "hermes-reserve", hermes.RESERVE_PRICE_TRAIN_BUILD_VERSION),
]
jvmNamespace = "com.porch.science.hermes.reserve"
aggregateDataPath = "gs://porch-science/data/model/hermes/reserve_price/aggregate_v6"
trainingDataPath = "gs://porch-science/data/model/hermes/reserve_price/{{dag_run.start_date.strftime('%Y%m%d')}}"
modelDeployPath = "gs://porch-science/deployable/hermes/reserve-price"
cron_expression = "4 9 * * 1"
start_date = porch_utils.generate_start_time(cron_expression, dag_id)
default_dag_args = porch_settings.create_default_dag_args()
default_dag_args.update({
"sla": timedelta(hours=24),
"start_date": start_date,
"cluster_name": "dataproc-" + app_name.replace("_", "-")[-30:] + "-{{ds_nodash}}",
"email": porch_settings.DEFAULT_NOTIFY_EMAIL
})
with models.DAG(
dag_id=dag_id,
schedule_interval=cron_expression,
catchup=False,
max_active_runs=1,
default_args=default_dag_args) as dag:
latest_only = LatestOnlyOperator(
task_id="latest_only"
)
create_dataproc_cluster = dataproc_operator.DataprocClusterCreateOperator(
task_id="create_dataproc_cluster",
network_uri="https://www.googleapis.com/compute/v1/projects/porch-gcp/global/networks/porch",
image_version="2.0.27-debian10",
storage_bucket="porch-science-dataproc",
service_account="ds-dataproc@porch-gcp.iam.gserviceaccount.com",
service_account_scopes=["https://www.googleapis.com/auth/cloud-platform"],
master_machine_type="n1-highmem-32",
master_disk_size=1024,
num_workers=4,
worker_machine_type="n1-highmem-64",
worker_disk_size=1024,
labels={label.TEAM_OWNER_LABEL: label.TEAM_OWNER_VALUE,
label.COST_CENTER_LABEL: label.MARKETPLACE_COST_CENTER_VALUE,
label.COST_DETAIL_LABEL: "hermes_reserve_model"}
)
create_dataproc_pyspark_cluster = dataproc_operator.DataprocClusterCreateOperator(
task_id="create_dataproc_pyspark_cluster",
network_uri="https://www.googleapis.com/compute/v1/projects/porch-gcp/global/networks/porch",
image_version="2.0.27-debian10",
storage_bucket="porch-science-dataproc",
service_account="ds-dataproc@porch-gcp.iam.gserviceaccount.com",
service_account_scopes=["https://www.googleapis.com/auth/cloud-platform"],
master_machine_type="n1-highmem-16",
master_disk_size=1024,
num_workers=2,
worker_machine_type="n1-highmem-16",
worker_disk_size=1024,
init_actions_uris=["gs://dataproc-initialization-actions/python/pip-install.sh",
"gs://porch-science-dataproc/dataproc-initializations/porch-pypi-install-hermes-eval.sh"],
metadata={"PIP_PACKAGES": "six==1.13.0"},
properties={"spark:spark.jars.packages": "com.google.cloud.spark:spark-bigquery-with-dependencies_2.12:0.17.3"},
labels={"team": "datascience", "trigger": "airflow"}
)
gather_data = dataproc_operator.DataProcSparkOperator(
task_id="gather_data",
dataproc_spark_jars=deployedLibraries,
main_class=jvmNamespace + ".GatherData",
arguments=[aggregateDataPath],
)
build_train_data = dataproc_operator.DataProcSparkOperator(
task_id="build_train_data",
dataproc_spark_jars=deployedLibraries,
main_class=jvmNamespace + ".TrainDataBuild",
arguments=[aggregateDataPath, trainingDataPath],
)
train_spark_model = dataproc_operator.DataProcSparkOperator(
task_id="train_spark_model",
dataproc_spark_jars=deployedLibraries,
main_class=jvmNamespace + ".TrainModel",
arguments=[trainingDataPath],
)
convert_to_pmml = dataproc_operator.DataProcSparkOperator(
task_id="convert_to_pmml",
dataproc_spark_jars=deployedLibraries,
main_class=jvmNamespace + ".ConvertToPmml",
arguments=[trainingDataPath, "{{dag_run.start_date.strftime('%Y%m%d')}}", modelDeployPath],
)
evaluate_spark_model = dataproc_operator.DataProcSparkOperator(
task_id="evaluate_spark_model",
dataproc_spark_jars=deployedLibraries,
main_class=jvmNamespace + ".EvaluateModel",
arguments=[trainingDataPath],
)
create_eval_metrics = dataproc_operator.DataProcPySparkOperator(
task_id="create_eval_metrics",
main="gs://porch-science-dataproc/hermes/source/reserve_eval.py",
arguments=["--gcs-bucket", "porch-science",
"--gcs-path", "data/model/hermes/reserve_price/{{dag_run.start_date.strftime('%Y%m%d')}}"]
)
def create_dataproc_delete_operator(task_id: str):
return dataproc_operator.DataprocClusterDeleteOperator(
task_id=task_id,
retries=3,
retry_delay=timedelta(minutes=5),
email_on_failure=True,
email=porch_settings.DEFAULT_NOTIFY_EMAIL,
trigger_rule=trigger_rule.TriggerRule.ALL_DONE
)
delete_dataproc_cluster = create_dataproc_delete_operator("delete_dataproc_cluster")
delete_dataproc_pyspark_cluster = create_dataproc_delete_operator("delete_dataproc_pyspark_cluster")
chat_notify = ChatNotifyOperator(
task_id="chat_notify",
verbosity="INFO",
message='Evaluation data: {}'.format(trainingDataPath)
)
latest_only >> create_dataproc_cluster
create_dataproc_cluster >> gather_data >> build_train_data >> train_spark_model
train_spark_model >> convert_to_pmml >> delete_dataproc_cluster
train_spark_model >> evaluate_spark_model >> delete_dataproc_cluster
delete_dataproc_cluster >> create_dataproc_pyspark_cluster >> create_eval_metrics
create_eval_metrics >> delete_dataproc_pyspark_cluster >> chat_notify
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment