Skip to content

Instantly share code, notes, and snippets.

@tommyblue
Created December 6, 2019 15:47

Revisions

  1. tommyblue created this gist Dec 6, 2019.
    159 changes: 159 additions & 0 deletions replace_instances.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,159 @@
    #!/usr/bin/env python3
    """
    Replaces all instances of a cluster in a region.
    This only works if all instances are managed by scaling group
    """
    import argparse
    import functools
    import logging
    import os
    import time

    import boto3

    ### START CONFIGURATIONS
    CLUSTER = "" # Name of the cluster
    REGION = "" # The region where the cluster is running
    ASG_NAME = "" # Name of the autoscaling group
    ### END CONFIGURATIONS

    def main():
    ecs_client = boto3.client("ecs", REGION)
    asg_client = boto3.client("autoscaling", REGION)

    # Can't replace if desired instances is different from running
    if _get_desired_capacity(asg_client) != _get_running_instances():
    logging.critical("Can't replace instances if autoscaling activity is ongoing")
    os._exit(1)

    logging.info("Finding the number of desired instances in the autoscaling group")
    desired_instances = _get_desired_capacity(asg_client)

    logging.info("Set the running instances status as 'DRAINING'")
    _set_running_instances_as_draining(ecs_client)

    logging.info("Modifying the autoscaling group doubling the desired instances")
    _set_desired_capacity(asg_client, desired_instances*2)

    logging.info("Waiting for the new instances to be launched")
    if not _wait_instances(desired_instances*2):
    os._exit(1)

    logging.info("Waiting all tasks in the draining instances to be stopped")
    _wait_draining_instances_are_empty(ecs_client)

    logging.info("Bringing back the desired count in the asg to its initial value")
    _set_desired_capacity(asg_client, desired_instances)

    logging.info("Waiting for the drained instances to be shutdown")
    if not _wait_instances(desired_instances):
    os._exit(1)


    def _get_desired_capacity(asg_client) -> int:
    resp = asg_client.describe_auto_scaling_groups(AutoScalingGroupNames=[ASG_NAME])
    if len(resp['AutoScalingGroups']) != 1:
    logging.critical("Too many ASG! {}".format(resp))
    os._exit(1)
    desired = resp['AutoScalingGroups'][0]["DesiredCapacity"]
    return desired


    def _set_running_instances_as_draining(ecs_client):
    all_instances = ecs_client.list_container_instances(cluster=CLUSTER)['containerInstanceArns']
    ecs_client.update_container_instances_state(
    cluster=CLUSTER,
    containerInstances=all_instances,
    status='DRAINING'
    )


    def _set_desired_capacity(asg_client, desired_instances):
    asg_client.set_desired_capacity(
    AutoScalingGroupName=ASG_NAME,
    DesiredCapacity=desired_instances,
    HonorCooldown=False
    )


    @with_sleep(sleep_time=30, max_attempts=20)
    def _wait_instances(desired_instances):
    running = _get_running_instances()
    if running == desired_instances:
    logging.info("Done!")
    return True

    def _get_running_instances() -> int:
    instances = _describe_container_instances()
    return len(instances['containerInstances'])


    @with_sleep(sleep_time=30, max_attempts=20)
    def _wait_draining_instances_are_empty(ecs_client):
    tasks_per_instance = _get_tasks_per_instance(ecs_client, status=["DRAINING"])
    if sum(tasks_per_instance.values()) == 0:
    return True


    def _get_tasks_per_instance(ecs_client, status=None):
    tasks_list = ecs_client.list_tasks(cluster=CLUSTER)
    return _tasks_per_instance(
    ecs_client, tasks_list['taskArns'], status=status)


    def _describe_container_instances():
    ecs_client = boto3.client("ecs", REGION)

    containers_response = ecs_client.list_container_instances(cluster=CLUSTER)

    cluster_instances = _describe_container_instances(
    cluster=CLUSTER, containerInstances=containers_response['containerInstanceArns'])

    return cluster_instances

    def _tasks_per_instance(ecs_client, tasks_list: list, status=None) -> dict:
    """
    Receives an `instances` dictionary with the instances arn as key and 0 as values, returns
    a dictionary where the arn is replaced with the id and the value is the number of tasks
    running on that instance
    """
    if status is None:
    status = ["ACTIVE"]
    instances = get_instances_dict(ecs_client, CLUSTER)
    tasks_desc = ecs_client.describe_tasks(cluster=CLUSTER, tasks=tasks_list)
    for t in tasks_desc['tasks']:
    instances[t['containerInstanceArn']] += 1

    response = ecs_client.describe_container_instances(
    cluster=CLUSTER, containerInstances=list(instances.keys()))

    instances_as_ids = {
    i['ec2InstanceId']: instances[i['containerInstanceArn']]
    for i in response['containerInstances'] if i['status'] in status
    }

    return instances_as_ids

    def with_sleep(sleep_time=5, max_attempts=3):
    def sleep_decorator(func):
    @functools.wraps(func)
    def wrapper_sleep(*args, **kwargs):
    attempts = 0
    while True:
    attempts += 1
    time.sleep(sleep_time)

    ret = func(*args, **kwargs)
    if ret is not None:
    return ret

    if attempts > max_attempts:
    print(
    "Still not ready after %s seconds, please investigate." % (sleep_time*max_attempts))
    return False
    return wrapper_sleep
    return sleep_decorator


    if __name__ == '__main__':
    main()