Created
May 25, 2022 16:37
-
-
Save zyd14/d7b1d0278270ed549977d7be7d42fb0a to your computer and use it in GitHub Desktop.
Dagster ECS Op
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import logging | |
from dataclasses import dataclass, asdict | |
import os | |
from time import sleep | |
from typing import List | |
import boto3 | |
from dagster import ( | |
op, | |
Field, | |
StringSource, | |
Array, | |
OpExecutionContext, | |
Failure, | |
graph, | |
MetadataEntry, | |
MetadataValue, | |
Shape, resource, | |
) | |
from dagster.core.errors import DagsterExecutionInterruptedError | |
from dagster.core.events import EngineEventData | |
from dagster_aws.s3 import s3_pickle_io_manager, s3_resource | |
class ConfigurationException(Exception): | |
pass | |
@resource | |
def ecs_client(_): | |
return boto3.client("ecs", region_name="us-east-1") | |
class EcsTimeoutException(Exception): | |
pass | |
@dataclass | |
class EnvironmentVar: | |
name: str | |
value: str | |
@dataclass | |
class TaskOverrides: | |
"""Values for overriding ECS Task specifications""" | |
name: str | |
cpu: int | |
memory: int | |
command: List[str] = None | |
environment: List[EnvironmentVar] = None | |
def __post_init__(self): | |
if not self._valid_cpu_and_memory(str(self.cpu), str(self.memory)): | |
raise ConfigurationException( | |
"Invalid cpu/memory combination. ECS Fargate places constraints on cpu/memory combinations. For details see https://docs.aws.amazon.com/AmazonECS/latest/developerguide/task-cpu-memory-error.html" | |
) | |
def _valid_cpu_and_memory(self, cpu: str, memory: str): | |
# https://docs.aws.amazon.com/AmazonECS/latest/developerguide/task-cpu-memory-error.html | |
# {cpu: [memory]} | |
constraints = { | |
"256": ["512", "1024", "2048"], | |
"512": [str(i) for i in range(1024, 4096 + 1, 1024)], | |
"1024": [str(i) for i in range(2048, 8192 + 1, 1024)], | |
"2048": [str(i) for i in range(4096, 16384 + 1, 1024)], | |
"4096": [str(i) for i in range(8192, 30720 + 1, 1024)], | |
} | |
return bool(memory in constraints.get(cpu, [])) | |
def to_container_override(self): | |
container_override = asdict(self) | |
# remove unset values | |
for key, value in container_override.items(): | |
if value is None: | |
container_override.pop(key) | |
if key in {"cpu", "memory"}: | |
container_override[key] = int(value) | |
return container_override | |
def to_task_override(self): | |
return {"cpu": str(self.cpu), "memory": str(self.memory)} | |
@dataclass | |
class TaskState: | |
state: str | |
exit_code: int = None | |
reason: str = None | |
stop_code: str = None | |
stopped_reason: str = None | |
def is_success(self): | |
return self.exit_code == 0 | |
STOPPING_STATES = ["DEACTIVATING", "DEPROVISIONING", "STOPPING", "STOPPED"] | |
class EcsTask: | |
"""Basic class for wrapping functionality around managing an individual ECS Task. Status values retrieved by | |
EcsTask.get_task_state() are simplified to RUNNING or STOPPING, and do not expose the various other states | |
reported by ECS | |
""" | |
def __init__(self, run_task_response: dict, ecs_client, cloudwatch_url: str): | |
self.ecs = ecs_client | |
self.task_arn = run_task_response["tasks"][0]["taskArn"] | |
self.cluster_arn = run_task_response["tasks"][0]["clusterArn"] | |
self.cloudwatch_url = cloudwatch_url | |
def get_task_state(self): | |
tasks = self.ecs.describe_tasks( | |
tasks=[self.task_arn], cluster=self.cluster_arn | |
).get("tasks") | |
if not tasks: | |
raise Exception("no tasks found") | |
status = tasks[0].get("lastStatus") | |
if status and status not in STOPPING_STATES: | |
state = "RUNNING" | |
exit_code = None | |
reason = None | |
stop_code = None | |
stopped_reason = None | |
else: | |
state = "STOPPING" | |
exit_code = tasks[0]["containers"][0].get("exitCode") | |
reason = tasks[0]["containers"][0].get("reason", "") | |
stop_code = tasks[0].get("stopCode") | |
stopped_reason = tasks[0].get("stoppedReason") | |
return TaskState(state, exit_code, reason, stop_code, stopped_reason) | |
def poll_until_stopping(self, polling_delay: float, timeout: float = None): | |
task_state = self.get_task_state() | |
total_time = 0 | |
while task_state.state == "RUNNING": | |
sleep(polling_delay) | |
total_time += polling_delay | |
task_state = self.get_task_state() | |
if timeout and total_time > timeout: | |
raise EcsTimeoutException("ECS Task has exceeded polling timeout") | |
return task_state | |
def terminate( | |
self, | |
cluster: str, | |
): | |
self.ecs.stop_task( | |
cluster=cluster, task=self.task_arn, reason="terminated by dagster" | |
) | |
class EcsTaskRunner: | |
def __init__( | |
self, | |
ecs_client, | |
log: logging.Logger, | |
poll_delay: float = 15, | |
): | |
self.ecs = ecs_client | |
self.log = log | |
self.poll_delay = poll_delay | |
def run_task( | |
self, | |
task_definition: str, | |
cluster: str, | |
network_configuration: dict, | |
dagster_run_id: str, | |
overrides: TaskOverrides = None, | |
) -> EcsTask: | |
run_task_request = dict(taskDefinition=task_definition, cluster=cluster) | |
if overrides: | |
# Add DAGSTER_RUN_ID env var to ECS container | |
if overrides.environment: | |
overrides.environment.append( | |
EnvironmentVar(name="DAGSTER_RUN_ID", value=dagster_run_id) | |
) | |
else: | |
overrides.environment = [ | |
EnvironmentVar(name="DAGSTER_RUN_ID", value=dagster_run_id) | |
] | |
run_task_request.update( | |
overrides={ | |
"containerOverrides": [overrides.to_container_override()], | |
**overrides.to_task_override(), | |
} | |
) | |
run_task_request.update( | |
networkConfiguration={"awsvpcConfiguration": network_configuration}, | |
launchType="FARGATE", | |
startedBy="dagster", | |
tags=[{"key": "dagster_run_id", "value": dagster_run_id}], | |
propagateTags="TASK_DEFINITION", | |
) | |
self.log.info(f"Submitting ECS task with request: {run_task_request}") | |
response = self.ecs.run_task(**run_task_request) | |
arn = response["tasks"][0]["taskArn"] | |
cw_url = self.get_cloudwatch_log_info(task_definition, arn) | |
return EcsTask( | |
run_task_response=response, ecs_client=self.ecs, cloudwatch_url=cw_url | |
) | |
def get_cloudwatch_log_info( | |
self, | |
task_definition: str, | |
task_arn: str, | |
) -> str: | |
response = self.ecs.describe_task_definition(taskDefinition=task_definition) | |
log_config = response["taskDefinition"]["containerDefinitions"][0][ | |
"logConfiguration" | |
]["options"] | |
log_group = log_config["awslogs-group"] | |
log_stream = log_config["awslogs-stream-prefix"] | |
log_region = log_config["awslogs-region"] | |
container_name = response["taskDefinition"]["containerDefinitions"][0]["name"] | |
return f"https://{log_region}.console.aws.amazon.com/cloudwatch/home?region={log_region}#logsV2:log-groups/log-group/{log_group.replace('/', '$252F')}/log-events/{log_stream.replace('/', '$252F')}$252F{container_name}$252F{os.path.basename(task_arn)}" | |
EnvironmentVarConfObject = Shape(fields={"name": Field(str), "value": Field(str)}) | |
NetworkConfigShape = Shape( | |
fields={ | |
"subnets": Field(Array(str)), | |
"assignPublicIp": Field(str, default_value="ENABLED"), | |
"securityGroups": Field(Array(str)), | |
} | |
) | |
configurable_fields = { | |
"task_definition": Field( | |
StringSource, | |
description=( | |
"The task definition to use when launching tasks for this op" | |
), | |
), | |
"container_name": Field( | |
StringSource, | |
is_required=False, | |
default_value="run", | |
description=( | |
"The container name to use when launching new tasks. Defaults to 'run'." | |
), | |
), | |
"cpu": Field(int, is_required=False, default_value=1024), | |
"memory": Field(int, is_required=False, default_value=8192), | |
"command": Field(Array(str), default_value=[]), | |
"environment": Field( | |
dict, | |
default_value={}, | |
description="Environment variables to inject into the ECS task", | |
), | |
"cluster": Field( | |
str, default_value="dagster", description="ECS Cluster to run task on" | |
), | |
"network_configuration": Field(NetworkConfigShape), | |
"polling_delay": Field( | |
int, | |
default_value=15, | |
description="Seconds between calls to check status of ECS task", | |
), | |
} | |
def resolve_task_environment( | |
run_config_env: dict, inputs_env: dict | |
) -> List[EnvironmentVar]: | |
inputs_env.update(run_config_env) | |
env_out = [] | |
for k, v in inputs_env.items(): | |
env_out.append(EnvironmentVar(name=k, value=v)) | |
return env_out | |
@op( | |
config_schema=configurable_fields, | |
required_resource_keys={"ecs"}, | |
tags={"kind": "ECS"}, | |
) | |
def run_ecs_task(context: OpExecutionContext, command: List[str], environment: dict): | |
task_runner = EcsTaskRunner(log=context.log, ecs_client=context.resources.ecs) | |
env_out = resolve_task_environment(context.op_config["environment"], environment) | |
overrides = TaskOverrides( | |
name=context.op_config["container_name"], | |
cpu=context.op_config["cpu"], | |
memory=context.op_config["memory"], | |
command=command, | |
environment=env_out, | |
) | |
task_definition = context.op_config["task_definition"] | |
task = task_runner.run_task( | |
task_definition=task_definition, | |
cluster=context.op_config["cluster"], | |
overrides=overrides, | |
network_configuration=context.op_config["network_configuration"], | |
dagster_run_id=context.run_id, | |
) | |
context.instance.report_engine_event( | |
message="Launched ECS Task", | |
pipeline_run=context.pipeline_run, | |
engine_event_data=EngineEventData( | |
[ | |
MetadataEntry( | |
"cloudwatch_url", | |
description="", | |
entry_data=MetadataValue.url(task.cloudwatch_url), | |
), | |
MetadataEntry( | |
"task_arn", | |
description="", | |
entry_data=MetadataValue.text(task.task_arn), | |
), | |
] | |
), | |
) | |
try: | |
task_state = task.poll_until_stopping(context.op_config["polling_delay"]) | |
except EcsTimeoutException as timeout_exc: | |
task.terminate(cluster=context.op_config["cluster"]) | |
raise Failure(description="ECS Task exceeded polling timeout") from timeout_exc | |
except DagsterExecutionInterruptedError: | |
task.terminate(cluster=context.op_config["cluster"]) | |
raise | |
if task_state.is_success(): | |
context.log.info(f"ECS task has finished successfully") | |
else: | |
context.log.error( | |
f"ECS task has failed with stopped code {task_state.stop_code}, reason: {task_state.stopped_reason}" | |
) | |
raise Failure(description=f"ECS task {task.task_arn} failed") | |
return task.task_arn | |
@graph | |
def ecs_testing(): | |
run_ecs_task() | |
ecs_testing_job = ecs_testing.to_job( | |
name="ecs_testing", | |
resource_defs={ | |
"io_manager": s3_pickle_io_manager, | |
"s3": s3_resource, | |
"ecs": ecs_client, | |
}, | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment