Skip to content

Instantly share code, notes, and snippets.

@jackmcguire1
Created December 11, 2025 14:16
Show Gist options
  • Select an option

  • Save jackmcguire1/697c33c9dc1a470868df1ac765a21c37 to your computer and use it in GitHub Desktop.

Select an option

Save jackmcguire1/697c33c9dc1a470868df1ac765a21c37 to your computer and use it in GitHub Desktop.
Batch Message Publisher with AWS Durable Lambda executions
import os
import json
from datetime import datetime, timedelta
from typing import List, Dict, Any
from itertools import batched
from aws_durable_execution_sdk_python import DurableContext, durable_execution
from aws_durable_execution_sdk_python.config import MapConfig
import boto3
from pymongo import MongoClient
from botocore.exceptions import ClientError
# Initialize clients
def get_mongo_client():
"""Get MongoDB client instance."""
mongo_host = os.environ.get('MONGO_HOST')
if not mongo_host:
raise ValueError("MONGO_HOST environment variable not set")
return MongoClient(mongo_host)
def get_sqs_client():
"""Get SQS client instance."""
return boto3.client('sqs', region_name='us-east-1')
def fetch_eligible_users(mongo_host: str) -> List[Dict[str, Any]]:
"""
Fetch users from MongoDB who have the required scopes and active tokens.
Returns list of users with their channel details.
"""
client = MongoClient(mongo_host)
db = client['some-database']
# Calculate cutoff time (5 hours ago)
cutoff_time = datetime.utcnow() - timedelta(hours=5)
# Convert to ISO string for comparison with string field in MongoDB
cutoff_time_str = cutoff_time.strftime('%Y-%m-%dT%H:%M:%SZ')
# Required scopes
required_scopes = [
"moderator:manage:announcements"
]
# Query user_tokens collection
user_tokens = db['user_tokens'].find({
'scopes': {'$all': required_scopes},
'last_refresh': {'$gte': cutoff_time_str}
})
eligible_users = []
for token_doc in user_tokens:
user_id = token_doc.get('_id')
if not user_id:
continue
# Get user details from users collection
user_doc = db['users'].find_one({'_id': user_id})
if user_doc:
eligible_users.append({
'user_id': user_id,
'channel_id': user_id,
'username': user_doc.get('username'),
})
client.close()
return eligible_users
def publish_sqs_messages(
users: List[Dict[str, Any]],
message: str,
colour: str,
queue_url: str
) -> Dict[str, int]:
"""
Publish announcement messages to SQS using batch sending (up to 10 per batch).
Returns dict with success and failure counts.
"""
sqs = get_sqs_client()
success_count = 0
failure_count = 0
# SQS batch limit is 10 messages per request
for sqs_batch in batched(users, 10):
entries = []
for idx, user in enumerate(sqs_batch):
announcement = {
'Message': message,
'Colour': colour,
'ChannelID': user['channel_id'],
'Username': user['username']
}
entries.append({
'Id': str(idx),
'MessageBody': json.dumps(announcement)
})
try:
response = sqs.send_message_batch(
QueueUrl=queue_url,
Entries=entries
)
# Count successful sends
success_count += len(response.get('Successful', []))
# Handle failures
failed = response.get('Failed', [])
failure_count += len(failed)
for failure in failed:
user_idx = int(failure['Id'])
user = list(sqs_batch)[user_idx]
print(f"Failed to send message for user {user['username']}: {failure.get('Message', 'Unknown error')}")
except ClientError as e:
# If entire batch fails, count all as failures
batch_size = len(entries)
failure_count += batch_size
print(f"Failed to send batch of {batch_size} messages: {e}")
return {
'successCount': success_count,
'failureCount': failure_count
}
def process_single_batch(ctx: DurableContext, batch: Dict, index: int) -> Dict[str, Any]:
"""
Process a single batch in the durable execution context.
The context.map() provides durability, so we don't need nested steps.
"""
batch_index = batch['batchIndex']
users = batch['users']
message = batch['message']
colour = batch['colour']
queue_url = batch['queue_url']
# Call directly - context.map provides the durability
result = publish_sqs_messages(users, message, colour, queue_url)
return {
'batchIndex': batch_index,
'userCount': len(users),
'successCount': result['successCount'],
'failureCount': result['failureCount']
}
@durable_execution
def lambda_handler(event, context: DurableContext):
"""
Durable Lambda handler for publishing messages to all eligible Twitch channels.
Expected event structure:
{
"message": "hello world!",
"colour": "#9146FF", # Optional, defaults to Twitch purple
"batchSize": 100 # Optional, defaults to 100
}
"""
# Extract parameters from event
message = event.get('message')
if not message:
raise ValueError("'message' field is required in event payload")
colour = event.get('colour', 'primary') # Default to Twitch purple
batch_size = event.get('batchSize', 100)
# Get environment variables
mongo_host = os.environ.get('MONGO_HOST')
queue_url = os.environ.get('SQS_QUEUE_URL')
if not mongo_host:
raise ValueError("MONGO_HOST environment variable not set")
if not queue_url:
raise ValueError("SQS_QUEUE_URL environment variable not set")
# Step 1: Fetch all eligible users from MongoDB
eligible_users = context.step(
lambda _: fetch_eligible_users(mongo_host),
name='fetch-eligible-users'
)
print(f"Found {len(eligible_users)} eligible users")
if not eligible_users:
return {
'message': 'No eligible users found',
'totalUsers': 0,
'batchesProcessed': 0,
'successCount': 0,
'failureCount': 0
}
# Step 2: Split users into batches using itertools.batched
batches = [
{
'batchIndex': idx,
'users': list(batch)
}
for idx, batch in enumerate(batched(eligible_users, batch_size))
]
print(f"Split into {len(batches)} batches of up to {batch_size} users each")
# Step 3: Process batches sequentially
# Since context.map() is having serialization issues, process sequentially
all_results = []
for batch in batches:
batch_index = batch['batchIndex']
users = batch['users']
def process_this_batch(_):
return publish_sqs_messages(users, message, colour, queue_url)
result = context.step(
process_this_batch,
name=f'publish-batch-{batch_index}'
)
all_results.append({
'batchIndex': batch_index,
'userCount': len(users),
'successCount': result['successCount'],
'failureCount': result['failureCount']
})
# Step 4: Aggregate results
def aggregate_results(_):
total_success = sum(r['successCount'] for r in all_results)
total_failure = sum(r['failureCount'] for r in all_results)
return {
'message': message,
'colour': colour,
'totalUsers': len(eligible_users),
'batchesProcessed': len(batches),
'successCount': total_success,
'failureCount': total_failure,
'completedAt': datetime.utcnow().isoformat()
}
summary = context.step(aggregate_results, name='aggregate-results')
print(f"Execution completed: {json.dumps(summary, indent=2)}")
return summary
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment