Skip to content

Instantly share code, notes, and snippets.

@daniyel
Created October 17, 2018 07:20
Show Gist options
  • Save daniyel/f85754541ccd05bff89bb4940d8ae375 to your computer and use it in GitHub Desktop.
Save daniyel/f85754541ccd05bff89bb4940d8ae375 to your computer and use it in GitHub Desktop.
Lambda function for creating SRV records in Route53 (AWS) for container services
import boto3
import pprint
import os
import re
environment = os.environ.get('ENVIRONMENT', 'development')
domain = os.environ.get('DOMAIN', f'sd-{environment}.internal')
cluster = f'{environment}-ECSCluster'
DEFAULT_TTL = 0
DEFAULT_WEIGHT = 1
route53 = boto3.client('route53')
ecs = boto3.client('ecs')
ec2 = boto3.client('ec2')
def get_dns_hosted_zone_id(domain):
lhz = route53.list_hosted_zones_by_name(DNSName=domain)
for zone in lhz['HostedZones']:
if zone['Name'] == f'{domain}.':
return zone['Id']
break
return ''
def get_srv_record_sets(hosted_zone_id):
record_sets = {}
lrrs = route53.list_resource_record_sets(
HostedZoneId=hosted_zone_id
)
for rec_set in lrrs['ResourceRecordSets']:
if rec_set['Type'] == 'SRV':
name_parts = rec_set['Name'].split('.')
service_dns = f'{name_parts[0]}.{name_parts[1]}'
for res_rec in rec_set['ResourceRecords']:
value_parts = res_rec['Value'].split(' ')
if not service_dns in record_sets:
record_sets[service_dns] = []
record_sets[service_dns].append({
'hostPort': value_parts[2],
'ec2PrivateDns': value_parts[3]
})
return record_sets
def create_srv_record(service_dns, host_port, ec2_private_dns, hosted_zone_id, domain, container_arn):
name = f'{service_dns}.{domain}'
response = route53.change_resource_record_sets(
HostedZoneId=hosted_zone_id,
ChangeBatch={
'Changes': [
{
'Action': 'CREATE',
'ResourceRecordSet': {
'Name': name,
'Type': 'SRV',
'SetIdentifier': f'{ec2_private_dns}:{container_arn}',
'Weight': DEFAULT_TTL,
'TTL': DEFAULT_TTL,
'ResourceRecords': [
{
'Value': f'1 1 {host_port} {ec2_private_dns}'
},
]
}
},
]
}
)
pprint.pprint(f'Record {name} created, resolves to "1 1 {host_port} {ec2_private_dns}"')
pprint.pprint(response)
def get_container_instances(cluster):
services = {}
lci = ecs.list_container_instances(cluster=cluster)
dci = ecs.describe_container_instances(
cluster=cluster,
containerInstances=lci['containerInstanceArns']
)
for ci in dci['containerInstances']:
instance_id = ci['ec2InstanceId']
services[instance_id] = {
'privateDns': get_ec2_instance_private_dns(instance_id),
'containers': []
}
# List tasks on this container instance
lt = ecs.list_tasks(cluster=cluster, containerInstance=ci['containerInstanceArn'])
# Describe tasks
dt = ecs.describe_tasks(cluster=cluster, tasks=lt['taskArns'])
for task in dt['tasks']:
task_def_arn = task['taskDefinitionArn']
service_dns_name = get_service_dns_name(task_def_arn)
containers = task['containers']
for container in containers:
if 'name' in container and 'networkBindings' in container and len(container['networkBindings']) > 0:
name = container['name']
host_port = container['networkBindings'][0]['hostPort']
container_arn = container['containerArn'].split('/')[0]
services[instance_id]['containers'].append({
'name': name,
'hostPort': host_port,
'serviceDns': service_dns_name,
'containerArn': container_arn
})
return services
def get_ec2_instance_private_dns(instance_id):
private_dns_name = ''
di = ec2.describe_instances(
InstanceIds=[instance_id]
)
for reservation in di['Reservations']:
for instance in reservation['Instances']:
private_dns_name = instance['PrivateDnsName']
break
return private_dns_name
def get_service_dns_name(task_def_arn):
env_name = re.compile('SERVICE_\d{4,5}_NAME')
service_dns_name = ''
dtd = ecs.describe_task_definition(
taskDefinition=task_def_arn
)
task_def = dtd['taskDefinition']
for cont_def in task_def['containerDefinitions']:
environments = cont_def['environment']
for environment in environments:
if env_name.match(environment['name']):
service_dns_name = environment['value']
break
return service_dns_name
def record_exists(service_dns, host_port, srv_record_sets, private_dns):
exists = False
if service_dns in srv_record_sets:
for srv_dns, records in srv_record_sets.items():
for record in records:
if record['ec2PrivateDns'] == private_dns and int(record['hostPort']) == host_port:
exists = True
break
return exists
def generate_srv_records(cluster, hosted_zone_id, domain):
container_instances = get_container_instances(cluster)
srv_record_sets = get_srv_record_sets(hosted_zone_id)
for ec2_id, ci in container_instances.items():
private_dns = ci['privateDns']
for container in ci['containers']:
service_dns = container['serviceDns']
host_port = container['hostPort']
container_arn = container['containerArn']
if service_dns:
if not record_exists(service_dns, host_port, srv_record_sets, private_dns):
print(f'Record {service_dns} with port {host_port} on {private_dns} does not exist. Creating one.')
create_srv_record(service_dns, host_port, private_dns, hosted_zone_id, domain, container_arn)
else:
print(f'Record {service_dns} with port {host_port} on {private_dns} already exists. Skip creating record.')
def lambda_handler(event, context):
hosted_zone_id = get_dns_hosted_zone_id(domain)
generate_srv_records(cluster, hosted_zone_id, domain)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment