Skip to content

Instantly share code, notes, and snippets.

@zhitaoli
Created July 29, 2021 22:46
Show Gist options
  • Save zhitaoli/b2d92f8ad04d98d99974513563149d33 to your computer and use it in GitHub Desktop.
Save zhitaoli/b2d92f8ad04d98d99974513563149d33 to your computer and use it in GitHub Desktop.
# Command to reproduce, after following file named `test_airflow_retry.py` is copied to $AIRFLOW_HOME/dags
# assuming Airflow is installed and running from $HOME:
# `rm -fr ./airflow/logs && airflow db reset -y && airflow dags unpause test_airflow_retry && rm -fr $HOME/tfx && rm -f /tmp/test_airflow_retry* && airflow dags trigger test_airflow_retry -e '2019-02-01T01:01:01'`
import datetime
import os
import threading
from absl import logging
import tfx
from tfx.v1.dsl.components import component
from tfx.v1.dsl import Pipeline
from tfx.v1.dsl.components import InputArtifact
from tfx.v1.dsl.components import OutputArtifact
from tfx.v1.types.standard_artifacts import Model
from tfx.orchestration.airflow.airflow_dag_runner import AirflowDagRunner
from tfx.orchestration.airflow.airflow_dag_runner import AirflowPipelineConfig
_pipeline_name = 'test_airflow_retry'
_tfx_root = os.path.join(os.environ['HOME'], 'tfx')
_pipeline_root = os.path.join(_tfx_root, 'pipelines', _pipeline_name)
# Sqlite ML-metadata db path.
_metadata_path = os.path.join(_tfx_root, 'metadata', _pipeline_name,
'metadata.db')
counter = 0
counter_lock = threading.Lock()
@component
def fail_every_other(model: OutputArtifact[Model]):
file = os.path.join('/tmp/', _pipeline_name + '_lock')
if os.path.exists(file):
logging.info("fail_every_other: marker file exists so removing it")
os.remove(file)
else:
with open(file, 'w+') as fp:
fp.write('this is a marker file')
assert 0 == 1, "fail_every_other: failing and creating marker"
logging.info("fail_every_other: populating model")
with open(os.path.join(model.uri, 'model.txt'), 'w+') as f:
f.write('fake model content from upstream')
print("fail_every_other: model populated")
@component
def downstream(model:InputArtifact[Model]):
logging.info("downstream: called")
with open(os.path.join(model.uri, 'model.txt'), 'r') as f:
print('downstream: %s' % f.readlines())
c1 = fail_every_other()
c2 = downstream(model=c1.outputs['model'])
pipeline = Pipeline(components=[c1, c2],
pipeline_name=_pipeline_name,
pipeline_root=_pipeline_root,
metadata_connection_config=tfx.orchestration.metadata.sqlite_metadata_connection_config(_metadata_path,),
)
DAG = AirflowDagRunner(
config=AirflowPipelineConfig(
airflow_dag_config={
# 'dag_id': _pipeline_name,
'default_args': {
'retries': 2,
'retry_delay': datetime.timedelta(seconds=5),
},
'schedule_interval': None,
'start_date': datetime.datetime(2019, 1, 1),
})
).run(pipeline)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment