Created
April 26, 2020 09:04
-
-
Save syossan27/f1ab396325f8c333b04a4bc1545e8215 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import requests | |
from airflow import DAG | |
from airflow.contrib.sensors.gcs_sensor import GoogleCloudStoragePrefixSensor | |
from airflow.contrib.operators.gcs_to_bq import GoogleCloudStorageToBigQueryOperator | |
from airflow.exceptions import AirflowException | |
from airflow.hooks.http_hook import HttpHook | |
from airflow.operators.http_operator import SimpleHttpOperator | |
from airflow.operators.python_operator import PythonOperator | |
from airflow.utils.dates import days_ago | |
from datetime import timedelta, datetime | |
from google.cloud import automl_v1beta1 as automl | |
cloud_functions_url = 'https://asia-northeast1-inference-pipeline.cloudfunctions.net' | |
metadata_url = 'http://metadata/computeMetadata/v1/instance/service-accounts/default/identity?audience=' | |
project_id = 'inference-pipeline' | |
automl_tables_region = 'us-central1' | |
model_id = 'TBL0000000000000000000' | |
dag = DAG( | |
'inference_pipeline', | |
default_args={ | |
'start_date': days_ago(1), | |
'retries': 1, | |
'retry_delay': timedelta(minutes=5) | |
}, | |
schedule_interval='@daily', | |
dagrun_timeout=timedelta(minutes=60), | |
catchup=False | |
) | |
class RunCloudFunctionsOperator(SimpleHttpOperator): | |
def execute(self, context): | |
http = HttpHook(self.method, http_conn_id=self.http_conn_id) | |
self.log.info("Calling HTTP method") | |
target_audience = cloud_functions_url + self.endpoint | |
fetch_instance_id_token_url = metadata_url + target_audience | |
r = requests.get(fetch_instance_id_token_url, headers={"Metadata-Flavor": "Google"}, verify=False) | |
idt = r.text | |
self.headers = {'Authorization': "Bearer " + idt} | |
response = http.run(self.endpoint, | |
self.data, | |
self.headers, | |
self.extra_options) | |
if self.response_check: | |
if not self.response_check(response): | |
raise AirflowException("Response check returned False.") | |
class RunCloudFunctionsWithOptOperator(SimpleHttpOperator): | |
def execute(self, context): | |
http = HttpHook(self.method, http_conn_id=self.http_conn_id) | |
self.log.info("Calling HTTP method") | |
target_audience = cloud_functions_url + self.endpoint | |
fetch_instance_id_token_url = metadata_url + target_audience | |
r = requests.get(fetch_instance_id_token_url, headers={"Metadata-Flavor": "Google"}, verify=False) | |
idt = r.text | |
# XComから予測結果CSVが格納されているディレクトリ情報を取得 | |
gcs_output_dir = context['ti'].xcom_pull(key='predicted output directory', task_ids='predict') | |
endpoint = self.endpoint + f'?dir={gcs_output_dir}' | |
self.headers = {'Authorization': "Bearer " + idt} | |
response = http.run(endpoint, | |
self.data, | |
self.headers, | |
self.extra_options) | |
if self.response_check: | |
if not self.response_check(response): | |
raise AirflowException("Response check returned False.") | |
csv_sensor = GoogleCloudStoragePrefixSensor( | |
task_id='csv_sensor', | |
bucket='test', | |
prefix='data/{}-'.format(datetime.now().strftime('%Y%m%d')), | |
timeout=60 * 60 * 24 * 2, | |
pool='csv_sensor', | |
dag=dag | |
) | |
preprocessing = RunCloudFunctionsOperator( | |
task_id='preprocessing', | |
method='GET', | |
http_conn_id='http_default', | |
endpoint='/preprocessing', | |
headers={}, | |
xcom_push=False, | |
response_check=lambda response: False if response.status_code != 200 else True, | |
dag=dag, | |
) | |
import_bq = GoogleCloudStorageToBigQueryOperator( | |
task_id='import_bq', | |
bucket='test', | |
source_objects=['preprocess_data/*.csv'], | |
source_format='CSV', | |
allow_quoted_newlines=True, | |
skip_leading_rows=1, | |
destination_project_dataset_table='test.data', | |
schema_fields=[ | |
{'name': 'id', 'type': 'INTEGER'}, | |
], | |
write_disposition='WRITE_TRUNCATE', | |
dag=dag | |
) | |
postprocessing = RunCloudFunctionsWithOptOperator( | |
task_id='postprocessing', | |
method='GET', | |
http_conn_id='http_default', | |
endpoint='/postprocessing', | |
headers={}, | |
xcom_push=False, | |
response_check=lambda response: False if response.status_code != 200 else True, | |
dag=dag, | |
) | |
def do_deploy_model(): | |
client = automl.AutoMlClient() | |
model_full_id = client.model_path(project_id, automl_tables_region, model_id) | |
response = client.deploy_model(model_full_id) | |
print(u'Model deployment finished. {}'.format(response.result())) | |
return | |
def do_predict(**kwargs): | |
# AutoMLクライアント作成 | |
client = automl.AutoMlClient() | |
model_full_id = client.model_path(project_id, automl_tables_region, model_id) | |
# 予測クライアント作成 | |
predict_client = automl.PredictionServiceClient() | |
# 入力元となるBigQueryの指定 | |
input_uri = 'bq://inference-pipeline.test.data' | |
input_config = {"bigquery_source": {"input_uri": input_uri}} | |
# 出力先となるCloud Storageの指定 | |
output_uri = 'gs://test/predicted_data' | |
output_config = {"gcs_destination": {"output_uri_prefix": output_uri}} | |
# 予測の実行 | |
response = predict_client.batch_predict(model_full_id, input_config, output_config) | |
response.result() | |
result = response.metadata | |
# 予測結果から出力されたCloud Storageのディレクトリ名を読み取り、XCOMにプッシュ | |
gcs_output_dir = result.batch_predict_details.output_info.gcs_output_directory | |
kwargs['ti'].xcom_push(key='predicted output directory', value=gcs_output_dir) | |
print(u'Predict finished. {}'.format(response.result())) | |
return | |
def do_delete_model(): | |
client = automl.AutoMlClient() | |
model_full_id = client.model_path(project_id, automl_tables_region, model_id) | |
response = client.undeploy_model(model_full_id) | |
print(u'Model delete finished. {}'.format(response.result())) | |
return | |
deploy_model = PythonOperator( | |
task_id='deploy_model', | |
dag=dag, | |
python_callable=do_deploy_model, | |
) | |
predict = PythonOperator( | |
task_id='predict', | |
dag=dag, | |
provide_context=True, | |
python_callable=do_predict, | |
) | |
delete_model = PythonOperator( | |
task_id='delete_model', | |
trigger_rule='all_done', | |
dag=dag, | |
python_callable=do_delete_model, | |
) | |
# タスク依存関係の設定 | |
csv_sensor >> preprocessing >> import_bq >> deploy_model >> predict >> delete_model >> postprocessing |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment