Skip to content

Instantly share code, notes, and snippets.

@tom-butler
Created February 21, 2019 23:44
Show Gist options
  • Save tom-butler/f470c4fbb7fe297437884cc2acab9741 to your computer and use it in GitHub Desktop.
Save tom-butler/f470c4fbb7fe297437884cc2acab9741 to your computer and use it in GitHub Desktop.
Infrastructure for watching SNS topic
import boto3
import click
import json
import schedule
from time import sleep
from hashlib import md5
from yaml import load
from functools import partial
from datetime import date, datetime, timedelta
from index_from_s3_bucket import add_dataset, get_s3_url
import datacube
from datacube import Datacube
from datacube.ui.click import pass_config, environment_option, config_option
from datacube_wms.product_ranges import add_range
from cubedash.generate import cli
import logging
_LOG = logging.getLogger("datakube-orchestration")
_LOG.setLevel("INFO")
sqs = boto3.client('sqs')
s3 = boto3.resource('s3')
SQS_LONG_POLL_TIME_SECS = 20
DEFAULT_POLL_TIME_SECS = 60
DEFAULT_SOURCES_POLICY="verify"
MAX_MESSAGES_BEFORE_EXTENT_CALCULATION = 10
def update_cubedash(product_names):
click_ctx = click.get_current_context()
# As we are invoking a cli command, intercept the call to exit
try:
click_ctx.invoke(cli, product_names=product_names)
except SystemExit:
pass
def archive_datasets(product, days, dc, enable_cubedash=False):
def get_ids(datasets):
for d in datasets:
ds = index.datasets.get(d.id, include_sources=True)
for source in ds.sources.values():
yield source.id
yield d.id
index = dc.index
past = datetime.now() - timedelta(days=days)
query = datacube.api.query.Query(product=product, time=[date(1970, 1, 1), past])
datasets = index.datasets.search_eager(**query.search_terms)
if len(datasets) > 0:
_LOG.info("Archiving datasets: %s", [d.id for d in datasets])
index.datasets.archive(get_ids(datasets))
add_range(dc, product)
if enable_cubedash:
update_cubedash([product.name])
def process_message(index, message, prefix, sources_policy=DEFAULT_SOURCES_POLICY):
# message body is a string, need to parse out json a few times
inner = json.loads(message)
s3_message = json.loads(inner["Message"])
errors = dict()
datasets = []
skipped = 0
if "Records" not in s3_message:
errors["no_record"] = "Message did not contain S3 records"
return datasets, errors
for record in s3_message["Records"]:
bucket_name = record["s3"]["bucket"]["name"]
key = record["s3"]["object"]["key"]
if len(prefix) is 0 or key.startswith(tuple(prefix)):
try:
errors[key] = None
obj = s3.Object(bucket_name, key).get(ResponseCacheControl='no-cache')
data = load(obj['Body'].read())
# NRT data may not have a creation_dt, attempt insert if missing
if "creation_dt" not in data:
try:
data["creation_dt"] = data["extent"]["center_dt"]
except KeyError:
pass
uri = get_s3_url(bucket_name, key)
# index into datacube
dataset, errors[key] = add_dataset(data, uri, index, sources_policy)
if errors[key] is None:
datasets.append(dataset)
except Exception as e:
errors[key] = e
else:
_LOG.debug("Skipped: %s as it does not match prefix filters", key)
skipped = skipped + 1
return datasets, skipped, errors
def delete_message(sqs, queue_url, message):
receipt_handle = message["ReceiptHandle"]
sqs.delete_message(
QueueUrl=queue_url,
ReceiptHandle=receipt_handle)
_LOG.debug("Deleted Message %s", message.get("MessageId"))
def query_queue(sqs, queue_url, dc, prefix, poll_time=DEFAULT_POLL_TIME_SECS,
sources_policy=DEFAULT_SOURCES_POLICY, enable_cubedash=False):
index = dc.index
messages_processed = 0
products_to_update = []
while True:
response = sqs.receive_message(
QueueUrl=queue_url,
WaitTimeSeconds=SQS_LONG_POLL_TIME_SECS)
if "Messages" not in response:
if messages_processed > 0:
_LOG.info("Processed: %d messages", messages_processed)
messages_processed = 0
for p in products_to_update:
add_range(dc, p)
if enable_cubedash:
update_cubedash([p.name for p in products_to_update])
return
else:
for message in response.get("Messages"):
message_id = message.get("MessageId")
body = message.get("Body")
md5_of_body = message.get("MD5OfBody", "")
md5_hash = md5()
md5_hash.update(body.encode("utf-8"))
# Process message if MD5 matches
if (md5_of_body == md5_hash.hexdigest()):
_LOG.info("Processing message: %s", message_id)
messages_processed += 1
datasets, skipped, errors = process_message(index, body, prefix, sources_policy)
for d in datasets:
product = d.type
if product not in products_to_update:
products_to_update.append(product)
if not any(errors.values()):
_LOG.info("Successfully processed all datasets in %s, %d datasets were skipped",
message.get("MessageId"), skipped)
else:
# Do not delete message
for key, error in errors.items():
_LOG.error("%s had error: %s", key, error)
else:
_LOG.warning("%s MD5 hashes did not match, discarding message: %s", message_id, body)
delete_message(sqs, queue_url, message)
@click.command(help="Python script to continuously poll SQS queue that is specified")
@environment_option
@config_option
@pass_config
@click.option("--queue",
"-q",
default=None)
@click.option("--poll-time",
default=DEFAULT_POLL_TIME_SECS)
@click.option('--sources_policy',
default=DEFAULT_SOURCES_POLICY,
help="verify, ensure, skip")
@click.option("--prefix",
default=None,
multiple=True)
@click.option("--archive",
default=None,
multiple=True,
type=(str, int))
@click.option("--archive-check-time",
default="01:00")
@click.option("--cubedash",
is_flag=True,
default=False)
def main(config, queue, poll_time, sources_policy, prefix, archive, archive_check_time, cubedash):
dc = Datacube(config=config)
if queue is not None:
sqs = boto3.client('sqs')
response = sqs.get_queue_url(QueueName=queue)
queue_url = response.get('QueueUrl')
query = partial(
query_queue,
sqs,
queue_url,
dc,
prefix,
poll_time=poll_time,
sources_policy=sources_policy,
enable_cubedash=cubedash)
schedule.every(poll_time).seconds.do(query)
for product, days in archive:
do_archive = partial(
archive_datasets,
product,
days,
dc,
enable_cubedash=cubedash)
do_archive()
schedule.every().day.at(archive_check_time).do(do_archive)
while True:
schedule.run_pending()
sleep(1)
if __name__ == "__main__":
main()
# Variables
variable "bucket" {
type = "string"
description = "S3 bucket to add SNS notification to"
default = "dea-public-data"
}
variable "services" {
type = "list"
description = "list of services that will require an SQS queue"
default = ["ows"]
}
variable "region" {
default = "ap-southeast-2"
}
variable "name" {
description = "Name for resources"
type = "string"
default = "dea-data"
}
variable "topic_arn" {
type = "string"
description = "ARN of SNS topic to subscribe to"
default = "arn:aws:sns:ap-southeast-2:538673716275:DEANewData"
}
# Config
terraform {
required_version = ">= 0.11.0"
}
provider "aws" {
region = "${var.region}"
}
data "aws_caller_identity" "current" {}
# Resources
resource "aws_sns_topic_subscription" "sqs_subscriptions" {
count = "${length(var.services)}"
topic_arn = "${var.topic_arn}"
protocol = "sqs"
endpoint = "${element(aws_sqs_queue.queues.*.arn, count.index)}"
}
resource "aws_sqs_queue" "queues" {
count = "${length(var.services)}"
name = "${element(var.services, count.index)}"
kms_master_key_id = "${aws_kms_alias.sqs.arn}"
}
resource "aws_sqs_queue_policy" "queue_policy" {
count = "${length(var.services)}"
queue_url = "${element(aws_sqs_queue.queues.*.id, count.index)}"
policy = <<POLICY
{
"Version":"2012-10-17",
"Statement":[
{
"Sid":"MySQSPolicy001",
"Effect":"Allow",
"Principal":"*",
"Action":"sqs:SendMessage",
"Resource":"${element(aws_sqs_queue.queues.*.arn, count.index)}",
"Condition":{
"ArnEquals":{
"aws:SourceArn":"${var.topic_arn}"
}
}
}
]
}
POLICY
}
# ======================================
# Orchestration Role
resource "aws_iam_role" "orchestration" {
name = "orchestration"
assume_role_policy = <<EOF
{
"Version": "2012-10-17",
"Statement": [
{
"Sid": "",
"Effect": "Allow",
"Principal": {
"Service": "ec2.amazonaws.com"
},
"Action": "sts:AssumeRole"
},
{
"Sid": "",
"Effect": "Allow",
"Principal": {
"AWS": "arn:aws:iam::${data.aws_caller_identity.current.account_id}:role/nodes.${var.name}"
},
"Action": "sts:AssumeRole"
}
]
}
EOF
}
resource "aws_iam_role_policy" "orchestration" {
name = "orchestration"
role = "${aws_iam_role.orchestration.id}"
policy = <<EOF
{
"Version": "2012-10-17",
"Statement": [
{
"Effect": "Allow",
"Action": [
"sqs:ReceiveMessage",
"sqs:GetQueueUrl",
"sqs:DeleteMessage",
"sqs:GetQueueAttributes",
"sqs:ListQueues"
],
"Resource": "*"
},
{
"Effect": "Allow",
"Action": ["S3:GetObject"],
"Resource": [
"arn:aws:s3:::dea-public-data/*"
]
},
{
"Effect": "Allow",
"Action": ["kms:Decrypt"],
"Resource": [
"${aws_kms_key.sqs.arn}"
]
}
]
}
EOF
}
#======================
# SQS Encryption
resource "aws_kms_key" "sqs" {
description = "KMS Key for encrypting SQS Queue for ${var.bucket} notifications"
deletion_window_in_days = 30
policy = <<POLICY
{
"Version": "2012-10-17",
"Statement": [
{
"Sid": "Allow administration of the key",
"Effect": "Allow",
"Principal": { "AWS": "arn:aws:iam::${data.aws_caller_identity.current.account_id}:root" },
"Action": [
"kms:Create*",
"kms:Describe*",
"kms:Enable*",
"kms:List*",
"kms:Put*",
"kms:Update*",
"kms:Revoke*",
"kms:Disable*",
"kms:Get*",
"kms:Delete*",
"kms:ScheduleKeyDeletion",
"kms:CancelKeyDeletion"
],
"Resource": "*"
},
{
"Effect": "Allow",
"Principal": {
"Service": "sns.amazonaws.com"
},
"Action": [
"kms:GenerateDataKey",
"kms:Decrypt"
],
"Resource": "*"
},
{
"Sid": "",
"Effect": "Allow",
"Principal": {
"AWS": "arn:aws:iam::${data.aws_caller_identity.current.account_id}:root"
},
"Action": "kms:Decrypt",
"Resource": "*"
}]
}
POLICY
}
resource "aws_kms_alias" "sqs" {
name = "alias/sqs-${var.bucket}"
target_key_id = "${aws_kms_key.sqs.key_id}"
}
@tom-butler
Copy link
Author

tom-butler commented Feb 21, 2019

Overview

sqs.tf is a terraform script that will create a Simple Queue Service in your AWS account, that will track the Simple Notification Service: arn:aws:sns:ap-southeast-2:538673716275:DEANewData

It will also create a role called orchestration that can be used as an instance profile on an EC2 instance.

orchestration.py is a script that we run to poll the sqs queue and index the yaml files into datacube, you'll need to adjust this script to fit your indexing process. We also use this script to remove old NRT data from our index with the --archive flag.

We run the script with a 0.1 vCPU limit it's memory usage spikes to around 120 MiB

Usage

  1. Install Terraform (https://www.terraform.io/downloads.html)
  2. Setup aws access credentials in your CLI
  3. terraform init - this will initialize the terraform environment
  4. terraform plan - this will show you the changes
  5. terraform apply -auto-approve - this will deploy the infrastructure
  6. create an instance with the orchestration role
  7. copy the python script to the instance
  8. run it like python3 orchestrate.py --queue ows --poll-time 60 --prefix L2/sentinel-2-nrt/S2MSIARD --prefix WOfS/WOFLs/v2.1.5/combined --archive s2a_nrt_granule 90 --archive s2b_nrt_granule 90

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment