Skip to content

Instantly share code, notes, and snippets.

@luketn
Created March 7, 2019 06:06
Show Gist options
  • Save luketn/15266eb13e87746491c61e5a5a22c8d2 to your computer and use it in GitHub Desktop.
Save luketn/15266eb13e87746491c61e5a5a22c8d2 to your computer and use it in GitHub Desktop.
'''
Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance with the License. A copy of the License is located at
http://aws.amazon.com/apache2.0/
or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
'''
import boto3
import hashlib
import json
import urllib2
# Name of the service, as seen in the ip-groups.json file, to extract information for
SERVICE = "CLOUDFRONT"
# Ports your application uses that need inbound permissions from the service for
INGRESS_PORTS = { 'Http' : 80, 'Https': 443 }
# Tags which identify the security groups you want to update
SECURITY_GROUP_TAG_FOR_GLOBAL_HTTP = { 'Name': 'cloudfront_g', 'AutoUpdate': 'true', 'Protocol': 'http' }
SECURITY_GROUP_TAG_FOR_GLOBAL_HTTPS = { 'Name': 'cloudfront_g', 'AutoUpdate': 'true', 'Protocol': 'https' }
SECURITY_GROUP_TAG_FOR_REGION_HTTP = { 'Name': 'cloudfront_r', 'AutoUpdate': 'true', 'Protocol': 'http' }
SECURITY_GROUP_TAG_FOR_REGION_HTTPS = { 'Name': 'cloudfront_r', 'AutoUpdate': 'true', 'Protocol': 'https' }
S3_BUCKET_TAG = 'cloudfront_only'
def lambda_handler(event, context):
print("Received event: " + json.dumps(event, indent=2))
if is_cloudtrail_event(event):
if not is_scalinggroup_or_sg_tag(event):
return
else:
url = "https://ip-ranges.amazonaws.com/ip-ranges.json"
expected_hash = None
else:
if 'Records' in event:
message = json.loads(event['Records'][0]['Sns']['Message'])
url = message['url']
expected_hash = message['md5']
else:
url = "https://ip-ranges.amazonaws.com/ip-ranges.json"
expected_hash = None
# Load the ip ranges from the url
ip_ranges = json.loads(get_ip_groups_json(url, expected_hash))
# extract the service ranges
global_cf_ranges = get_ranges_for_service(ip_ranges, SERVICE, "GLOBAL")
region_cf_ranges = get_ranges_for_service(ip_ranges, SERVICE, "REGION")
ip_ranges = { "GLOBAL": global_cf_ranges, "REGION": region_cf_ranges }
# update the security groups
result = update_security_groups(ip_ranges)
update_s3_buckets(ip_ranges)
return result
def get_ip_groups_json(url, expected_hash):
print("Updating from " + url)
response = urllib2.urlopen(url)
ip_json = response.read()
if expected_hash is not None:
m = hashlib.md5()
m.update(ip_json)
hash = m.hexdigest()
if hash != expected_hash:
raise Exception('MD5 Mismatch: got ' + hash + ' expected ' + expected_hash)
return ip_json
def get_ranges_for_service(ranges, service, subset):
service_ranges = list()
for prefix in ranges['prefixes']:
if prefix['service'] == service and ((subset == prefix['region'] and subset == "GLOBAL") or (subset != 'GLOBAL' and prefix['region'] != 'GLOBAL')):
print('Found ' + service + ' region: ' + prefix['region'] + ' range: ' + prefix['ip_prefix'])
service_ranges.append(prefix['ip_prefix'])
return service_ranges
def update_s3_buckets(new_ranges):
client = boto3.client('s3')
bucket_list = client.list_buckets()
for bucket in bucket_list['Buckets']:
bucket_name = bucket['Name']
try:
bucket_tags = client.get_bucket_tagging(Bucket=bucket_name)
for tag in bucket_tags['TagSet']:
key = tag['Key']
value = tag['Value']
if key == S3_BUCKET_TAG:
policy = {
"Version": "2012-10-17",
"Id": "S3PolicyForBucket",
"Statement": [
{
"Sid": "IPAllow",
"Effect": "Allow",
"Principal": "*",
"Action": "s3:GetObject",
"Resource": [
"arn:aws:s3:::" + bucket_name + "/*",
"arn:aws:s3:::" + bucket_name
],
"Condition": {
"IpAddress": {
"aws:SourceIp": new_ranges["GLOBAL"] + new_ranges["REGION"]
}
}
}
]
}
policy_string = json.dumps(policy)
print('Applying policy to bucket ' + bucket_name + ':')
print(policy_string)
try:
response = client.put_bucket_policy(
Bucket=bucket_name,
Policy=policy_string
)
except:
print('Failed to update bucket policy.')
except:
print('No tags for ' + bucket_name)
def update_security_groups(new_ranges):
client = boto3.client('ec2')
global_http_group = get_security_groups_for_update(client, SECURITY_GROUP_TAG_FOR_GLOBAL_HTTP)
global_https_group = get_security_groups_for_update(client, SECURITY_GROUP_TAG_FOR_GLOBAL_HTTPS)
region_http_group = get_security_groups_for_update(client, SECURITY_GROUP_TAG_FOR_REGION_HTTP)
region_https_group = get_security_groups_for_update(client, SECURITY_GROUP_TAG_FOR_REGION_HTTPS)
print ('Found ' + str(len(global_http_group)) + ' CloudFront_g HttpSecurityGroups to update')
print ('Found ' + str(len(global_https_group)) + ' CloudFront_g HttpsSecurityGroups to update')
print ('Found ' + str(len(region_http_group)) + ' CloudFront_r HttpSecurityGroups to update')
print ('Found ' + str(len(region_https_group)) + ' CloudFront_r HttpsSecurityGroups to update')
result = list()
global_http_updated = 0
global_https_updated = 0
region_http_updated = 0
region_https_updated = 0
for group in global_http_group:
if update_security_group(client, group, new_ranges["GLOBAL"], INGRESS_PORTS['Http']):
global_http_updated += 1
result.append('Updated ' + group['GroupId'])
for group in global_https_group:
if update_security_group(client, group, new_ranges["GLOBAL"], INGRESS_PORTS['Https']):
global_https_updated += 1
result.append('Updated ' + group['GroupId'])
for group in region_http_group:
if update_security_group(client, group, new_ranges["REGION"], INGRESS_PORTS['Http']):
region_http_updated += 1
result.append('Updated ' + group['GroupId'])
for group in region_https_group:
if update_security_group(client, group, new_ranges["REGION"], INGRESS_PORTS['Https']):
region_https_updated += 1
result.append('Updated ' + group['GroupId'])
result.append('Updated ' + str(global_http_updated) + ' of ' + str(len(global_http_group)) + ' CloudFront_g HttpSecurityGroups')
result.append('Updated ' + str(global_https_updated) + ' of ' + str(len(global_https_group)) + ' CloudFront_g HttpsSecurityGroups')
result.append('Updated ' + str(region_http_updated) + ' of ' + str(len(region_http_group)) + ' CloudFront_r HttpSecurityGroups')
result.append('Updated ' + str(region_https_updated) + ' of ' + str(len(region_https_group)) + ' CloudFront_r HttpsSecurityGroups')
return result
def update_security_group(client, group, new_ranges, port):
added = 0
removed = 0
if len(group['IpPermissions']) > 0:
for permission in group['IpPermissions']:
if permission['FromPort'] <= port and permission['ToPort'] >= port :
old_prefixes = list()
to_revoke = list()
to_add = list()
for range in permission['IpRanges']:
cidr = range['CidrIp']
old_prefixes.append(cidr)
if new_ranges.count(cidr) == 0:
to_revoke.append(range)
print(group['GroupId'] + ": Revoking " + cidr + ":" + str(permission['ToPort']))
for range in new_ranges:
if old_prefixes.count(range) == 0:
to_add.append({ 'CidrIp': range })
print(group['GroupId'] + ": Adding " + range + ":" + str(permission['ToPort']))
removed += revoke_permissions(client, group, permission, to_revoke)
added += add_permissions(client, group, permission, to_add)
else:
to_add = list()
for range in new_ranges:
to_add.append({ 'CidrIp': range })
print(group['GroupId'] + ": Adding " + range + ":" + str(port))
permission = { 'ToPort': port, 'FromPort': port, 'IpProtocol': 'tcp'}
added += add_permissions(client, group, permission, to_add)
print (group['GroupId'] + ": Added " + str(added) + ", Revoked " + str(removed))
return (added > 0 or removed > 0)
def revoke_permissions(client, group, permission, to_revoke):
if len(to_revoke) > 0:
revoke_params = {
'ToPort': permission['ToPort'],
'FromPort': permission['FromPort'],
'IpRanges': to_revoke,
'IpProtocol': permission['IpProtocol']
}
client.revoke_security_group_ingress(GroupId=group['GroupId'], IpPermissions=[revoke_params])
return len(to_revoke)
def add_permissions(client, group, permission, to_add):
if len(to_add) > 0:
add_params = {
'ToPort': permission['ToPort'],
'FromPort': permission['FromPort'],
'IpRanges': to_add,
'IpProtocol': permission['IpProtocol']
}
client.authorize_security_group_ingress(GroupId=group['GroupId'], IpPermissions=[add_params])
return len(to_add)
def get_security_groups_for_update(client, security_group_tag):
filters = list();
for key, value in security_group_tag.iteritems():
filters.extend(
[
{ 'Name': "tag-key", 'Values': [ key ] },
{ 'Name': "tag-value", 'Values': [ value ] }
]
)
response = client.describe_security_groups(Filters=filters)
return response['SecurityGroups']
def is_cloudtrail_event(event):
return "source" in event and event["source"] == "aws.ec2" and "detail-type" in event and event["detail-type"] == "AWS API Call via CloudTrail" and "detail" in event
def is_scalinggroup_or_sg_tag(event):
if is_cloudtrail_event(event):
detail = event["detail"]
if "eventName" in detail:
if detail["eventName"] =="CreateTags":
print('Tag created event...')
try:
for resource in detail["requestParameters"]["resourcesSet"]["items"]:
print('Tag created on resource: ' + resource["resourceId"])
if resource["resourceId"].startswith("sg-"):
return True;
except Exception, e:
print('Error attempting to parse CreateTags event:')
print(e)
if detail["eventName"] =="CreateSecurityGroup":
print('Security group created...')
try:
print('Security group name: ' + detail["requestParameters"]["groupName"])
return True
except Exception, e:
print('Error attempting to parse CreateSecurityGroup event:')
print(e)
return True
return False
'''
Sample Event From SNS:
{
"Records": [
{
"EventVersion": "1.0",
"EventSubscriptionArn": "arn:aws:sns:EXAMPLE",
"EventSource": "aws:sns",
"Sns": {
"SignatureVersion": "1",
"Timestamp": "1970-01-01T00:00:00.000Z",
"Signature": "EXAMPLE",
"SigningCertUrl": "EXAMPLE",
"MessageId": "95df01b4-ee98-5cb9-9903-4c221d41eb5e",
"Message": "{\"create-time\": \"yyyy-mm-ddThh:mm:ss+00:00\", \"synctoken\": \"0123456789\", \"md5\": \"45be1ba64fe83acb7ef247bccbc45704\", \"url\": \"https://ip-ranges.amazonaws.com/ip-ranges.json\"}",
"Type": "Notification",
"UnsubscribeUrl": "EXAMPLE",
"TopicArn": "arn:aws:sns:EXAMPLE",
"Subject": "TestInvoke"
}
}
]
}
'''
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment