Created
March 5, 2018 21:31
-
-
Save cwells/293b1d0da73943de59242c0ef5e29ae2 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/python3 | |
import os | |
import re | |
import logging | |
import boto3 | |
from queue import Queue | |
from threading import Thread | |
from datetime import datetime, timedelta, timezone | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
logging.getLogger("requests.packages.urllib3.connectionpool").setLevel(logging.WARNING) | |
DEFAULT_SNAPSHOT_POLICY = 'daily: 7' | |
def boto_client(*args, **kwargs): | |
client = boto3.client(*args, **kwargs) | |
client.meta.events._unique_id_handlers['retry-config-ec2']['handler']._checker.__dict__['_max_attempts'] = 20 | |
return client | |
def tag_dict(obj): | |
'''convert the structure returned by API into dictionary | |
''' | |
return { tag['Key']: tag['Value'] for tag in obj.get('Tags', []) } | |
def parse_policy(policy): | |
'''convert a policy spec into a dictionary | |
''' | |
matches = re.findall(r'\s?([a-z]+)\s?:\s?(\d+)\s?,?', policy) | |
return { k: int(v) for (k, v) in matches } | |
def instances_for_region(region, interval): | |
'''get all instances for this region that have a policy for the interval | |
''' | |
def test(instance): | |
tags = tag_dict(instance) | |
snapshot_policy = tags.get('SnapshotPolicy', DEFAULT_SNAPSHOT_POLICY) | |
policy = parse_policy(snapshot_policy) | |
return interval in policy | |
ec2 = boto_client('ec2', region_name=region) | |
filters=[ | |
{ 'Name': 'tag:SnapshotPolicy', 'Values': ['*{}*'.format(interval)] } | |
] | |
reservations = ec2.describe_instances(Filters=filters).get('Reservations', []) | |
for r in reservations: | |
for i in r['Instances']: | |
if test(i): | |
yield i | |
def expired_snapshots_for_region(region, interval): | |
ec2 = boto_client('ec2', region_name=region) | |
filters=[ | |
{ 'Name': 'tag:SnapshotPolicy', 'Values': ['*{}*'.format(interval)] } | |
] | |
reservations = ec2.describe_instances(Filters=filters).get('Reservations', []) | |
for r in reservations: | |
for instance in r['Instances']: | |
tags = tag_dict(instance) | |
snapshot_policy = tags['SnapshotPolicy'] | |
policy = parse_policy(snapshot_policy) | |
for dev in instance['BlockDeviceMappings']: | |
if dev.get('Ebs', None) is None: | |
continue | |
snapshots = ec2.describe_snapshots( | |
OwnerIds=['self'], | |
Filters=[ | |
{ 'Name': 'volume-id', 'Values': [ dev['Ebs']['VolumeId'] ] }, | |
{ 'Name': 'tag:SnapshotInterval', 'Values': [ interval ]} | |
] | |
)['Snapshots'] | |
sorted_snapshots = sorted(snapshots, key=lambda s: s['StartTime'], reverse=True)[policy[interval]:] | |
for snapshot in sorted_snapshots: | |
yield snapshot | |
def get_instances(regions, interval, queue): | |
'''producer that fills create queue | |
''' | |
for region in regions: | |
for instance in instances_for_region(region, interval=interval): | |
queue.put((region, instance)) | |
def get_snapshots(regions, interval, queue): | |
'''producer that fills delete queue | |
''' | |
for region in regions: | |
for snapshot in expired_snapshots_for_region(region, interval=interval): | |
queue.put((region, snapshot)) | |
def create_snapshot(queue, interval): | |
while True: | |
region, instance = queue.get() | |
ec2 = boto_client('ec2', region_name=region) | |
tags = tag_dict(instance) | |
logger.info("Creating snapshot of instance {} in {}".format(instance['InstanceId'], region)) | |
for dev in instance['BlockDeviceMappings']: | |
if dev.get('Ebs', None) is None: | |
continue | |
snapshot_description = '{} snapshot of {} on {}'.format( | |
interval.capitalize(), | |
instance['InstanceId'], | |
datetime.now(timezone.utc).strftime('%Y-%m-%d') | |
) | |
try: | |
snapshot = ec2.create_snapshot(VolumeId=dev['Ebs']['VolumeId'], Description=snapshot_description) | |
except: | |
logger.error("Failed to create snapshot for {} in {}".format(instance['InstanceId'], region)) | |
continue | |
else: | |
logger.info("Created snapshot for {} in {}".format(instance['InstanceId'], region)) | |
ec2.create_tags( | |
Resources=[snapshot['SnapshotId']], | |
Tags=[ | |
{ 'Key': 'Name', 'Value': tags.get('Name', '') }, | |
{ 'Key': 'ClientName', 'Value': tags.get('ClientName', '') }, | |
{ 'Key': 'Platform', 'Value': tags.get('Platform', '') }, | |
{ 'Key': 'Tenancy', 'Value': tags.get('Tenancy', '') }, | |
{ 'Key': 'SnapshotInterval', 'Value': interval } | |
] | |
) | |
queue.task_done() | |
def delete_snapshot(queue, interval): | |
while True: | |
region, snapshot = queue.get() | |
ec2 = boto_client('ec2', region_name=region) | |
logger.info("Deleting snapshot {} created on {} in {}".format(snapshot['SnapshotId'], snapshot['StartTime'], region)) | |
try: | |
ec2.delete_snapshot(SnapshotId=snapshot['SnapshotId']) | |
except: | |
logger.warn("Can't delete snapshot {}".format(snapshot['SnapshotId'])) | |
queue.task_done() | |
def lambda_handler(event, context): | |
'''start up threadpool, and populate the queue with snapshots | |
''' | |
interval = os.environ.get('interval', 'daily') | |
num_threads = int(os.environ.get('threads', 4)) | |
logger.info("Processing {} interval".format(interval)) | |
ec2 = boto3.client('ec2') | |
regions = [ region['RegionName'] for region in ec2.describe_regions()['Regions'] ] | |
create_queue = Queue(maxsize=0) | |
delete_queue = Queue(maxsize=0) | |
instance_producer = Thread(target=get_instances, args=(regions, interval, create_queue)) | |
snapshot_producer = Thread(target=get_snapshots, args=(regions, interval, delete_queue)) | |
instance_producer.start() | |
snapshot_producer.start() | |
instance_producer.join() | |
snapshot_producer.join() | |
for i in range(num_threads): | |
consumer = Thread(target=create_snapshot, args=(create_queue, interval)) | |
consumer.setDaemon(True) | |
consumer.start() | |
consumer = Thread(target=delete_snapshot, args=(delete_queue, interval)) | |
consumer.setDaemon(True) | |
consumer.start() | |
create_queue.join() | |
delete_queue.join() | |
if __name__ == '__main__': | |
lambda_handler(None, None) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment