Skip to content

Instantly share code, notes, and snippets.

@alexturek
Last active September 26, 2022 21:44
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save alexturek/eec410a5326ad4506a7c200b1644edd2 to your computer and use it in GitHub Desktop.
Save alexturek/eec410a5326ad4506a7c200b1644edd2 to your computer and use it in GitHub Desktop.
How to configure AWS Batch job state updates so they go into an SQS queue

What this is good for

This is Terraform and Python code to help you track AWS Batch Jobs without having to poll Batch's APIs.

It sets up this integration

AWS Batch --(job state change)--> EventBridge --(message)--> SQS

You can do it in pure Terraform, or a mix of Python/Terraform.

The Terraform shows how to do this as a per-region queue across all workflows.

The Python shows how to do this as a workflow-specific queue, matching only job state updates from AWS Batch jobs that have a specific parameter (workflow_id) set.

Setting up/tearing down queues per workflow

I set this up for Prefect, which sets workflow IDs as UUIDs. I didn't trust AWS name length limits so I base62 encode them to save characters.

create_subscription_for_batch_jobs creates the following things in AWS

  • SQS Queue batch-job-states-{encoded_id}
  • EventBridge rule (matching job state events) batch-job-states-{encoded_id}
  • An EventBridge target that connects the two
  • An SQS Queue Policy that permits the EventBridge rule to push events to the queue

and clean_up_subscription cleans them up

Usage:

import boto3
import uuid
from batch import create_subscription_for_batch_jobs, clean_up_subscription, submit_sleep_job, wait_for_jobstates

workflow_id = uuid.uuid4() # or get it from wherever
session = boto3._get_default_session()

subscription = create_subscription_for_batch_jobs(id, session)

# Right now, this blocks forever, so
# you'll have to figure out what you want to do with each
# job state message, and leave the function at some point.
wait_for_jobstates(subscription, session)

# later...
clean_up_subscription(subscription, session)

batch.py

import asyncio
from enum import Enum
import json
from typing import Callable, Optional
import uuid
import boto3
from dataclasses import dataclass
from .id_encoding import encode_uuid
from mypy_boto3_sqs.service_resource import Message as SqsMessage
from types import SimpleNamespace


@dataclass(frozen=True)
class QueueData:
    url: str
    arn: str


@dataclass(frozen=True)
class RuleData:
    rule_name: str
    target_ids: list[str]


@dataclass(frozen=True)
class Subscription:
    queue: QueueData
    rule: RuleData


def create_subscription_for_batch_jobs(
    workflow_id: uuid.UUID, aws: boto3.Session
) -> Subscription:
    encoded_id = encode_uuid(workflow_id)
    eventbridge = aws.client("events")
    rule_name = f"batch-job-states-{encoded_id}"
    rule = eventbridge.put_rule(
        Name=rule_name,
        Description="match Batch Job States",
        EventPattern=json.dumps(
            {
                "source": ["aws.batch"],
                "detail-type": ["Batch Job State Change"],
                "detail": {
                    "parameters": {"workflow_id": [encoded_id]},
                },
            }
        ),
        Tags=[{"Key": "flow_run_id", "Value": encoded_id}],
    )
    queue_name = f"batch-job-states-{encoded_id}"
    queue = aws.resource("sqs").create_queue(
        QueueName=queue_name, tags={"flow_run_id": encoded_id}
    )
    queue_arn = queue.attributes["QueueArn"]
    policy = {
        "Version": "2012-10-17",
        "Statement": [
            {
                "Sid": "SQSAccess",
                "Effect": "Allow",
                "Principal": {"Service": "events.amazonaws.com"},
                "Action": "sqs:SendMessage",
                "Resource": queue_arn,
                "Condition": {"ArnEquals": {"aws:SourceArn": rule["RuleArn"]}},
            }
        ],
    }
    queue.set_attributes(Attributes={"Policy": json.dumps(policy)})
    eventbridge.put_targets(
        Rule=rule_name,
        EventBusName="default",
        Targets=[
            {
                "Id": queue_name,
                "Arn": queue_arn,
            }
        ],
    )
    return Subscription(
        queue=QueueData(url=queue.url, arn=queue_arn),
        rule=RuleData(rule_name=rule_name, target_ids=[queue_name]),
    )


def clean_up_subscription(sub: Subscription, aws: boto3.Session) -> None:
    eventbridge = aws.client("events")
    eventbridge.remove_targets(Rule=sub.rule.rule_name, Ids=sub.rule.target_ids)
    eventbridge.delete_rule(Name=sub.rule.rule_name)
    sqs = aws.client("sqs")
    sqs.delete_queue(QueueUrl=sub.queue.url)


# Example code for running an AWS Batch job
# This just runs an alpine container that sleeps for some time
#
# Key code: the part that sets the container's workflow_id parameter
#  for matching in the eventbridge rule.
def submit_sleep_job(workflow_id: uuid.UUID, seconds: int, aws: boto3.Session) -> str:
    encoded_id = encode_uuid(workflow_id)
    batch_client = aws.client("batch")
    job = batch_client.submit_job(
        jobName=f"sleep-{seconds}",
        jobQueue="primary_queue_shared_us_west_2_dev_630831553409",
        jobDefinition="alpine",
        containerOverrides={
            "command": [
                "/bin/sh",
                "-c",
                f"echo 'sleeping..' ; sleep {seconds} ; echo 'awake'",
            ]
        },
        tags={"workflow_id": encoded_id},
        parameters={"workflow_id": encoded_id},
    )
    return job["jobId"]


def matches_job(job_id: str) -> Callable[[SqsMessage], bool]:
    def matcher(message: SqsMessage) -> bool:
        parsed_body = json.loads(message.body)
        return parsed_body["detail"]["jobId"] == job_id

    return matcher


async def wait_for_jobstates(sub: Subscription, aws: boto3.Session) -> None:
    queue = aws.resource("sqs").Queue(sub.queue.url)
    while True:
        await asyncio.sleep(1) # Let other code run between these long polls
        messages = queue.receive_messages(WaitTimeSeconds=20, MaxNumberOfMessages=10)
        for message in messages:
            job_state_message = json.loads(
                message.body, object_hook=lambda d: SimpleNamespace(**d)
            )
            task_id: str = job_state_message.detail.parameters.task_id
            job_state: str = job_state_message.detail.status

            # Do whatever you want with the job state here.
            # Check out
            #   https://docs.aws.amazon.com/batch/latest/userguide/job_states.html

            queue.delete_messages(
                Entries=[
                    {"Id": message.message_id, "ReceiptHandle": message.receipt_handle}
                ]
            )

id_encoding.py

import base62
import uuid


def encode_uuid(u: uuid.UUID) -> str:
    return base62.encodebytes(u.bytes)

Terraform resources

This sets up a global queue for all AWS Batch job states.

provider "aws" {
  region = var.region
}

resource "aws_cloudwatch_event_rule" "all_batch_job_states" {
  name = "match-batch-job-states"
  description = "match Batch job states"
  event_pattern = jsonencode({
    "source" : ["aws.batch"]
    "detail-type": ["Batch Job State Change"],
  })
}

resource "aws_sqs_queue" "job_state_updates" {
  name = "job-state-updates"
}

resource "aws_cloudwatch_event_target" "job_states_target" {
  target_id = "sqs-for-job-states"
  arn = aws_sqs_queue.job_state_updates.arn
  rule = aws_cloudwatch_event_rule.all_batch_job_states.name
}

resource "aws_sqs_queue_policy" "allow_eventbridge_forwarding" {
  queue_url = aws_sqs_queue.job_state_updates.id
  policy = <<POLICY
{
  "Version": "2012-10-17",
  "Statement": [
    {
      "Sid": "SQSAccess",
      "Effect": "Allow",
      "Principal": {
        "Service": "events.amazonaws.com"
      },
      "Action": "sqs:SendMessage",
      "Resource": "${aws_sqs_queue.job_state_updates.arn}",
      "Condition": {
        "ArnEquals": {
          "aws:SourceArn": "${aws_cloudwatch_event_rule.all_batch_job_states.arn}"
        }
      }
    }
  ]
}
POLICY
}

Making the terraform only match job states from specific workflows

You can configure it just for a specific workflow ID by adding more matching parameters to the aws_cloudwatch_event_rule, like so:

  resource "aws_cloudwatch_event_rule" "all_batch_job_states" {
    name = "match-batch-job-states"
    description = "match Batch job states"
    event_pattern = jsonencode({
      "source" : ["aws.batch"]
      "detail-type": ["Batch Job State Change"],
+     "parameters": {"workflow_id": [var.workflow_id]},
    })
  }

You'd probably also want to change all the resource names, e.g. SQS queue name, to match the workflow IDs

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment