Created
July 29, 2021 22:46
-
-
Save zhitaoli/b2d92f8ad04d98d99974513563149d33 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Command to reproduce, after following file named `test_airflow_retry.py` is copied to $AIRFLOW_HOME/dags | |
# assuming Airflow is installed and running from $HOME: | |
# `rm -fr ./airflow/logs && airflow db reset -y && airflow dags unpause test_airflow_retry && rm -fr $HOME/tfx && rm -f /tmp/test_airflow_retry* && airflow dags trigger test_airflow_retry -e '2019-02-01T01:01:01'` | |
import datetime | |
import os | |
import threading | |
from absl import logging | |
import tfx | |
from tfx.v1.dsl.components import component | |
from tfx.v1.dsl import Pipeline | |
from tfx.v1.dsl.components import InputArtifact | |
from tfx.v1.dsl.components import OutputArtifact | |
from tfx.v1.types.standard_artifacts import Model | |
from tfx.orchestration.airflow.airflow_dag_runner import AirflowDagRunner | |
from tfx.orchestration.airflow.airflow_dag_runner import AirflowPipelineConfig | |
_pipeline_name = 'test_airflow_retry' | |
_tfx_root = os.path.join(os.environ['HOME'], 'tfx') | |
_pipeline_root = os.path.join(_tfx_root, 'pipelines', _pipeline_name) | |
# Sqlite ML-metadata db path. | |
_metadata_path = os.path.join(_tfx_root, 'metadata', _pipeline_name, | |
'metadata.db') | |
counter = 0 | |
counter_lock = threading.Lock() | |
@component | |
def fail_every_other(model: OutputArtifact[Model]): | |
file = os.path.join('/tmp/', _pipeline_name + '_lock') | |
if os.path.exists(file): | |
logging.info("fail_every_other: marker file exists so removing it") | |
os.remove(file) | |
else: | |
with open(file, 'w+') as fp: | |
fp.write('this is a marker file') | |
assert 0 == 1, "fail_every_other: failing and creating marker" | |
logging.info("fail_every_other: populating model") | |
with open(os.path.join(model.uri, 'model.txt'), 'w+') as f: | |
f.write('fake model content from upstream') | |
print("fail_every_other: model populated") | |
@component | |
def downstream(model:InputArtifact[Model]): | |
logging.info("downstream: called") | |
with open(os.path.join(model.uri, 'model.txt'), 'r') as f: | |
print('downstream: %s' % f.readlines()) | |
c1 = fail_every_other() | |
c2 = downstream(model=c1.outputs['model']) | |
pipeline = Pipeline(components=[c1, c2], | |
pipeline_name=_pipeline_name, | |
pipeline_root=_pipeline_root, | |
metadata_connection_config=tfx.orchestration.metadata.sqlite_metadata_connection_config(_metadata_path,), | |
) | |
DAG = AirflowDagRunner( | |
config=AirflowPipelineConfig( | |
airflow_dag_config={ | |
# 'dag_id': _pipeline_name, | |
'default_args': { | |
'retries': 2, | |
'retry_delay': datetime.timedelta(seconds=5), | |
}, | |
'schedule_interval': None, | |
'start_date': datetime.datetime(2019, 1, 1), | |
}) | |
).run(pipeline) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment