Created
April 26, 2020 06:59
-
-
Save syossan27/9549d3efffda84b67faaac09be313f05 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import requests | |
from airflow import DAG | |
from airflow.contrib.sensors.gcs_sensor import GoogleCloudStoragePrefixSensor | |
from airflow.exceptions import AirflowException | |
from airflow.hooks.http_hook import HttpHook | |
from airflow.operators.http_operator import SimpleHttpOperator | |
from airflow.utils.dates import days_ago | |
from datetime import timedelta, datetime | |
cloud_functions_url = 'https://asia-northeast1-inference-pipeline.cloudfunctions.net' | |
metadata_url = 'http://metadata/computeMetadata/v1/instance/service-accounts/default/identity?audience=' | |
dag = DAG( | |
'inference_pipeline', | |
default_args={ | |
'start_date': days_ago(1), | |
'retries': 1, | |
'retry_delay': timedelta(minutes=5), | |
}, | |
schedule_interval='@daily', | |
dagrun_timeout=timedelta(minutes=60), | |
catchup=False | |
) | |
class RunCloudFunctionsOperator(SimpleHttpOperator): | |
# SimpleHttpOperatorのexecuteをオーバーライドし、Cloud Functionsへリクエストする処理に変更する | |
def execute(self, context): | |
http = HttpHook(self.method, http_conn_id=self.http_conn_id) | |
self.log.info("Calling HTTP method") | |
# OAuthIDトークンの取得 | |
# https://cloud.google.com/functions/docs/securing/authenticating?hl=ja#function-to-function | |
# https://cloud.google.com/compute/docs/instances/verifying-instance-identity?hl=ja | |
target_audience = cloud_functions_url + self.endpoint | |
fetch_instance_id_token_url = metadata_url + target_audience | |
r = requests.get(fetch_instance_id_token_url, headers={"Metadata-Flavor": "Google"}, verify=False) | |
idt = r.text | |
# Cloud Functionsへリクエスト | |
self.headers = {'Authorization': "Bearer " + idt} | |
response = http.run(self.endpoint, | |
self.data, | |
self.headers, | |
self.extra_options) | |
# レスポンスをチェック | |
if self.response_check: | |
if not self.response_check(response): | |
raise AirflowException("Response check returned False.") | |
csv_sensor = GoogleCloudStoragePrefixSensor( | |
task_id='csv_sensor', | |
bucket='test', | |
prefix='data/{}-'.format(datetime.now().strftime('%Y%m%d')), | |
timeout=60 * 60 * 24 * 2, | |
pool='csv_sensor', | |
dag=dag | |
) | |
preprocessing = RunCloudFunctionsOperator( | |
task_id='preprocessing', | |
method='GET', | |
http_conn_id='http_default', | |
endpoint='/preprocessing', | |
response_check=lambda response: False if response.status_code != 200 else True, # Status codeが200以外の場合はFalseとして返す | |
dag=dag, | |
) | |
# タスク依存関係の設定 | |
csv_sensor >> preprocessing |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment