Skip to content

Instantly share code, notes, and snippets.

@subudear
Last active November 21, 2022 12:54
Show Gist options
  • Save subudear/8f9d8c05a2261cceb0ea704d6a5546c9 to your computer and use it in GitHub Desktop.
Save subudear/8f9d8c05a2261cceb0ea704d6a5546c9 to your computer and use it in GitHub Desktop.
Lambda function to update security group rules with CloudFront IPs
# Ports your application uses that need inbound permissions from the service for
# If all you're doing is HTTPS, this can be simply { 'https': 443 }
INGRESS_PORTS = { 'https': 443 }
# Tags which identify the security groups you want to update
GLOBAL_SG_TAGS = { 'Name': 'Update_CloudFront_IPs', 'AutoUpdate': 'true' }
import boto3
import hashlib
import json
import logging
import urllib.request, urllib.error, urllib.parse
import os
REGION= os.getenv( 'REGION',"ap-southeast-2")
def lambda_handler(event, context):
global NRANGES
# Set up logging
if len(logging.getLogger().handlers) > 0:
logging.getLogger().setLevel(logging.ERROR)
else:
logging.basicConfig(level=logging.DEBUG)
# Set the environment variable DEBUG to 'true' if you want verbose debug details in CloudWatch Logs.
try:
if os.environ['DEBUG'] == 'true':
logging.getLogger().setLevel(logging.INFO)
except KeyError:
pass
# If you want a different service, set the SERVICE environment variable.
# It defaults to CLOUDFRONT. Using 'jq' and 'curl' get the list of possible
# services like this:
# curl -s 'https://ip-ranges.amazonaws.com/ip-ranges.json' | jq -r '.prefixes[] | .service' ip-ranges.json | sort -u
SERVICE = os.getenv( 'SERVICE', "CLOUDFRONT")
message = json.loads(event['Records'][0]['Sns']['Message'])
# Load the ip ranges from the url
ip_ranges = json.loads(get_ip_groups_json(message['url'], message['md5']))
# Extract the service ranges
cf_ranges = get_ranges_for_service(ip_ranges, SERVICE)
#Number of security group rules required as per the total range count
NRANGES=len(cf_ranges)*len(INGRESS_PORTS)
# Update the security groups
result = update_security_groups(cf_ranges)
return result
def get_ip_groups_json(url, expected_hash):
logging.debug("Updating from " + url)
response = urllib.request.urlopen(url)
ip_json = response.read()
m = hashlib.md5()
m.update(ip_json)
hash = m.hexdigest()
if hash != expected_hash:
raise Exception('MD5 Mismatch: got ' + hash + ' expected ' + expected_hash)
return ip_json
def get_ranges_for_service(ranges, service):
service_ranges = list()
for prefix in ranges['prefixes']:
if prefix['service'] == service:
logging.info(('Found ' + service + ' region: ' + prefix['region'] + ' range: ' + prefix['ip_prefix']))
service_ranges.append(prefix['ip_prefix'])
return service_ranges
def update_security_groups(new_ranges):
client = boto3.client('ec2',region_name=REGION)
result = list()
# All the security groups we will need to find.
allSGs = INGRESS_PORTS.keys()
# Iterate over every group, doing its global and regional versions
for curGroup in allSGs:
tagToFind = GLOBAL_SG_TAGS
tagToFind['Protocol'] = curGroup
rangeToUpdate = get_security_groups_for_update(client, tagToFind)
msg = 'tagged Name: {}, Protocol: {} to update'.format( tagToFind["Name"], curGroup )
logging.info('Found {} groups {}'.format( str(len(rangeToUpdate)), msg ) )
if len(rangeToUpdate) == 0:
result.append( 'No groups {}'.format(msg) )
logging.warning( 'No groups {}'.format(msg) )
else:
for securityGroupToUpdate in rangeToUpdate:
if update_security_group(client, securityGroupToUpdate, new_ranges, INGRESS_PORTS[curGroup] ):
result.append('Security Group {} updated.'.format( securityGroupToUpdate['GroupId'] ) )
else:
result.append('Security Group {} unchanged.'.format( securityGroupToUpdate['GroupId'] ) )
return result
def update_security_group(client, group, new_ranges, port):
added = 0
removed = 0
if len(group['IpPermissions']) > 0:
for permission in group['IpPermissions']:
if permission['FromPort'] <= port and permission['ToPort'] >= port:
old_prefixes = list()
to_revoke = list()
to_add = list()
for range in permission['IpRanges']:
cidr = range['CidrIp']
old_prefixes.append(cidr)
if new_ranges.count(cidr) == 0:
to_revoke.append(range)
logging.debug((group['GroupId'] + ": Revoking " + cidr + ":" + str(permission['ToPort'])))
for range in new_ranges:
if old_prefixes.count(range) == 0:
to_add.append({ 'CidrIp': range, 'Description': 'CloudFront CIDR-Range' })
logging.debug((group['GroupId'] + ": Adding " + range + ":" + str(permission['ToPort'])))
removed += revoke_permissions(client, group, permission, to_revoke)
added += add_permissions(client, group, permission, to_add)
else:
to_add = list()
for range in new_ranges:
to_add.append({ 'CidrIp': range })
logging.info((group['GroupId'] + ": Adding " + range + ":" + str(port)))
permission = { 'ToPort': port, 'FromPort': port, 'IpProtocol': 'tcp'}
added += add_permissions(client, group, permission, to_add)
logging.debug((group['GroupId'] + ": Added " + str(added) + ", Revoked " + str(removed)))
return (added > 0 or removed > 0)
def revoke_permissions(client, group, permission, to_revoke):
if len(to_revoke) > 0:
revoke_params = {
'ToPort': permission['ToPort'],
'FromPort': permission['FromPort'],
'IpRanges': to_revoke,
'IpProtocol': permission['IpProtocol']
}
client.revoke_security_group_ingress(GroupId=group['GroupId'], IpPermissions=[revoke_params])
return len(to_revoke)
def add_permissions(client, group, permission, to_add):
if len(to_add) > 0:
add_params = {
'ToPort': permission['ToPort'],
'FromPort': permission['FromPort'],
'IpRanges': to_add,
'IpProtocol': permission['IpProtocol']
}
client.authorize_security_group_ingress(GroupId=group['GroupId'], IpPermissions=[add_params])
return len(to_add)
def get_security_groups_for_update(client, security_group_tag):
filters = list()
for key, value in security_group_tag.items():
filters.extend(
[
{ 'Name': "tag-key", 'Values': [ key ] },
{ 'Name': "tag-value", 'Values': [ value ] }
]
)
response = client.describe_security_groups(Filters=filters)
return response['SecurityGroups']
# This is a handy test event you can use when testing your lambda function.
'''
Sample Event From SNS:
{
"Records": [
{
"EventVersion": "1.0",
"EventSubscriptionArn": "arn:aws:sns:EXAMPLE",
"EventSource": "aws:sns",
"Sns": {
"SignatureVersion": "1",
"Timestamp": "1970-01-01T00:00:00.000Z",
"Signature": "EXAMPLE",
"SigningCertUrl": "EXAMPLE",
"MessageId": "95df01b4-ee98-5cb9-9903-4c221d41eb5e",
"Message": "{\"create-time\": \"yyyy-mm-ddThh:mm:ss+00:00\", \"synctoken\": \"0123456789\", \"md5\": \"45be1ba64fe83acb7ef247bccbc45704\", \"url\": \"https://ip-ranges.amazonaws.com/ip-ranges.json\"}",
"Type": "Notification",
"UnsubscribeUrl": "EXAMPLE",
"TopicArn": "arn:aws:sns:EXAMPLE",
"Subject": "TestInvoke"
}
}
]
}
'''
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment