Skip to content

Instantly share code, notes, and snippets.

@manhtai
Created May 21, 2018 13:45
Show Gist options
  • Save manhtai/66dfdae56ebce7b6270788018516a409 to your computer and use it in GitHub Desktop.
Save manhtai/66dfdae56ebce7b6270788018516a409 to your computer and use it in GitHub Desktop.
Lambda function for draining ECS instances before terminating it
from __future__ import print_function
import boto3
import base64
import json
import logging
logging.basicConfig()
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
# Establish boto3 session
session = boto3.session.Session()
logger.debug("Session is in region %s ", session.region_name)
ec2Client = session.client(service_name='ec2')
ecsClient = session.client(service_name='ecs')
asgClient = session.client('autoscaling')
snsClient = session.client('sns')
lambdaClient = session.client('lambda')
def publishToSNS(message, topicARN):
"""Publish SNS message to trigger lambda again.
:param message: To repost the complete original message received when ASG terminating event was received.
:param topicARN: SNS topic to publish the message to.
"""
logger.info("Publish to SNS topic %s", topicARN)
snsClient.publish(
TopicArn=topicARN,
Message=json.dumps(message),
Subject='Publishing SNS message to invoke lambda again..'
)
def checkContainerInstanceTaskStatus(Ec2InstanceId):
"""Check task status on the ECS container instance ID.
:param Ec2InstanceId: The EC2 instance ID is used to identify the cluster, container instances in cluster
"""
containerInstanceId = None
clusterName = None
tmpMsgAppend = None
# Describe instance attributes and get the Clustername from userdata section which would have set ECS_CLUSTER name
ec2Resp = ec2Client.describe_instance_attribute(InstanceId=Ec2InstanceId, Attribute='userData')
userdataEncoded = ec2Resp['UserData']
userdataDecoded = base64.b64decode(userdataEncoded['Value'])
logger.debug("Describe instance attributes response %s", ec2Resp)
tmpList = userdataDecoded.split()
for token in tmpList:
if token.find("ECS_CLUSTER") > -1:
# Split and get the cluster name
clusterName = token.split('=')[1]
logger.info("Cluster name %s", clusterName)
# Get list of container instance IDs from the clusterName
paginator = ecsClient.get_paginator('list_container_instances')
clusterListPages = paginator.paginate(cluster=clusterName)
for containerListResp in clusterListPages:
containerDetResp = ecsClient.describe_container_instances(
cluster=clusterName,
containerInstances=containerListResp['containerInstanceArns'],
)
logger.debug("describe container instances response %s", containerDetResp)
for containerInstances in containerDetResp['containerInstances']:
logger.debug(
"Container Instance ARN: %s and ec2 Instance ID %s",
containerInstances['containerInstanceArn'],
containerInstances['ec2InstanceId'],
)
if containerInstances['ec2InstanceId'] == Ec2InstanceId:
logger.info("Container instance ID of interest : %s",
containerInstances['containerInstanceArn'])
containerInstanceId = containerInstances['containerInstanceArn']
# Check if the instance state is set to DRAINING. If not, set it, so the ECS Cluster
# will handle de-registering instance, draining tasks and draining them
containerStatus = containerInstances['status']
if containerStatus == 'DRAINING':
logger.info(
"Container ID %s with EC2 instance-id %s is draining tasks",
containerInstanceId,
Ec2InstanceId,
)
tmpMsgAppend = {"containerInstanceId": containerInstanceId}
else:
# Make ECS API call to set the container status to DRAINING
logger.info("Make ECS API call to set the container status to DRAINING...")
ecsClient.update_container_instances_state(
cluster=clusterName,
containerInstances=[containerInstanceId],
status='DRAINING',
)
# When you set instance state to draining, append the containerInstanceID to the message as well
tmpMsgAppend = {"containerInstanceId": containerInstanceId}
break
if containerInstanceId is not None:
break
# Using container Instance ID, get the task list, and task running on that instance.
if containerInstanceId is not None:
# List tasks on the container instance ID, to get task Arns
listTaskResp = ecsClient.list_tasks(cluster=clusterName, containerInstance=containerInstanceId)
logger.debug("Container instance task list %s", listTaskResp['taskArns'])
# If the chosen instance has tasks
if len(listTaskResp['taskArns']) > 0:
logger.info("Tasks are on this instance...%s", Ec2InstanceId)
return 1, tmpMsgAppend
else:
logger.info("NO tasks are on this instance...%s", Ec2InstanceId)
return 0, tmpMsgAppend
else:
logger.info("NO tasks are on this instance....%s", Ec2InstanceId)
return 0, tmpMsgAppend
def lambda_handler(event, context):
line = event['Records'][0]['Sns']['Message']
message = json.loads(line)
if not message.get('EC2InstanceId'):
return
Ec2InstanceId = message['EC2InstanceId']
asgGroupName = message['AutoScalingGroupName']
snsArn = event['Records'][0]['EventSubscriptionArn']
TopicArn = event['Records'][0]['Sns']['TopicArn']
lifecycleHookName = None
clusterName = None
tmpMsgAppend = None
logger.info("Lambda received the event %s", event)
logger.debug("records: %s", event['Records'][0])
logger.debug("sns: %s", event['Records'][0]['Sns'])
logger.debug("Message: %s", message)
logger.debug("Ec2 Instance Id %s ,%s", Ec2InstanceId, asgGroupName)
logger.debug("SNS ARN %s", snsArn)
# Describe instance attributes and get the Clustername from userdata section which would have set ECS_CLUSTER name
ec2Resp = ec2Client.describe_instance_attribute(InstanceId=Ec2InstanceId, Attribute='userData')
logger.debug("Describe instance attributes response %s", ec2Resp)
userdataEncoded = ec2Resp['UserData']
userdataDecoded = base64.b64decode(userdataEncoded['Value'])
tmpList = userdataDecoded.split()
for token in tmpList:
if token.find("ECS_CLUSTER") > -1:
# Split and get the cluster name
clusterName = token.split('=')[1]
logger.debug("Cluster name %s", clusterName)
# If the event received is instance terminating...
if 'LifecycleTransition' in message.keys():
logger.debug("message autoscaling %s", message['LifecycleTransition'])
if message['LifecycleTransition'].find('autoscaling:EC2_INSTANCE_TERMINATING') > -1:
# Get lifecycle hook name
lifecycleHookName = message['LifecycleHookName']
logger.debug("Setting lifecycle hook name %s ", lifecycleHookName)
# Check if there are any tasks running on the instance
tasksRunning, tmpMsgAppend = checkContainerInstanceTaskStatus(Ec2InstanceId)
logger.debug("Returned values received: %s ", tasksRunning)
if tmpMsgAppend is not None:
message.update(tmpMsgAppend)
# If tasks are still running...
if tasksRunning == 1:
publishToSNS(message, TopicArn)
# If tasks are NOT running...
elif tasksRunning == 0:
logger.debug("Setting lifecycle to complete; No tasks are running on instance, completing lifecycle action....")
try:
response = asgClient.complete_lifecycle_action(
LifecycleHookName=lifecycleHookName,
AutoScalingGroupName=asgGroupName,
LifecycleActionResult='CONTINUE',
InstanceId=Ec2InstanceId)
logger.info("Response received from complete_lifecycle_action %s", response)
logger.info("Completedlifecycle hook action")
except Exception as e:
print(str(e))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment