Skip to content

Instantly share code, notes, and snippets.

@cwells
Created March 5, 2018 21:31
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save cwells/293b1d0da73943de59242c0ef5e29ae2 to your computer and use it in GitHub Desktop.
Save cwells/293b1d0da73943de59242c0ef5e29ae2 to your computer and use it in GitHub Desktop.
#!/usr/bin/python3
import os
import re
import logging
import boto3
from queue import Queue
from threading import Thread
from datetime import datetime, timedelta, timezone
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
logging.getLogger("requests.packages.urllib3.connectionpool").setLevel(logging.WARNING)
DEFAULT_SNAPSHOT_POLICY = 'daily: 7'
def boto_client(*args, **kwargs):
client = boto3.client(*args, **kwargs)
client.meta.events._unique_id_handlers['retry-config-ec2']['handler']._checker.__dict__['_max_attempts'] = 20
return client
def tag_dict(obj):
'''convert the structure returned by API into dictionary
'''
return { tag['Key']: tag['Value'] for tag in obj.get('Tags', []) }
def parse_policy(policy):
'''convert a policy spec into a dictionary
'''
matches = re.findall(r'\s?([a-z]+)\s?:\s?(\d+)\s?,?', policy)
return { k: int(v) for (k, v) in matches }
def instances_for_region(region, interval):
'''get all instances for this region that have a policy for the interval
'''
def test(instance):
tags = tag_dict(instance)
snapshot_policy = tags.get('SnapshotPolicy', DEFAULT_SNAPSHOT_POLICY)
policy = parse_policy(snapshot_policy)
return interval in policy
ec2 = boto_client('ec2', region_name=region)
filters=[
{ 'Name': 'tag:SnapshotPolicy', 'Values': ['*{}*'.format(interval)] }
]
reservations = ec2.describe_instances(Filters=filters).get('Reservations', [])
for r in reservations:
for i in r['Instances']:
if test(i):
yield i
def expired_snapshots_for_region(region, interval):
ec2 = boto_client('ec2', region_name=region)
filters=[
{ 'Name': 'tag:SnapshotPolicy', 'Values': ['*{}*'.format(interval)] }
]
reservations = ec2.describe_instances(Filters=filters).get('Reservations', [])
for r in reservations:
for instance in r['Instances']:
tags = tag_dict(instance)
snapshot_policy = tags['SnapshotPolicy']
policy = parse_policy(snapshot_policy)
for dev in instance['BlockDeviceMappings']:
if dev.get('Ebs', None) is None:
continue
snapshots = ec2.describe_snapshots(
OwnerIds=['self'],
Filters=[
{ 'Name': 'volume-id', 'Values': [ dev['Ebs']['VolumeId'] ] },
{ 'Name': 'tag:SnapshotInterval', 'Values': [ interval ]}
]
)['Snapshots']
sorted_snapshots = sorted(snapshots, key=lambda s: s['StartTime'], reverse=True)[policy[interval]:]
for snapshot in sorted_snapshots:
yield snapshot
def get_instances(regions, interval, queue):
'''producer that fills create queue
'''
for region in regions:
for instance in instances_for_region(region, interval=interval):
queue.put((region, instance))
def get_snapshots(regions, interval, queue):
'''producer that fills delete queue
'''
for region in regions:
for snapshot in expired_snapshots_for_region(region, interval=interval):
queue.put((region, snapshot))
def create_snapshot(queue, interval):
while True:
region, instance = queue.get()
ec2 = boto_client('ec2', region_name=region)
tags = tag_dict(instance)
logger.info("Creating snapshot of instance {} in {}".format(instance['InstanceId'], region))
for dev in instance['BlockDeviceMappings']:
if dev.get('Ebs', None) is None:
continue
snapshot_description = '{} snapshot of {} on {}'.format(
interval.capitalize(),
instance['InstanceId'],
datetime.now(timezone.utc).strftime('%Y-%m-%d')
)
try:
snapshot = ec2.create_snapshot(VolumeId=dev['Ebs']['VolumeId'], Description=snapshot_description)
except:
logger.error("Failed to create snapshot for {} in {}".format(instance['InstanceId'], region))
continue
else:
logger.info("Created snapshot for {} in {}".format(instance['InstanceId'], region))
ec2.create_tags(
Resources=[snapshot['SnapshotId']],
Tags=[
{ 'Key': 'Name', 'Value': tags.get('Name', '') },
{ 'Key': 'ClientName', 'Value': tags.get('ClientName', '') },
{ 'Key': 'Platform', 'Value': tags.get('Platform', '') },
{ 'Key': 'Tenancy', 'Value': tags.get('Tenancy', '') },
{ 'Key': 'SnapshotInterval', 'Value': interval }
]
)
queue.task_done()
def delete_snapshot(queue, interval):
while True:
region, snapshot = queue.get()
ec2 = boto_client('ec2', region_name=region)
logger.info("Deleting snapshot {} created on {} in {}".format(snapshot['SnapshotId'], snapshot['StartTime'], region))
try:
ec2.delete_snapshot(SnapshotId=snapshot['SnapshotId'])
except:
logger.warn("Can't delete snapshot {}".format(snapshot['SnapshotId']))
queue.task_done()
def lambda_handler(event, context):
'''start up threadpool, and populate the queue with snapshots
'''
interval = os.environ.get('interval', 'daily')
num_threads = int(os.environ.get('threads', 4))
logger.info("Processing {} interval".format(interval))
ec2 = boto3.client('ec2')
regions = [ region['RegionName'] for region in ec2.describe_regions()['Regions'] ]
create_queue = Queue(maxsize=0)
delete_queue = Queue(maxsize=0)
instance_producer = Thread(target=get_instances, args=(regions, interval, create_queue))
snapshot_producer = Thread(target=get_snapshots, args=(regions, interval, delete_queue))
instance_producer.start()
snapshot_producer.start()
instance_producer.join()
snapshot_producer.join()
for i in range(num_threads):
consumer = Thread(target=create_snapshot, args=(create_queue, interval))
consumer.setDaemon(True)
consumer.start()
consumer = Thread(target=delete_snapshot, args=(delete_queue, interval))
consumer.setDaemon(True)
consumer.start()
create_queue.join()
delete_queue.join()
if __name__ == '__main__':
lambda_handler(None, None)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment