Created
December 11, 2025 14:16
-
-
Save jackmcguire1/697c33c9dc1a470868df1ac765a21c37 to your computer and use it in GitHub Desktop.
Batch Message Publisher with AWS Durable Lambda executions
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import os | |
| import json | |
| from datetime import datetime, timedelta | |
| from typing import List, Dict, Any | |
| from itertools import batched | |
| from aws_durable_execution_sdk_python import DurableContext, durable_execution | |
| from aws_durable_execution_sdk_python.config import MapConfig | |
| import boto3 | |
| from pymongo import MongoClient | |
| from botocore.exceptions import ClientError | |
| # Initialize clients | |
| def get_mongo_client(): | |
| """Get MongoDB client instance.""" | |
| mongo_host = os.environ.get('MONGO_HOST') | |
| if not mongo_host: | |
| raise ValueError("MONGO_HOST environment variable not set") | |
| return MongoClient(mongo_host) | |
| def get_sqs_client(): | |
| """Get SQS client instance.""" | |
| return boto3.client('sqs', region_name='us-east-1') | |
| def fetch_eligible_users(mongo_host: str) -> List[Dict[str, Any]]: | |
| """ | |
| Fetch users from MongoDB who have the required scopes and active tokens. | |
| Returns list of users with their channel details. | |
| """ | |
| client = MongoClient(mongo_host) | |
| db = client['some-database'] | |
| # Calculate cutoff time (5 hours ago) | |
| cutoff_time = datetime.utcnow() - timedelta(hours=5) | |
| # Convert to ISO string for comparison with string field in MongoDB | |
| cutoff_time_str = cutoff_time.strftime('%Y-%m-%dT%H:%M:%SZ') | |
| # Required scopes | |
| required_scopes = [ | |
| "moderator:manage:announcements" | |
| ] | |
| # Query user_tokens collection | |
| user_tokens = db['user_tokens'].find({ | |
| 'scopes': {'$all': required_scopes}, | |
| 'last_refresh': {'$gte': cutoff_time_str} | |
| }) | |
| eligible_users = [] | |
| for token_doc in user_tokens: | |
| user_id = token_doc.get('_id') | |
| if not user_id: | |
| continue | |
| # Get user details from users collection | |
| user_doc = db['users'].find_one({'_id': user_id}) | |
| if user_doc: | |
| eligible_users.append({ | |
| 'user_id': user_id, | |
| 'channel_id': user_id, | |
| 'username': user_doc.get('username'), | |
| }) | |
| client.close() | |
| return eligible_users | |
| def publish_sqs_messages( | |
| users: List[Dict[str, Any]], | |
| message: str, | |
| colour: str, | |
| queue_url: str | |
| ) -> Dict[str, int]: | |
| """ | |
| Publish announcement messages to SQS using batch sending (up to 10 per batch). | |
| Returns dict with success and failure counts. | |
| """ | |
| sqs = get_sqs_client() | |
| success_count = 0 | |
| failure_count = 0 | |
| # SQS batch limit is 10 messages per request | |
| for sqs_batch in batched(users, 10): | |
| entries = [] | |
| for idx, user in enumerate(sqs_batch): | |
| announcement = { | |
| 'Message': message, | |
| 'Colour': colour, | |
| 'ChannelID': user['channel_id'], | |
| 'Username': user['username'] | |
| } | |
| entries.append({ | |
| 'Id': str(idx), | |
| 'MessageBody': json.dumps(announcement) | |
| }) | |
| try: | |
| response = sqs.send_message_batch( | |
| QueueUrl=queue_url, | |
| Entries=entries | |
| ) | |
| # Count successful sends | |
| success_count += len(response.get('Successful', [])) | |
| # Handle failures | |
| failed = response.get('Failed', []) | |
| failure_count += len(failed) | |
| for failure in failed: | |
| user_idx = int(failure['Id']) | |
| user = list(sqs_batch)[user_idx] | |
| print(f"Failed to send message for user {user['username']}: {failure.get('Message', 'Unknown error')}") | |
| except ClientError as e: | |
| # If entire batch fails, count all as failures | |
| batch_size = len(entries) | |
| failure_count += batch_size | |
| print(f"Failed to send batch of {batch_size} messages: {e}") | |
| return { | |
| 'successCount': success_count, | |
| 'failureCount': failure_count | |
| } | |
| def process_single_batch(ctx: DurableContext, batch: Dict, index: int) -> Dict[str, Any]: | |
| """ | |
| Process a single batch in the durable execution context. | |
| The context.map() provides durability, so we don't need nested steps. | |
| """ | |
| batch_index = batch['batchIndex'] | |
| users = batch['users'] | |
| message = batch['message'] | |
| colour = batch['colour'] | |
| queue_url = batch['queue_url'] | |
| # Call directly - context.map provides the durability | |
| result = publish_sqs_messages(users, message, colour, queue_url) | |
| return { | |
| 'batchIndex': batch_index, | |
| 'userCount': len(users), | |
| 'successCount': result['successCount'], | |
| 'failureCount': result['failureCount'] | |
| } | |
| @durable_execution | |
| def lambda_handler(event, context: DurableContext): | |
| """ | |
| Durable Lambda handler for publishing messages to all eligible Twitch channels. | |
| Expected event structure: | |
| { | |
| "message": "hello world!", | |
| "colour": "#9146FF", # Optional, defaults to Twitch purple | |
| "batchSize": 100 # Optional, defaults to 100 | |
| } | |
| """ | |
| # Extract parameters from event | |
| message = event.get('message') | |
| if not message: | |
| raise ValueError("'message' field is required in event payload") | |
| colour = event.get('colour', 'primary') # Default to Twitch purple | |
| batch_size = event.get('batchSize', 100) | |
| # Get environment variables | |
| mongo_host = os.environ.get('MONGO_HOST') | |
| queue_url = os.environ.get('SQS_QUEUE_URL') | |
| if not mongo_host: | |
| raise ValueError("MONGO_HOST environment variable not set") | |
| if not queue_url: | |
| raise ValueError("SQS_QUEUE_URL environment variable not set") | |
| # Step 1: Fetch all eligible users from MongoDB | |
| eligible_users = context.step( | |
| lambda _: fetch_eligible_users(mongo_host), | |
| name='fetch-eligible-users' | |
| ) | |
| print(f"Found {len(eligible_users)} eligible users") | |
| if not eligible_users: | |
| return { | |
| 'message': 'No eligible users found', | |
| 'totalUsers': 0, | |
| 'batchesProcessed': 0, | |
| 'successCount': 0, | |
| 'failureCount': 0 | |
| } | |
| # Step 2: Split users into batches using itertools.batched | |
| batches = [ | |
| { | |
| 'batchIndex': idx, | |
| 'users': list(batch) | |
| } | |
| for idx, batch in enumerate(batched(eligible_users, batch_size)) | |
| ] | |
| print(f"Split into {len(batches)} batches of up to {batch_size} users each") | |
| # Step 3: Process batches sequentially | |
| # Since context.map() is having serialization issues, process sequentially | |
| all_results = [] | |
| for batch in batches: | |
| batch_index = batch['batchIndex'] | |
| users = batch['users'] | |
| def process_this_batch(_): | |
| return publish_sqs_messages(users, message, colour, queue_url) | |
| result = context.step( | |
| process_this_batch, | |
| name=f'publish-batch-{batch_index}' | |
| ) | |
| all_results.append({ | |
| 'batchIndex': batch_index, | |
| 'userCount': len(users), | |
| 'successCount': result['successCount'], | |
| 'failureCount': result['failureCount'] | |
| }) | |
| # Step 4: Aggregate results | |
| def aggregate_results(_): | |
| total_success = sum(r['successCount'] for r in all_results) | |
| total_failure = sum(r['failureCount'] for r in all_results) | |
| return { | |
| 'message': message, | |
| 'colour': colour, | |
| 'totalUsers': len(eligible_users), | |
| 'batchesProcessed': len(batches), | |
| 'successCount': total_success, | |
| 'failureCount': total_failure, | |
| 'completedAt': datetime.utcnow().isoformat() | |
| } | |
| summary = context.step(aggregate_results, name='aggregate-results') | |
| print(f"Execution completed: {json.dumps(summary, indent=2)}") | |
| return summary |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment