Skip to content

Instantly share code, notes, and snippets.

@syossan27
Created April 26, 2020 09:04
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save syossan27/f1ab396325f8c333b04a4bc1545e8215 to your computer and use it in GitHub Desktop.
Save syossan27/f1ab396325f8c333b04a4bc1545e8215 to your computer and use it in GitHub Desktop.
import requests
from airflow import DAG
from airflow.contrib.sensors.gcs_sensor import GoogleCloudStoragePrefixSensor
from airflow.contrib.operators.gcs_to_bq import GoogleCloudStorageToBigQueryOperator
from airflow.exceptions import AirflowException
from airflow.hooks.http_hook import HttpHook
from airflow.operators.http_operator import SimpleHttpOperator
from airflow.operators.python_operator import PythonOperator
from airflow.utils.dates import days_ago
from datetime import timedelta, datetime
from google.cloud import automl_v1beta1 as automl
cloud_functions_url = 'https://asia-northeast1-inference-pipeline.cloudfunctions.net'
metadata_url = 'http://metadata/computeMetadata/v1/instance/service-accounts/default/identity?audience='
project_id = 'inference-pipeline'
automl_tables_region = 'us-central1'
model_id = 'TBL0000000000000000000'
dag = DAG(
'inference_pipeline',
default_args={
'start_date': days_ago(1),
'retries': 1,
'retry_delay': timedelta(minutes=5)
},
schedule_interval='@daily',
dagrun_timeout=timedelta(minutes=60),
catchup=False
)
class RunCloudFunctionsOperator(SimpleHttpOperator):
def execute(self, context):
http = HttpHook(self.method, http_conn_id=self.http_conn_id)
self.log.info("Calling HTTP method")
target_audience = cloud_functions_url + self.endpoint
fetch_instance_id_token_url = metadata_url + target_audience
r = requests.get(fetch_instance_id_token_url, headers={"Metadata-Flavor": "Google"}, verify=False)
idt = r.text
self.headers = {'Authorization': "Bearer " + idt}
response = http.run(self.endpoint,
self.data,
self.headers,
self.extra_options)
if self.response_check:
if not self.response_check(response):
raise AirflowException("Response check returned False.")
class RunCloudFunctionsWithOptOperator(SimpleHttpOperator):
def execute(self, context):
http = HttpHook(self.method, http_conn_id=self.http_conn_id)
self.log.info("Calling HTTP method")
target_audience = cloud_functions_url + self.endpoint
fetch_instance_id_token_url = metadata_url + target_audience
r = requests.get(fetch_instance_id_token_url, headers={"Metadata-Flavor": "Google"}, verify=False)
idt = r.text
# XComから予測結果CSVが格納されているディレクトリ情報を取得
gcs_output_dir = context['ti'].xcom_pull(key='predicted output directory', task_ids='predict')
endpoint = self.endpoint + f'?dir={gcs_output_dir}'
self.headers = {'Authorization': "Bearer " + idt}
response = http.run(endpoint,
self.data,
self.headers,
self.extra_options)
if self.response_check:
if not self.response_check(response):
raise AirflowException("Response check returned False.")
csv_sensor = GoogleCloudStoragePrefixSensor(
task_id='csv_sensor',
bucket='test',
prefix='data/{}-'.format(datetime.now().strftime('%Y%m%d')),
timeout=60 * 60 * 24 * 2,
pool='csv_sensor',
dag=dag
)
preprocessing = RunCloudFunctionsOperator(
task_id='preprocessing',
method='GET',
http_conn_id='http_default',
endpoint='/preprocessing',
headers={},
xcom_push=False,
response_check=lambda response: False if response.status_code != 200 else True,
dag=dag,
)
import_bq = GoogleCloudStorageToBigQueryOperator(
task_id='import_bq',
bucket='test',
source_objects=['preprocess_data/*.csv'],
source_format='CSV',
allow_quoted_newlines=True,
skip_leading_rows=1,
destination_project_dataset_table='test.data',
schema_fields=[
{'name': 'id', 'type': 'INTEGER'},
],
write_disposition='WRITE_TRUNCATE',
dag=dag
)
postprocessing = RunCloudFunctionsWithOptOperator(
task_id='postprocessing',
method='GET',
http_conn_id='http_default',
endpoint='/postprocessing',
headers={},
xcom_push=False,
response_check=lambda response: False if response.status_code != 200 else True,
dag=dag,
)
def do_deploy_model():
client = automl.AutoMlClient()
model_full_id = client.model_path(project_id, automl_tables_region, model_id)
response = client.deploy_model(model_full_id)
print(u'Model deployment finished. {}'.format(response.result()))
return
def do_predict(**kwargs):
# AutoMLクライアント作成
client = automl.AutoMlClient()
model_full_id = client.model_path(project_id, automl_tables_region, model_id)
# 予測クライアント作成
predict_client = automl.PredictionServiceClient()
# 入力元となるBigQueryの指定
input_uri = 'bq://inference-pipeline.test.data'
input_config = {"bigquery_source": {"input_uri": input_uri}}
# 出力先となるCloud Storageの指定
output_uri = 'gs://test/predicted_data'
output_config = {"gcs_destination": {"output_uri_prefix": output_uri}}
# 予測の実行
response = predict_client.batch_predict(model_full_id, input_config, output_config)
response.result()
result = response.metadata
# 予測結果から出力されたCloud Storageのディレクトリ名を読み取り、XCOMにプッシュ
gcs_output_dir = result.batch_predict_details.output_info.gcs_output_directory
kwargs['ti'].xcom_push(key='predicted output directory', value=gcs_output_dir)
print(u'Predict finished. {}'.format(response.result()))
return
def do_delete_model():
client = automl.AutoMlClient()
model_full_id = client.model_path(project_id, automl_tables_region, model_id)
response = client.undeploy_model(model_full_id)
print(u'Model delete finished. {}'.format(response.result()))
return
deploy_model = PythonOperator(
task_id='deploy_model',
dag=dag,
python_callable=do_deploy_model,
)
predict = PythonOperator(
task_id='predict',
dag=dag,
provide_context=True,
python_callable=do_predict,
)
delete_model = PythonOperator(
task_id='delete_model',
trigger_rule='all_done',
dag=dag,
python_callable=do_delete_model,
)
# タスク依存関係の設定
csv_sensor >> preprocessing >> import_bq >> deploy_model >> predict >> delete_model >> postprocessing
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment