Skip to content

Instantly share code, notes, and snippets.

@heitorlessa
Last active March 12, 2024 10:28
Show Gist options
  • Save heitorlessa/4aad06c39a1d520ff8c42adc72b0bcd5 to your computer and use it in GitHub Desktop.
Save heitorlessa/4aad06c39a1d520ff8c42adc72b0bcd5 to your computer and use it in GitHub Desktop.
Lambda Dummy Extension POC
import signal
import sys
import time
from aws_lambda_powertools import Logger, Metrics, Tracer
from aws_lambda_powertools.event_handler.api_gateway import ApiGatewayResolver
from aws_lambda_powertools.logging import correlation_paths
import os
import json
import urllib3
import threading
logger = Logger(service="APP")
tracer = Tracer(service="APP")
metrics = Metrics(namespace="MyApp", service="APP")
app = ApiGatewayResolver()
http = urllib3.PoolManager()
#### DUMMY EXTENSION SECTION
EXTENSION_HOST = os.getenv("AWS_LAMBDA_RUNTIME_API")
EXTENSION_ENDPOINT = f"http://{EXTENSION_HOST}/2020-01-01/extension"
EXTENSION_NEXT_EVENT_URL = f"{EXTENSION_ENDPOINT}/event/next"
EXTENSION_REGISTER_URL = f"{EXTENSION_ENDPOINT}/register"
EXTENSION_ID_HEADER = "Lambda-Extension-Identifier"
EXTENSION_REGISTRATION_HEADERS = {
"Lambda-Extension-Name": "dummy",
"Content-Type": "application/json",
}
CLEANUP_ENDPOINT = (
"https://jfhth8lff7.execute-api.eu-west-1.amazonaws.com/Prod/shutdown"
)
def register_dummy_extension():
def activate_extension(extension_id: str):
## a single /next call activates the extension
http.request(
method="GET",
url=EXTENSION_NEXT_EVENT_URL,
headers={"Lambda-Extension-Identifier": extension_id},
)
logger.info("Extension activated", extension_id=extension_id})
logger.info("Registering the extension")
registration = http.request(
method="POST",
url=EXTENSION_REGISTER_URL,
headers=EXTENSION_REGISTRATION_HEADERS,
body=json.dumps({"events": []}),
)
logger.info("Activating extension", extension_id=extension_id, registration=registration.data)
extension_id = registration.headers[EXTENSION_ID_HEADER]
threading.Thread(target=activate_extension, args=[extension_id]).start() ## a must have
def register_cleanup_handler():
def cleanup(signal, frame):
logger.info("Received SIGTERM; shutting down...")
http.request(method="GET", url=CLEANUP_ENDPOINT)
sys.exit(0)
signal.signal(signal.SIGTERM, cleanup)
register_dummy_extension()
register_cleanup_handler()
#### DUMMY EXTENSION SECTION
@app.get("/hello")
@tracer.capture_method
def hello():
time.sleep(5) # force timeout to receive SIGTERM
return {"message": "hello stranger!"}
@tracer.capture_lambda_handler
@logger.inject_lambda_context(
correlation_id_path=correlation_paths.API_GATEWAY_REST, log_event=True
)
@metrics.log_metrics(capture_cold_start_metric=True)
def lambda_handler(event, context):
try:
return app.resolve(event, context)
except Exception as e:
logger.exception(e)
raise
@heitorlessa
Copy link
Author

heitorlessa commented Aug 21, 2023

Sample using Batch processing to demonstrate POC in a SQS long-processing example

import os
import boto3
import signal
import threading
import requests
import sys

from aws_lambda_powertools import Logger, Tracer
from aws_lambda_powertools.utilities.batch import (
    BatchProcessor,
    EventType,
    process_partial_response,
)
from aws_lambda_powertools.utilities.data_classes.sqs_event import SQSRecord
from aws_lambda_powertools.utilities.typing import LambdaContext

processor = BatchProcessor(event_type=EventType.SQS)
tracer = Tracer()
logger = Logger()
sqs = boto3.client("sqs")

MESSAGE_QUEUE = os.environ.get("MESSAGE_QUEUE", "no-queue-defined")
MESSAGE_QUEUE_URL = os.environ.get("MESSAGE_QUEUE_URL", "no-queue-defined")


@tracer.capture_method
def process_todo(record: SQSRecord):
    logger.set_correlation_id(record.message_id)

    todo: str = record.json_body
    todo_title = todo["title"]  # {"title": "akdoaskdaokd"}
    logger.info("Processing todo", todo=todo_title)
    return True


@logger.inject_lambda_context
@tracer.capture_lambda_handler
def lambda_handler(event, context: LambdaContext):
    logger.append_keys(message_queue=MESSAGE_QUEUE)
    return process_partial_response(
        event=event,
        record_handler=process_todo,
        processor=processor,
        context=context,
    )


#### NO-OP TIMEOUT EXTENSION SECTION
#
# Registers a dummy extension that enables SIGTERM to be handled
# this no-op extension traps SIGTERM and deletes messages that processed successfully upon timeout.
#
# Alternative to deleting messages is to emit an Event to EventBridge and delete that async (and confirm they were deleted)
#
# Word of advice: A more stable solution would be to use Idempotency feature so duplicate messages could be handled appropriately
#


def build_messages_for_deletion():
    # Builds parameter for sqs.delete_batch (needs to change into a generator for >10 messages)
    return [
        {"Id": msg["messageId"], "ReceiptHandle": msg["receiptHandle"]}
        for msg in processor.success_messages
    ]


def register_dummy_extension():
    EXTENSION_HOST = os.getenv("AWS_LAMBDA_RUNTIME_API")
    EXTENSION_ENDPOINT = f"http://{EXTENSION_HOST}/2020-01-01/extension"
    EXTENSION_NEXT_EVENT_URL = f"{EXTENSION_ENDPOINT}/event/next"
    EXTENSION_REGISTER_URL = f"{EXTENSION_ENDPOINT}/register"
    EXTENSION_ID_HEADER = "Lambda-Extension-Identifier"
    EXTENSION_REGISTRATION_HEADERS = {
        "Lambda-Extension-Name": "dummy",
        "Content-Type": "application/json",
    }

    def activate_extension(extension_id: str):
        ## a single /next call activates the extension
        requests.get(
            url=EXTENSION_NEXT_EVENT_URL,
            headers={"Lambda-Extension-Identifier": extension_id},
        )
        logger.info("Extension activated", extra={"extension_id": extension_id})

    logger.info("Registering the extension")
    registration = requests.post(
        url=EXTENSION_REGISTER_URL,
        headers=EXTENSION_REGISTRATION_HEADERS,
        json={"events": []},
    )

    extension_id = registration.headers[EXTENSION_ID_HEADER]
    logger.info("Activating extension", extension_id=extension_id)
    threading.Thread(target=activate_extension, args=[extension_id]).start()


def register_cleanup_handler():
    def cleanup(signal, frame):
        messages_to_delete = build_messages_for_deletion()

        logger.info(
            "Received SIGTERM; shutting down. Deleting processed messages",
            messages=messages_to_delete,
        )

        ret = sqs.delete_message_batch(QueueUrl=MESSAGE_QUEUE_URL, Entries=messages_to_delete)
        logger.info("Deleted messages from the queue", response=ret)
        
        sys.exit(0) # prevent additional billing if finished earlier than 500ms

    signal.signal(signal.SIGTERM, cleanup)


register_dummy_extension()
register_cleanup_handler()


#### NO-OP TIMEOUT EXTENSION SECTION

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment