Skip to content

Instantly share code, notes, and snippets.

@StevenACoffman
Last active February 14, 2017 15:01
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save StevenACoffman/c48efc8bb1ec114c32d2e91e217f7a47 to your computer and use it in GitHub Desktop.
Save StevenACoffman/c48efc8bb1ec114c32d2e91e217f7a47 to your computer and use it in GitHub Desktop.
Recursively walk S3 and tag items based upon a mapping
#!/usr/bin/env python
import argparse
import boto3
import fnmatch
import json
import logging
import re
from collections import OrderedDict
from datetime import datetime
from io import open
LOGGER_NAME = None
class Config(object):
mapping_list = None
mapping_list_bucket = 'sequoia-short-lived'
mapping_list_key = 'mapping.json'
profile = None
def log():
return logging.getLogger(LOGGER_NAME)
def config_logger():
logger = log()
if not logger.handlers:
logger.addHandler(logging.StreamHandler())
logger.setLevel(logging.INFO)
def read_file_from_s3(bucket, key):
session = boto3.session.Session(profile_name=Config.profile)
s3 = session.resource('s3')
obj = s3.Object(bucket, key)
file_contents = obj.get()['Body'].read().decode('utf-8')
return json.loads(file_contents, object_pairs_hook=OrderedDict)
def read_local_file(mapping_file_name):
f = open(mapping_file_name, encoding='utf-8')
mapping_file_name_text = f.read() # unicode, not bytes
return json.loads(mapping_file_name_text, object_pairs_hook=OrderedDict)
def config_mapping_list_from_s3():
Config.mapping_list = read_file_from_s3(Config.mapping_list_bucket,
Config.mapping_list_key)
def json_serial(obj):
"""JSON serializer for objects not serializable by default json code"""
if isinstance(obj, datetime):
serial = obj.isoformat()
return serial
raise TypeError("Type not serializable")
class S3URL:
"""Simple wrapper for S3 URL.
This class parses a S3 URL and provides accessors to each component.
"""
S3URL_PATTERN = re.compile(r'(s3[n]?)://([^/]+)[/]?(.*)')
PATH_SEP = "/"
def __init__(self, bucket=None, path=None, uri=None):
"""Initialization, parse S3 URL"""
if uri:
try:
self.proto, self.bucket, self.path = S3URL.S3URL_PATTERN.match(uri).groups()
self.proto = self.proto.rstrip("n") # normalize s3n => s3
except Exception:
raise RuntimeError('Invalid S3 URI: %s' % uri)
else:
self.proto = 's3' # normalize s3n => s3
self.bucket, self.path = bucket, path
self.uri = S3URL.combine(self.proto, self.bucket, self.path)
def __str__(self):
"""Return the original S3 URL"""
return S3URL.combine(self.proto, self.bucket, self.path)
def get_fixed_path(self):
"""Get the fixed part of the path without wildcard"""
pi = self.path.split(S3URL.PATH_SEP)
return S3URL.PATH_SEP.join([p for p in pi])
@staticmethod
def combine(proto, bucket, path):
"""Combine each component and generate a S3 url string,
no path normalization here.
The path should not start with slash.
"""
return '%s://%s/%s' % (proto, bucket, path)
@staticmethod
def is_valid(uri):
"""Check if given uri is a valid S3 URL"""
return S3URL.S3URL_PATTERN.match(uri) is not None
def s3_walk_and_tag(s3url, s3dir):
"""Recursively walk into all subdirectories"""
result = []
session = boto3.session.Session(profile_name=Config.profile)
client = session.client('s3')
paginator = client.get_paginator('list_objects')
for page in paginator.paginate(
Bucket=s3url.bucket,
Prefix=s3dir,
Delimiter=S3URL.PATH_SEP,
PaginationConfig={'PageSize': 1000}):
# Get subdirectories first.
for obj in page.get('CommonPrefixes', []):
result.extend(s3_walk_and_tag(s3url, obj['Prefix']))
# Then get all items in this folder.
for obj in page.get('Contents', []):
obj_name = obj['Key']
# Do we want to log all this?
s3_object = {
'Bucket': s3url.bucket,
'Key': obj_name,
'name': S3URL.combine(s3url.proto, s3url.bucket, obj_name),
'is_dir': False,
'size': obj['Size'],
'last_modified': obj['LastModified']
}
update_object_tags_by_mapping(s3url.bucket, obj_name)
result.append(s3_object)
return result
def filter_key_collisions(list_of_dicts, key):
merged = OrderedDict({})
for item in list_of_dicts:
merged.setdefault(item[key], OrderedDict({})).update(item)
return merged.values()
def update_object_tags_by_mapping(bucket, s3_key):
log().info('Attempting to Tag s3://%s/%s', bucket, s3_key)
session = boto3.session.Session(profile_name=Config.profile)
client = session.client("s3")
tag_set_keys = ["Key", "Value"]
current_tags = client.get_object_tagging(Bucket=bucket, Key=s3_key).get('TagSet', [])
mapped_tags = [{key: mapping[key] for key in tag_set_keys}
for mapping in Config.mapping_list
if fnmatch.fnmatch(s3_key, mapping["Filter"])]
if mapped_tags and any(li not in current_tags for li in mapped_tags):
to_be_put_tag_set = filter_key_collisions(current_tags + mapped_tags, 'Key')
response = client.put_object_tagging(Bucket=bucket,
Key=s3_key,
Tagging={'TagSet': to_be_put_tag_set})
json_added_tags = json.dumps(mapped_tags, default=json_serial)
if response.get("ResponseMetadata", {}).get("HTTPStatusCode", 418) == 200:
log().error('Successfully tagged s3://%s/%s with tags %s',
bucket,
s3_key,
json_added_tags)
else:
json_response = json.dumps(response, default=json_serial)
log().error('Error Attempting to tag s3://%s/%s with tags %s with response %s', bucket, s3_key,
json_added_tags, json_response)
def tag_s3_objects(uris):
for s3_url in uris:
if s3_url.is_valid:
update_object_tags_by_mapping(s3_url.bucket, s3_url.path)
else:
log().info('Invalid uri %s', s3_url)
def process_args(source=None):
description = 'Recursively tag S3 objects based on mapping file'
parser = argparse.ArgumentParser(description=description)
parser.add_argument(
'-p', '--profile',
dest='profile',
action='store',
type=str,
default=Config.profile,
help='AWS profile')
parser.add_argument(
'-f',
'--file',
dest='mapping_local_file',
action='store',
type=str,
default=None,
help='Local mapping file or None')
parser.add_argument('input', type=str, nargs='+', help='URI to process')
args = parser.parse_args(args=source)
Config.profile = args.profile
if args.mapping_local_file:
Config.mapping_list = read_local_file(args.mapping_local_file)
else:
config_mapping_list_from_s3()
return args
def main():
config_logger()
s3urls = [S3URL(uri=s3url) for s3url in process_args().input]
result = s3_walk_and_tag(s3urls[0], s3urls[0].get_fixed_path())
log().info('Successfully tagged %s objects', len(result))
if __name__ == '__main__':
main()

S3 Object tagger

Starting at a root S3 URL, recursively add tags to S3 objects, applying arbitrary number of tags based on matching arbitrary prefix patterns. Existing tags on objects will be preserved unless new values are specified.

Intended to be executed like s3_price_tag.py "s3://myBucket/some/long/crazy/path" mapping.json

The mapping file can contain filter paths like some/long/* or */long* to match, and any object that matches those filters will have those tags applied.

If multiple filters match, with identical tag keys, it will take the last tag. For instance, in the example below, the tag sequoia:environment would be set to value test for all matching objects:

{
"Filter": "AWSLogs/*",
"Key": "sequoia:environment",
"Value": "prod"
},
{
"Filter": "AWSLogs/*",
"Key": "sequoia:environment",
"Value": "test"
}

Lifted lots from s4cmd.py.

[
{
"Filter": "test/*",
"Value": "test",
"Key": "sequoia:environment"
},
{
"Filter": "prod/*",
"Value": "prod",
"Key": "sequoia:environment"
},
{
"Filter": "AWSLogs/594813696195/elasticloadbalancing/us-east-1/2016/04/04/*",
"Value": "test",
"Key": "sequoia:environment"
}
]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment