Skip to content

Instantly share code, notes, and snippets.

@rgpower
Created January 14, 2021 15:42
Show Gist options
  • Save rgpower/6f53d84f8db1a740fbf5caac5f78c75c to your computer and use it in GitHub Desktop.
Save rgpower/6f53d84f8db1a740fbf5caac5f78c75c to your computer and use it in GitHub Desktop.
process n items using m threads using asyncio and a queue in python or use it to backfill failed lambda executions
from dotenv import load_dotenv
load_dotenv()
import asyncio
import datetime
import functools
import json
import logging
import os
import queue
import random
import sys
import threading
import traceback
from timeit import default_timer as timer
MAX_SIMULTANEOUS = int(os.environ.get("CXS_MAX_SIMULTANEOUS", "25"))
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
max_simul = threading.BoundedSemaphore(MAX_SIMULTANEOUS)
async def main():
global max_simul
backfill_filename = backfill_filename_arg()
with sys.stdin if backfill_filename is None else open(backfill_filename, "r") as fh:
file_uids = list(map(lambda l: l.rstrip("\n"), fh.readlines()))
total_uids = len(file_uids)
uids_complete = ThreadsafeCounter()
num_left = ThreadsafeCounter(start_value=len(file_uids))
tm_started = timer()
q = queue.Queue(maxsize=total_uids)
def backfill_user():
global max_simul
nonlocal q
uid = q.get()
logger = logging.getLogger(uid)
logger.setLevel(logging.INFO)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
context = new_lambda_context()
sns = build_sns(uid)
# you supply handle_sns func
coro = handle_sns(sns, context)
loop.run_until_complete(coro)
except Exception as err:
stk = "".join(traceback.TracebackException.from_exception(err).format())
logger.error(f"error={err} stk={stk}")
finally:
uids_complete.increment()
num_left.increment(-1)
loop.close()
q.task_done()
max_simul.release()
def log_progress():
pct_done = 100 * uids_complete.value / total_uids
elapsed = timer() - tm_started
speed = uids_complete.value / elapsed
time_left_mins = (num_left.value / speed) / 60 if speed > 0 else -1
logger.warning(
f"pct_done={pct_done:.1f}% speed={speed:.2f} uids/sec uids_completed={uids_complete.value} num_left={num_left.value} elapsed={round(elapsed)} remaining(mins)={time_left_mins:.2f} active.threads={threading.active_count() - 2}"
)
cancel_log_progress = call_repeatedly(1, log_progress)
i = 0
for uid in file_uids:
max_simul.acquire(True)
i = i + 1
worker_name = f"QueueWorkerThread<id={i}>"
t = threading.Thread(name=worker_name, target=backfill_user)
q.put(uid)
t.start()
q.join()
cancel_log_progress()
log_progress()
def backfill_filename_arg():
if len(sys.argv) < 2:
return None
return sys.argv[1]
def new_lambda_context():
import types
import uuid
return types.SimpleNamespace(aws_request_id=str(uuid.uuid4()))
def to_iso8601(datetime: datetime.datetime):
return datetime.strftime("%Y-%m-%dT%H:%M:%S%z")
def build_sns(uid):
return {
"Timestamp": to_iso8601(datetime.datetime.now()),
"Message": json.dumps(
{
"event": "activity",
"uid": uid,
}
),
}
class ThreadsafeCounter(object):
def __init__(self, start_value=0):
self.value = start_value
self._lock = threading.Lock()
def increment(self, amount=1):
with self._lock:
self.value += amount
# https://stackoverflow.com/questions/22498038/improve-current-implementation-of-a-setinterval/22498708#22498708
def call_repeatedly(interval, func, *args):
stopped = threading.Event()
def loop():
while not stopped.wait(interval): # the first call is in `interval` secs
func(*args)
threading.Thread(target=loop).start()
return stopped.set
if __name__ == "__main__":
asyncio.run(main())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment