Skip to content

Instantly share code, notes, and snippets.

@gzzsound
Last active May 23, 2024 05:29
Show Gist options
  • Save gzzsound/8fc1bb1b2a6f854e6bce746ee424ac3a to your computer and use it in GitHub Desktop.
Save gzzsound/8fc1bb1b2a6f854e6bce746ee424ac3a to your computer and use it in GitHub Desktop.
A Python script to run a task on an ECS cluster and check if it has completed. Handy for database migrations
#!/usr/bin/env python3
import logging
import sys
import time
import boto3
import click
logging.basicConfig(
stream=sys.stdout,
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
def get_task_definition(client, family_name):
# Call the describe_task_definition method to find the task definition
response = client.list_task_definitions(
familyPrefix=family_name, sort="DESC", maxResults=1
)
if len(response["taskDefinitionArns"]) == 0:
raise Exception(f"No task definition found for family {family_name}")
task_definition_arn = response["taskDefinitionArns"][0]
return task_definition_arn
def run_task(client, cluster, task_arn, private_subnet_ids, security_group_id):
response = client.run_task(
cluster=cluster,
taskDefinition=task_arn,
launchType="FARGATE",
networkConfiguration={
"awsvpcConfiguration": {
"subnets": private_subnet_ids,
"securityGroups": [security_group_id],
}
},
)
return response
def wait_until_task_finished(client, cluster, task, timeout):
sleep_seconds = 10
timeout = int(timeout / sleep_seconds)
cnt = 0
deployment_finished = False
while cnt < timeout:
response = client.describe_tasks(cluster=cluster, tasks=[task])
logging.info(
f"Number of tasks: {len(response['tasks'])} for task {task}")
task_status = response["tasks"][0]["lastStatus"]
logging.info(f"Current task status: {task_status}")
# wait until the task is stopped
if task_status in ["STOPPED"]:
deployment_finished = True
break
time.sleep(sleep_seconds)
cnt += 1
return deployment_finished
def exit_code_from_task(client, cluster, task):
response = client.describe_tasks(cluster=cluster, tasks=[task])
logging.info(f"Number of tasks: {len(response['tasks'])} for task {task}")
return response["tasks"][0]["containers"][0]["exitCode"]
if __name__ == "__main__":
@click.command()
@click.option("--cluster-name", help="ECS cluster name")
@click.option("--task-definition-family-name", help="Task definition name")
@click.option("--task-security-group", help="Task security group")
@click.option("--task-private-subnets", help="Task private subnets e.g. 'subnet-xx','subnet-yyy', 'subnet-zzz'")
@click.option("--timeout", default=1800, help="Timeout in seconds")
def main(
cluster_name,
task_definition_family_name,
task_security_group,
task_private_subnets,
timeout,
):
# TODO: pass region as parameter
ecs_client = boto3.client("ecs", region_name="eu-west-1")
task_arn = get_task_definition(
ecs_client, family_name=task_definition_family_name
)
logging.info(
f"Running task {task_arn} on cluster {cluster_name}, security group {task_security_group}, private subnet {task_private_subnets}"
)
split_subnets = task_private_subnets.split(",")
new_task = run_task(
ecs_client, cluster_name, task_arn, split_subnets, task_security_group
)
arn = new_task["tasks"][0]["taskArn"]
finished = wait_until_task_finished(
ecs_client,
cluster=cluster_name,
task=arn,
timeout=timeout,
)
if not finished:
logging.error("Did not stabilize ...")
exit(1)
exit_code = exit_code_from_task(ecs_client, cluster_name, arn)
logging.info("Exit code: %s", exit_code)
if int(exit_code) != 0:
logging.error("Task failed, check logs what was the problem")
exit(1)
logging.info("Done!")
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment