Skip to content

Instantly share code, notes, and snippets.

@zyd14
Created May 25, 2022 16:37
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save zyd14/d7b1d0278270ed549977d7be7d42fb0a to your computer and use it in GitHub Desktop.
Save zyd14/d7b1d0278270ed549977d7be7d42fb0a to your computer and use it in GitHub Desktop.
Dagster ECS Op
import logging
from dataclasses import dataclass, asdict
import os
from time import sleep
from typing import List
import boto3
from dagster import (
op,
Field,
StringSource,
Array,
OpExecutionContext,
Failure,
graph,
MetadataEntry,
MetadataValue,
Shape, resource,
)
from dagster.core.errors import DagsterExecutionInterruptedError
from dagster.core.events import EngineEventData
from dagster_aws.s3 import s3_pickle_io_manager, s3_resource
class ConfigurationException(Exception):
pass
@resource
def ecs_client(_):
return boto3.client("ecs", region_name="us-east-1")
class EcsTimeoutException(Exception):
pass
@dataclass
class EnvironmentVar:
name: str
value: str
@dataclass
class TaskOverrides:
"""Values for overriding ECS Task specifications"""
name: str
cpu: int
memory: int
command: List[str] = None
environment: List[EnvironmentVar] = None
def __post_init__(self):
if not self._valid_cpu_and_memory(str(self.cpu), str(self.memory)):
raise ConfigurationException(
"Invalid cpu/memory combination. ECS Fargate places constraints on cpu/memory combinations. For details see https://docs.aws.amazon.com/AmazonECS/latest/developerguide/task-cpu-memory-error.html"
)
def _valid_cpu_and_memory(self, cpu: str, memory: str):
# https://docs.aws.amazon.com/AmazonECS/latest/developerguide/task-cpu-memory-error.html
# {cpu: [memory]}
constraints = {
"256": ["512", "1024", "2048"],
"512": [str(i) for i in range(1024, 4096 + 1, 1024)],
"1024": [str(i) for i in range(2048, 8192 + 1, 1024)],
"2048": [str(i) for i in range(4096, 16384 + 1, 1024)],
"4096": [str(i) for i in range(8192, 30720 + 1, 1024)],
}
return bool(memory in constraints.get(cpu, []))
def to_container_override(self):
container_override = asdict(self)
# remove unset values
for key, value in container_override.items():
if value is None:
container_override.pop(key)
if key in {"cpu", "memory"}:
container_override[key] = int(value)
return container_override
def to_task_override(self):
return {"cpu": str(self.cpu), "memory": str(self.memory)}
@dataclass
class TaskState:
state: str
exit_code: int = None
reason: str = None
stop_code: str = None
stopped_reason: str = None
def is_success(self):
return self.exit_code == 0
STOPPING_STATES = ["DEACTIVATING", "DEPROVISIONING", "STOPPING", "STOPPED"]
class EcsTask:
"""Basic class for wrapping functionality around managing an individual ECS Task. Status values retrieved by
EcsTask.get_task_state() are simplified to RUNNING or STOPPING, and do not expose the various other states
reported by ECS
"""
def __init__(self, run_task_response: dict, ecs_client, cloudwatch_url: str):
self.ecs = ecs_client
self.task_arn = run_task_response["tasks"][0]["taskArn"]
self.cluster_arn = run_task_response["tasks"][0]["clusterArn"]
self.cloudwatch_url = cloudwatch_url
def get_task_state(self):
tasks = self.ecs.describe_tasks(
tasks=[self.task_arn], cluster=self.cluster_arn
).get("tasks")
if not tasks:
raise Exception("no tasks found")
status = tasks[0].get("lastStatus")
if status and status not in STOPPING_STATES:
state = "RUNNING"
exit_code = None
reason = None
stop_code = None
stopped_reason = None
else:
state = "STOPPING"
exit_code = tasks[0]["containers"][0].get("exitCode")
reason = tasks[0]["containers"][0].get("reason", "")
stop_code = tasks[0].get("stopCode")
stopped_reason = tasks[0].get("stoppedReason")
return TaskState(state, exit_code, reason, stop_code, stopped_reason)
def poll_until_stopping(self, polling_delay: float, timeout: float = None):
task_state = self.get_task_state()
total_time = 0
while task_state.state == "RUNNING":
sleep(polling_delay)
total_time += polling_delay
task_state = self.get_task_state()
if timeout and total_time > timeout:
raise EcsTimeoutException("ECS Task has exceeded polling timeout")
return task_state
def terminate(
self,
cluster: str,
):
self.ecs.stop_task(
cluster=cluster, task=self.task_arn, reason="terminated by dagster"
)
class EcsTaskRunner:
def __init__(
self,
ecs_client,
log: logging.Logger,
poll_delay: float = 15,
):
self.ecs = ecs_client
self.log = log
self.poll_delay = poll_delay
def run_task(
self,
task_definition: str,
cluster: str,
network_configuration: dict,
dagster_run_id: str,
overrides: TaskOverrides = None,
) -> EcsTask:
run_task_request = dict(taskDefinition=task_definition, cluster=cluster)
if overrides:
# Add DAGSTER_RUN_ID env var to ECS container
if overrides.environment:
overrides.environment.append(
EnvironmentVar(name="DAGSTER_RUN_ID", value=dagster_run_id)
)
else:
overrides.environment = [
EnvironmentVar(name="DAGSTER_RUN_ID", value=dagster_run_id)
]
run_task_request.update(
overrides={
"containerOverrides": [overrides.to_container_override()],
**overrides.to_task_override(),
}
)
run_task_request.update(
networkConfiguration={"awsvpcConfiguration": network_configuration},
launchType="FARGATE",
startedBy="dagster",
tags=[{"key": "dagster_run_id", "value": dagster_run_id}],
propagateTags="TASK_DEFINITION",
)
self.log.info(f"Submitting ECS task with request: {run_task_request}")
response = self.ecs.run_task(**run_task_request)
arn = response["tasks"][0]["taskArn"]
cw_url = self.get_cloudwatch_log_info(task_definition, arn)
return EcsTask(
run_task_response=response, ecs_client=self.ecs, cloudwatch_url=cw_url
)
def get_cloudwatch_log_info(
self,
task_definition: str,
task_arn: str,
) -> str:
response = self.ecs.describe_task_definition(taskDefinition=task_definition)
log_config = response["taskDefinition"]["containerDefinitions"][0][
"logConfiguration"
]["options"]
log_group = log_config["awslogs-group"]
log_stream = log_config["awslogs-stream-prefix"]
log_region = log_config["awslogs-region"]
container_name = response["taskDefinition"]["containerDefinitions"][0]["name"]
return f"https://{log_region}.console.aws.amazon.com/cloudwatch/home?region={log_region}#logsV2:log-groups/log-group/{log_group.replace('/', '$252F')}/log-events/{log_stream.replace('/', '$252F')}$252F{container_name}$252F{os.path.basename(task_arn)}"
EnvironmentVarConfObject = Shape(fields={"name": Field(str), "value": Field(str)})
NetworkConfigShape = Shape(
fields={
"subnets": Field(Array(str)),
"assignPublicIp": Field(str, default_value="ENABLED"),
"securityGroups": Field(Array(str)),
}
)
configurable_fields = {
"task_definition": Field(
StringSource,
description=(
"The task definition to use when launching tasks for this op"
),
),
"container_name": Field(
StringSource,
is_required=False,
default_value="run",
description=(
"The container name to use when launching new tasks. Defaults to 'run'."
),
),
"cpu": Field(int, is_required=False, default_value=1024),
"memory": Field(int, is_required=False, default_value=8192),
"command": Field(Array(str), default_value=[]),
"environment": Field(
dict,
default_value={},
description="Environment variables to inject into the ECS task",
),
"cluster": Field(
str, default_value="dagster", description="ECS Cluster to run task on"
),
"network_configuration": Field(NetworkConfigShape),
"polling_delay": Field(
int,
default_value=15,
description="Seconds between calls to check status of ECS task",
),
}
def resolve_task_environment(
run_config_env: dict, inputs_env: dict
) -> List[EnvironmentVar]:
inputs_env.update(run_config_env)
env_out = []
for k, v in inputs_env.items():
env_out.append(EnvironmentVar(name=k, value=v))
return env_out
@op(
config_schema=configurable_fields,
required_resource_keys={"ecs"},
tags={"kind": "ECS"},
)
def run_ecs_task(context: OpExecutionContext, command: List[str], environment: dict):
task_runner = EcsTaskRunner(log=context.log, ecs_client=context.resources.ecs)
env_out = resolve_task_environment(context.op_config["environment"], environment)
overrides = TaskOverrides(
name=context.op_config["container_name"],
cpu=context.op_config["cpu"],
memory=context.op_config["memory"],
command=command,
environment=env_out,
)
task_definition = context.op_config["task_definition"]
task = task_runner.run_task(
task_definition=task_definition,
cluster=context.op_config["cluster"],
overrides=overrides,
network_configuration=context.op_config["network_configuration"],
dagster_run_id=context.run_id,
)
context.instance.report_engine_event(
message="Launched ECS Task",
pipeline_run=context.pipeline_run,
engine_event_data=EngineEventData(
[
MetadataEntry(
"cloudwatch_url",
description="",
entry_data=MetadataValue.url(task.cloudwatch_url),
),
MetadataEntry(
"task_arn",
description="",
entry_data=MetadataValue.text(task.task_arn),
),
]
),
)
try:
task_state = task.poll_until_stopping(context.op_config["polling_delay"])
except EcsTimeoutException as timeout_exc:
task.terminate(cluster=context.op_config["cluster"])
raise Failure(description="ECS Task exceeded polling timeout") from timeout_exc
except DagsterExecutionInterruptedError:
task.terminate(cluster=context.op_config["cluster"])
raise
if task_state.is_success():
context.log.info(f"ECS task has finished successfully")
else:
context.log.error(
f"ECS task has failed with stopped code {task_state.stop_code}, reason: {task_state.stopped_reason}"
)
raise Failure(description=f"ECS task {task.task_arn} failed")
return task.task_arn
@graph
def ecs_testing():
run_ecs_task()
ecs_testing_job = ecs_testing.to_job(
name="ecs_testing",
resource_defs={
"io_manager": s3_pickle_io_manager,
"s3": s3_resource,
"ecs": ecs_client,
},
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment