Skip to content

Instantly share code, notes, and snippets.

@DanielDaCosta
Created December 20, 2020 18:45
Show Gist options
  • Save DanielDaCosta/aa18410afb638c655a73af412b53fec9 to your computer and use it in GitHub Desktop.
Save DanielDaCosta/aa18410afb638c655a73af412b53fec9 to your computer and use it in GitHub Desktop.
from airflow import DAG
from airflow.models import Variable
from airflow.contrib.operators.ecs_operator import ECSOperator
import copy
from datetime import timedelta, datetime
# Airflow Variables
awsRegionName = Variable.get('AwsRegionName')
awsCluster = Variable.get('AwsCluster')
awsTaskDefinition = Variable.get('AwsTaskDefinition')
awsNetworkSubnet = Variable.get('AwsNetworkSubnet')
awsContainerName = Variable.get('AwsContainerName')
AIRFLOW_ECS_OPERATOR_RETRIES = 2
default_args = {
'owner': 'ml-pipeline',
'depends_on_past': False,
'retries': 0,
'start_date': datetime(2020, 12, 13)
}
# DAG base information
dag = DAG(
dag_id='ml-pipeline',
default_args=default_args,
schedule_interval=None,
)
# ECS Args
ecs_operator_args_template = {
'aws_conn_id': 'aws_default',
'region_name': awsRegionName,
'launch_type': 'FARGATE',
'cluster': awsCluster,
'task_definition': awsTaskDefinition,
'network_configuration': {
'awsvpcConfiguration': {
'assignPublicIp': 'ENABLED',
'subnets': [awsNetworkSubnet]
}
},
'awslogs_group': '/ecs/' + awsTaskDefinition,
'awslogs_stream_prefix': 'ecs/' + awsContainerName,
'overrides': {
'containerOverrides': [
{
'name': awsContainerName,
'memoryReservation': 500,
},
],
},
}
ecs_operator_args = copy.deepcopy(ecs_operator_args_template)
ecs_operator = ECSOperator(
task_id='run_ml',
dag=dag,
retries=AIRFLOW_ECS_OPERATOR_RETRIES,
retry_delay=timedelta(seconds=10),
**ecs_operator_args
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment