Skip to content

Instantly share code, notes, and snippets.

@whatnick
Created August 10, 2020 00:33
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save whatnick/d24363dd88bda075ab0cfbacc0bb6ef9 to your computer and use it in GitHub Desktop.
Save whatnick/d24363dd88bda075ab0cfbacc0bb6ef9 to your computer and use it in GitHub Desktop.
Use datacube-core, datacube-ows, datacube-explorer in one large script
import json
import logging
import os
from datetime import date, datetime, timedelta
from functools import partial
from hashlib import md5
from pathlib import PurePath
from time import sleep
import boto3
import click
import datacube
import datacube_ows
import requests
import schedule
import sentry_sdk
from cubedash.generate import cli
from datacube import Datacube
from datacube.ui.click import pass_config, environment_option, config_option
from datacube_ows.product_ranges import create_range_entry, get_crses, add_ranges
# from datacube_ows.update_ranges import refresh_views
from index_from_s3_bucket import add_dataset, get_s3_url
from sentry_sdk.integrations.logging import LoggingIntegration
from yaml import load
# Create a custom logger
logging.basicConfig(format='%(asctime)s - %(message)s', level=os.getenv('ORCHESTRATION_LOG_LEVEL', 'INFO'))
logger = logging.getLogger(__name__)
# FIXME: Does this need to be defined globally
s3 = boto3.resource('s3')
SQS_LONG_POLL_TIME_SECS = 20
DEFAULT_POLL_TIME_SECS = 60
DEFAULT_SOURCES_POLICY = "verify"
MAX_MESSAGES_BEFORE_EXTENT_CALCULATION = 10
SQS_QUEUE_NAME = os.getenv("SQS_QUEUE_NAME")
SENTRY_DSN = os.getenv("SENTRY_DSN")
WMS_CONFIG_URL = os.getenv("WMS_CONFIG_URL")
def init_sentry():
if SENTRY_DSN is not None:
# All of this is already happening by default!
sentry_logging = LoggingIntegration(
level=logging.INFO, # Capture info and above as breadcrumbs
event_level=logging.ERROR # Send errors as events
)
sentry_sdk.init(
dsn=SENTRY_DSN,
integrations=[sentry_logging]
)
def init_wms_config():
if WMS_CONFIG_URL is not None:
# Get online WMS config and place it where OWS expects it
logger.info("Getting OWS config from github")
r = requests.get(WMS_CONFIG_URL, allow_redirects=True)
path = datacube_ows.__path__
cfg = path[0] + '/ows_cfg.py'
open(cfg, 'wb').write(r.content)
if os.path.exists(cfg):
logger.info("Successfully obtained config")
def update_cubedash(product_names):
click_ctx = click.get_current_context()
# As we are invoking a cli command, intercept the call to exit
try:
click_ctx.invoke(cli, product_names=product_names)
except SystemExit:
pass
def archive_datasets(product, days, dc, enable_cubedash=False):
def get_ids(datasets):
for d in datasets:
ds = index.datasets.get(d.id, include_sources=True)
for source in ds.sources.values():
yield source.id
yield d.id
index = dc.index
past = datetime.now() - timedelta(days=days)
query = datacube.api.query.Query(
product=product, time=[date(1970, 1, 1), past])
datasets = index.datasets.search_eager(**query.search_terms)
if len(datasets) > 0:
logger.info("Archiving datasets: %s", [d.id for d in datasets])
index.datasets.archive(get_ids(datasets))
create_range_entry(dc, product, get_crses())
if enable_cubedash:
update_cubedash([product.name])
def process_message(index, message, prefix, sources_policy=DEFAULT_SOURCES_POLICY):
# message body is a string, need to parse out json a few times
inner = json.loads(message)
s3_message = json.loads(inner["Message"])
errors = dict()
datasets = []
skipped = 0
if "Records" not in s3_message:
errors["no_record"] = "Message did not contain S3 records"
return datasets, errors
for record in s3_message["Records"]:
bucket_name = record["s3"]["bucket"]["name"]
key = record["s3"]["object"]["key"]
if prefix is None or len(prefix) == 0 or any([PurePath(key).match(p) for p in prefix]):
try:
errors[key] = None
obj = s3.Object(bucket_name, key).get(
ResponseCacheControl='no-cache')
data = load(obj['Body'].read())
# NRT data may not have a creation_dt, attempt insert if missing
if "creation_dt" not in data:
try:
data["creation_dt"] = data["extent"]["center_dt"]
except KeyError:
pass
uri = get_s3_url(bucket_name, key)
dataset, errors[key] = add_dataset(
data, uri, index, sources_policy)
if errors[key] is None:
datasets.append(dataset)
except Exception as e:
errors[key] = e
else:
logger.debug("Skipped: %s as it does not match prefix filters", key)
skipped = skipped + 1
return datasets, skipped, errors
def delete_message(sqs, queue_url, message):
receipt_handle = message["ReceiptHandle"]
sqs.delete_message(
QueueUrl=queue_url,
ReceiptHandle=receipt_handle)
logger.debug("Deleted Message %s", message.get("MessageId"))
def query_queue(sqs, queue_url, dc, prefix, poll_time=DEFAULT_POLL_TIME_SECS,
sources_policy=DEFAULT_SOURCES_POLICY, enable_cubedash=False):
index = dc.index
messages_processed = 0
products_to_update = []
while True:
response = sqs.receive_message(
QueueUrl=queue_url,
WaitTimeSeconds=SQS_LONG_POLL_TIME_SECS)
if "Messages" not in response:
if messages_processed > 0:
logger.info("Processed: %d messages", messages_processed)
messages_processed = 0
for p in products_to_update:
create_range_entry(dc, p, get_crses())
if enable_cubedash:
update_cubedash([p.name for p in products_to_update])
return
else:
for message in response.get("Messages"):
message_id = message.get("MessageId")
body = message.get("Body")
md5_of_body = message.get("MD5OfBody", "")
md5_hash = md5()
md5_hash.update(body.encode("utf-8"))
# Process message if MD5 matches
if md5_of_body == md5_hash.hexdigest():
logger.info("Processing message: %s", message_id)
messages_processed += 1
datasets, skipped, errors = process_message(
index, body, prefix, sources_policy)
for d in datasets:
product = d.type
if product not in products_to_update:
products_to_update.append(product)
if not any(errors.values()):
logger.info("Successfully processed %d datasets in %s, %d datasets were skipped",
len(datasets), message.get("MessageId"), skipped)
else:
# Do not delete message
for key, error in errors.items():
logger.error("%s had error: %s", key, error)
else:
logger.warning(
"%s MD5 hashes did not match, discarding message: %s", message_id, body)
delete_message(sqs, queue_url, message)
@click.command(help="Python script to continuously poll SQS queue that is specified")
@environment_option
@config_option
@pass_config
@click.option("--queue",
"-q",
default=SQS_QUEUE_NAME)
@click.option("--poll-time",
default=DEFAULT_POLL_TIME_SECS)
@click.option('--sources_policy',
default=DEFAULT_SOURCES_POLICY,
help="verify, ensure, skip")
@click.option("--prefix",
default=None,
multiple=True)
@click.option("--archive",
default=None,
multiple=True,
type=(str, int))
@click.option("--archive-check-time",
default="01:00")
@click.option("--multiproduct",
default=None,
multiple=True,
type=str,
help="Provide wms multi product name(s) for multiproduct ranges")
@click.option("--cubedash",
is_flag=True,
default=False)
@click.option("--views-blocking",
is_flag=True,
default=False,
help="Controls whether the materialised view refresh is blocking or concurrent. Defaults to concurrent")
def main(config,
queue,
poll_time,
sources_policy,
prefix,
archive,
archive_check_time,
multiproduct,
cubedash,
views_blocking):
init_sentry()
init_wms_config()
dc = Datacube(config=config)
if queue is not None:
sqs = boto3.client('sqs')
response = sqs.get_queue_url(QueueName=queue)
queue_url = response.get('QueueUrl')
query = partial(
query_queue,
sqs,
queue_url,
dc,
prefix,
poll_time=poll_time,
sources_policy=sources_policy,
enable_cubedash=cubedash)
schedule.every(poll_time).seconds.do(query)
for product, days in archive:
do_archive = partial(
archive_datasets,
product,
days,
dc,
enable_cubedash=cubedash)
do_archive()
schedule.every().day.at(archive_check_time).do(do_archive)
# This needs to be maintained to align with current OWS API
if multiproduct is not None:
try:
do_update_ranges_mp = partial(
add_ranges, # def add_ranges(dc, product_names, summary=False, merge_only=False):
dc,
# multiproduct,
merge_only=True
)
def run_all_mp():
try:
logger.info("Updating extents for multi-products")
list(map(do_update_ranges_mp, multiproduct))
except Exception as e:
error_str = f"Failed Multi products {multiproduct} with exception {e}"
logger.error(error_str)
logger.info("Configure schedule for multi-products with extents")
schedule.every(3).hours.do(run_all_mp)
except Exception as e:
error_str = f"Failed Multi product {multiproduct} with exception {e}"
logger.error(error_str)
while True:
schedule.run_pending()
sleep(1)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment