Skip to content

Instantly share code, notes, and snippets.

@mixja
Last active March 13, 2018 07:32
Show Gist options
  • Save mixja/2242aa357fdfde62146a13b8a7c168b2 to your computer and use it in GitHub Desktop.
Save mixja/2242aa357fdfde62146a13b8a7c168b2 to your computer and use it in GitHub Desktop.
Network Load Balancer Security Group Provisioner
import sys, os
parent_dir = os.path.abspath(os.path.dirname(__file__))
vendor_dir = os.path.join(parent_dir, 'vendor')
sys.path.append(vendor_dir)
import logging, datetime, json
import boto3
import backoff
from botocore.exceptions import ClientError
from cfn_lambda_handler import Handler
from voluptuous import Schema, Required, All, Coerce
from urllib.parse import urlparse, unquote
# Configure logging
logging.basicConfig()
log = logging.getLogger()
log.setLevel(os.environ.get('LOG_LEVEL','INFO'))
def format_json(data):
return json.dumps(data, default=lambda d: d.isoformat() if isinstance(d, datetime.datetime) else str(d))
# Set handler as the entry point for Lambda
handler = Handler()
# EC2 Client
client = boto3.client('ec2')
# Input validation
validator = Schema({
Required('LoadBalancerFullName'): All(str),
Required('VpcId'): All(str),
Required('TcpPorts', default=[]): All([Coerce(int)]),
Required('UdpPorts', default=[]): All([Coerce(int)])
}, extra=True)
# Sorts nested permissions dictionaries on CidrIp
class SortedPermissionsEncoder(json.JSONEncoder):
def encode(self, obj):
def sort_lists(item):
if isinstance(item, list):
return sorted((sort_lists(i) for i in item), key=lambda k: str(k.get('CidrIp','')) if type(k) is dict else k)
elif isinstance(item, dict):
return {k: sort_lists(v) for k, v in item.items()}
else:
return item
return super(SortedPermissionsEncoder, self).encode(sort_lists(obj))
# Returns stack name
def get_stack_name(event):
url = urlparse(event['ResponseURL'])
return unquote(url.path).split(':stack/')[1].split('/')[0]
# Queries network interfaces and returns a set of permissions for a given load balancer and TCP and UDP listener ports
def generate_permissions(lb_name, tcp_ports, udp_ports):
lb_description = 'ELB %s' % lb_name
# Get private IP addresses
network_interfaces = client.describe_network_interfaces(
Filters=[{'Name':'description', 'Values':[lb_description]}]
)
lb_private_ips = [eni['PrivateIpAddress'] for eni in network_interfaces['NetworkInterfaces']]
lb_ip_ranges = [
{'CidrIp': '%s/32' % ip, 'Description': '%s private IP address' % lb_name}
for ip in lb_private_ips
]
# Generate egress rules
tcp_permissions = [
{'FromPort': port, 'ToPort': port, 'IpRanges': lb_ip_ranges, 'IpProtocol':'tcp', 'Ipv6Ranges':[], 'PrefixListIds':[], 'UserIdGroupPairs': []}
for port in tcp_ports
]
udp_permissions = [
{'FromPort': port, 'ToPort': port, 'IpRanges': lb_ip_ranges, 'IpProtocol':'udp', 'Ipv6Ranges':[], 'PrefixListIds':[], 'UserIdGroupPairs': []}
for port in udp_ports
]
return tcp_permissions + udp_permissions
# Creates security group
def create_security_group(lb_name, stack_name, vpc_id):
sg = client.create_security_group(
Description="%s %s Security Group" % (stack_name, lb_name),
GroupName="%s-%s-sg" % (stack_name, lb_name),
VpcId=vpc_id
)
tag_security_group(sg, lb_name, stack_name)
client.revoke_security_group_egress(
GroupId=sg['GroupId'],
IpPermissions=[{'FromPort':-1,'ToPort':-1,'IpRanges':[{'CidrIp':'0.0.0.0/0'}],'IpProtocol':'-1'}]
)
return sg
# Tags security group
@backoff.on_exception(backoff.constant, ClientError, interval=5, giveup=lambda e: e.response['Error']['Code'] != 'InvalidGroup.NotFound')
def tag_security_group(sg, lb_name, stack_name):
client.create_tags(
Resources=[sg['GroupId']],
Tags=[{'Key':'Name','Value':"%s-%s-sg" % (stack_name, lb_name)}]
)
# Loops until security group is deleted
@backoff.on_exception(backoff.constant, ClientError, interval=5, giveup=lambda e: e.response['Error']['Code'] != 'DependencyViolation')
def check_sg_deleted(sg_id):
client.delete_security_group(GroupId=sg_id)
# Create requests
@handler.create
def create(event, context):
log.info("Received create event: %s", format_json(event))
data = validator(event['ResourceProperties'])
if not data['TcpPorts'] and not data['UdpPorts']:
raise ValueError('Invalid input - you must specify at least one of TcpPorts or UdpPorts properties')
# Generate permissions
lb_name = data['LoadBalancerFullName']
permissions = generate_permissions(lb_name, data['TcpPorts'], data['UdpPorts'])
# Create security group
stack_name = get_stack_name(event)
sg = create_security_group(lb_name, stack_name, data['VpcId'])
client.authorize_security_group_egress(GroupId=sg['GroupId'],IpPermissions=permissions)
# Set physical resource Id
log.info("Successfully created security group with id: %s", sg['GroupId'])
event['PhysicalResourceId'] = sg['GroupId']
return event
# Update requests
@handler.update
def update(event, context):
log.info("Received update event: %s", format_json(event))
data = validator(event['ResourceProperties'])
if not data['TcpPorts'] and not data['UdpPorts']:
raise ValueError('Invalid input - you must specify at least one of TcpPorts or UdpPorts properties')
# Generate permissions
lb_name = data['LoadBalancerFullName']
permissions = generate_permissions(lb_name, data['TcpPorts'], data['UdpPorts'])
# Update security group
existing_sg = next(iter(client.describe_security_groups(GroupIds=[event['PhysicalResourceId']])['SecurityGroups']),None)
stack_name = get_stack_name(event)
if existing_sg is None:
# Create sg as it doesn't currently exist for some reason
existing_sg = create_security_group(lb_name, stack_name, data['VpcId'])
event['PhysicalResourceId'] = existing_sg['GroupId']
existing_sg['IpPermissionsEgress'] = []
# Tag security group
tag_security_group(existing_sg, lb_name, stack_name)
# Determine if any changes required
# This requires nested dictionaries to be sorted using a custom JSON encoder
# Apply new permissions first and then revoke old permissions to avoid packet loss
sorted_permissions = json.loads(json.dumps(permissions, sort_keys=True, cls=SortedPermissionsEncoder))
sorted_existing = json.loads(json.dumps(existing_sg['IpPermissionsEgress'], sort_keys=True, cls=SortedPermissionsEncoder))
log.info("New permissions %s", sorted_permissions)
log.info("Existing permissions %s", sorted_existing)
if sorted_existing != sorted_permissions:
changes = []
for permission in sorted_permissions:
if permission in sorted_existing:
# Permission already exists, remove it from the existing list which we will revoke later
sorted_existing.remove(permission)
else:
# New permission required
changes.append(permission)
# Authorize new permissions
if changes:
log.info("Applying new permissions %s", changes)
client.authorize_security_group_egress(GroupId=existing_sg['GroupId'],IpPermissions=changes)
# Revoke old permissions
if sorted_existing:
log.info("Revoking old permissions %s", sorted_existing)
client.revoke_security_group_egress(GroupId=existing_sg['GroupId'],IpPermissions=sorted_existing)
return event
# Delete requests
@handler.delete
def delete(event, context):
log.info("Received delete event: %s" % format_json(event))
try:
existing_sg = next(iter(client.describe_security_groups(GroupIds=[event['PhysicalResourceId']])['SecurityGroups']),None)
if existing_sg is not None:
check_sg_deleted(existing_sg['GroupId'])
except ClientError as e:
if e.response['Error']['Code'].startswith('InvalidGroup'): log.info("Skipping as security group %s not found", event['PhysicalResourceId'])
else: raise e
return event
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment