Last active
November 21, 2022 12:54
-
-
Save subudear/8f9d8c05a2261cceb0ea704d6a5546c9 to your computer and use it in GitHub Desktop.
Lambda function to update security group rules with CloudFront IPs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Ports your application uses that need inbound permissions from the service for | |
# If all you're doing is HTTPS, this can be simply { 'https': 443 } | |
INGRESS_PORTS = { 'https': 443 } | |
# Tags which identify the security groups you want to update | |
GLOBAL_SG_TAGS = { 'Name': 'Update_CloudFront_IPs', 'AutoUpdate': 'true' } | |
import boto3 | |
import hashlib | |
import json | |
import logging | |
import urllib.request, urllib.error, urllib.parse | |
import os | |
REGION= os.getenv( 'REGION',"ap-southeast-2") | |
def lambda_handler(event, context): | |
global NRANGES | |
# Set up logging | |
if len(logging.getLogger().handlers) > 0: | |
logging.getLogger().setLevel(logging.ERROR) | |
else: | |
logging.basicConfig(level=logging.DEBUG) | |
# Set the environment variable DEBUG to 'true' if you want verbose debug details in CloudWatch Logs. | |
try: | |
if os.environ['DEBUG'] == 'true': | |
logging.getLogger().setLevel(logging.INFO) | |
except KeyError: | |
pass | |
# If you want a different service, set the SERVICE environment variable. | |
# It defaults to CLOUDFRONT. Using 'jq' and 'curl' get the list of possible | |
# services like this: | |
# curl -s 'https://ip-ranges.amazonaws.com/ip-ranges.json' | jq -r '.prefixes[] | .service' ip-ranges.json | sort -u | |
SERVICE = os.getenv( 'SERVICE', "CLOUDFRONT") | |
message = json.loads(event['Records'][0]['Sns']['Message']) | |
# Load the ip ranges from the url | |
ip_ranges = json.loads(get_ip_groups_json(message['url'], message['md5'])) | |
# Extract the service ranges | |
cf_ranges = get_ranges_for_service(ip_ranges, SERVICE) | |
#Number of security group rules required as per the total range count | |
NRANGES=len(cf_ranges)*len(INGRESS_PORTS) | |
# Update the security groups | |
result = update_security_groups(cf_ranges) | |
return result | |
def get_ip_groups_json(url, expected_hash): | |
logging.debug("Updating from " + url) | |
response = urllib.request.urlopen(url) | |
ip_json = response.read() | |
m = hashlib.md5() | |
m.update(ip_json) | |
hash = m.hexdigest() | |
if hash != expected_hash: | |
raise Exception('MD5 Mismatch: got ' + hash + ' expected ' + expected_hash) | |
return ip_json | |
def get_ranges_for_service(ranges, service): | |
service_ranges = list() | |
for prefix in ranges['prefixes']: | |
if prefix['service'] == service: | |
logging.info(('Found ' + service + ' region: ' + prefix['region'] + ' range: ' + prefix['ip_prefix'])) | |
service_ranges.append(prefix['ip_prefix']) | |
return service_ranges | |
def update_security_groups(new_ranges): | |
client = boto3.client('ec2',region_name=REGION) | |
result = list() | |
# All the security groups we will need to find. | |
allSGs = INGRESS_PORTS.keys() | |
# Iterate over every group, doing its global and regional versions | |
for curGroup in allSGs: | |
tagToFind = GLOBAL_SG_TAGS | |
tagToFind['Protocol'] = curGroup | |
rangeToUpdate = get_security_groups_for_update(client, tagToFind) | |
msg = 'tagged Name: {}, Protocol: {} to update'.format( tagToFind["Name"], curGroup ) | |
logging.info('Found {} groups {}'.format( str(len(rangeToUpdate)), msg ) ) | |
if len(rangeToUpdate) == 0: | |
result.append( 'No groups {}'.format(msg) ) | |
logging.warning( 'No groups {}'.format(msg) ) | |
else: | |
for securityGroupToUpdate in rangeToUpdate: | |
if update_security_group(client, securityGroupToUpdate, new_ranges, INGRESS_PORTS[curGroup] ): | |
result.append('Security Group {} updated.'.format( securityGroupToUpdate['GroupId'] ) ) | |
else: | |
result.append('Security Group {} unchanged.'.format( securityGroupToUpdate['GroupId'] ) ) | |
return result | |
def update_security_group(client, group, new_ranges, port): | |
added = 0 | |
removed = 0 | |
if len(group['IpPermissions']) > 0: | |
for permission in group['IpPermissions']: | |
if permission['FromPort'] <= port and permission['ToPort'] >= port: | |
old_prefixes = list() | |
to_revoke = list() | |
to_add = list() | |
for range in permission['IpRanges']: | |
cidr = range['CidrIp'] | |
old_prefixes.append(cidr) | |
if new_ranges.count(cidr) == 0: | |
to_revoke.append(range) | |
logging.debug((group['GroupId'] + ": Revoking " + cidr + ":" + str(permission['ToPort']))) | |
for range in new_ranges: | |
if old_prefixes.count(range) == 0: | |
to_add.append({ 'CidrIp': range, 'Description': 'CloudFront CIDR-Range' }) | |
logging.debug((group['GroupId'] + ": Adding " + range + ":" + str(permission['ToPort']))) | |
removed += revoke_permissions(client, group, permission, to_revoke) | |
added += add_permissions(client, group, permission, to_add) | |
else: | |
to_add = list() | |
for range in new_ranges: | |
to_add.append({ 'CidrIp': range }) | |
logging.info((group['GroupId'] + ": Adding " + range + ":" + str(port))) | |
permission = { 'ToPort': port, 'FromPort': port, 'IpProtocol': 'tcp'} | |
added += add_permissions(client, group, permission, to_add) | |
logging.debug((group['GroupId'] + ": Added " + str(added) + ", Revoked " + str(removed))) | |
return (added > 0 or removed > 0) | |
def revoke_permissions(client, group, permission, to_revoke): | |
if len(to_revoke) > 0: | |
revoke_params = { | |
'ToPort': permission['ToPort'], | |
'FromPort': permission['FromPort'], | |
'IpRanges': to_revoke, | |
'IpProtocol': permission['IpProtocol'] | |
} | |
client.revoke_security_group_ingress(GroupId=group['GroupId'], IpPermissions=[revoke_params]) | |
return len(to_revoke) | |
def add_permissions(client, group, permission, to_add): | |
if len(to_add) > 0: | |
add_params = { | |
'ToPort': permission['ToPort'], | |
'FromPort': permission['FromPort'], | |
'IpRanges': to_add, | |
'IpProtocol': permission['IpProtocol'] | |
} | |
client.authorize_security_group_ingress(GroupId=group['GroupId'], IpPermissions=[add_params]) | |
return len(to_add) | |
def get_security_groups_for_update(client, security_group_tag): | |
filters = list() | |
for key, value in security_group_tag.items(): | |
filters.extend( | |
[ | |
{ 'Name': "tag-key", 'Values': [ key ] }, | |
{ 'Name': "tag-value", 'Values': [ value ] } | |
] | |
) | |
response = client.describe_security_groups(Filters=filters) | |
return response['SecurityGroups'] | |
# This is a handy test event you can use when testing your lambda function. | |
''' | |
Sample Event From SNS: | |
{ | |
"Records": [ | |
{ | |
"EventVersion": "1.0", | |
"EventSubscriptionArn": "arn:aws:sns:EXAMPLE", | |
"EventSource": "aws:sns", | |
"Sns": { | |
"SignatureVersion": "1", | |
"Timestamp": "1970-01-01T00:00:00.000Z", | |
"Signature": "EXAMPLE", | |
"SigningCertUrl": "EXAMPLE", | |
"MessageId": "95df01b4-ee98-5cb9-9903-4c221d41eb5e", | |
"Message": "{\"create-time\": \"yyyy-mm-ddThh:mm:ss+00:00\", \"synctoken\": \"0123456789\", \"md5\": \"45be1ba64fe83acb7ef247bccbc45704\", \"url\": \"https://ip-ranges.amazonaws.com/ip-ranges.json\"}", | |
"Type": "Notification", | |
"UnsubscribeUrl": "EXAMPLE", | |
"TopicArn": "arn:aws:sns:EXAMPLE", | |
"Subject": "TestInvoke" | |
} | |
} | |
] | |
} | |
''' |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment