Skip to content

Instantly share code, notes, and snippets.

@GRAYgoose124
Created July 5, 2023 01:24
Show Gist options
  • Save GRAYgoose124/4083e0203bffd16a7e014cd957e05b50 to your computer and use it in GitHub Desktop.
Save GRAYgoose124/4083e0203bffd16a7e014cd957e05b50 to your computer and use it in GitHub Desktop.
ThreadPool fun
from abc import ABCMeta
from dataclasses import dataclass, asdict
import enum
import threading
import json
import logging
from queue import Queue
import time
from typing import Literal
import typing
import uuid
import pydantic
logger = logging.getLogger(__name__)
class BaseActionEnum(str, enum.Enum):
pass
@dataclass
class WorkRequest:
id: str
call: BaseActionEnum
args: tuple
class EOQTerminateThread(object):
pass
class QueueingWorkThread(threading.Thread, metaclass=ABCMeta):
def __init__(self, queue: Queue, actions: dict):
super().__init__()
self._queue = queue
self._actions = actions
self.delay = 0.0
def run(self):
while True:
item = self._queue.get()
if item is EOQTerminateThread:
break
if item.call in self._actions:
action = self._actions[item.call]
result = action(*item.args)
logger.debug(f"{action.__name__}({item.args}) return {result}")
self._queue.task_done()
self.delay = max(0.0, self.delay - 0.025)
else:
# TODO: Rather than just putting it back on the queue, we should probably request
# invalidating the request. item needs to become Request and such.
self._queue.put(item)
logger.error(
f"UKN {item.call=} | {self.__class__.__name__}-{self.ident}"
)
time.sleep(self.delay)
self.delay = min(20.0, self.delay + 0.33)
def get_work(self) -> WorkRequest:
return self._queue.get()
def submit_work(self, item: WorkRequest):
self._queue.put(item)
def create_str_enum_type(name: str, values: list[str]) -> enum.Enum:
values = [v.upper() for v in values]
members = {v: v for v in values}
enum_type = enum.Enum(name, members, type=str)
return enum_type
class ActionSettingMeta(type):
def __new__(cls, name, bases, attrs):
if "Actions" in attrs:
# add each action staticmethod to a dict of action, functions
attrs["actions"] = {
k: v.__func__
for k, v in attrs["Actions"].__dict__.items()
if isinstance(v, staticmethod)
}
del attrs["Actions"]
if "Settings" in attrs:
# add each setting to the class
for k, v in attrs["Settings"].__dict__.items():
if not k.startswith("__"):
attrs[k] = v
del attrs["Settings"]
if "WorkRequest" in attrs:
# generate a WorkRequest class using actions for the call field
actions = [k for k in attrs["actions"].keys() if not k.startswith("__")]
ActionEnum = create_str_enum_type("ActionEnum", actions)
attrs["ActionEnum"] = ActionEnum
attrs["WorkRequest"] = type(
"WorkRequest",
(WorkRequest,),
{
"id": str,
"call": ActionEnum,
"args": tuple,
},
)
return super().__new__(cls, name, bases, attrs)
class QueueingThreadPool(metaclass=ActionSettingMeta):
class Settings:
max_jobs = 100
num_threads = 4
class Actions:
@staticmethod
def action(arg: str, arg2: int):
pass
@dataclass
class WorkRequest(pydantic.BaseModel):
call: "QueueingThreadPool.ActionEnum" = pydantic.Field(
default="action", annotation=Literal["action"]
)
args: tuple = ("arg", 1)
def __init__(self):
self._todo_jobs = Queue()
self._running_jobs = {}
self._threads = [
QueueingWorkThread(self._todo_jobs, self.actions)
for i in range(self.num_threads)
]
def start(self):
for thread in self._threads:
thread.start()
def submit_work(self, item: WorkRequest):
if self._todo_jobs.qsize() >= self.max_jobs:
raise RuntimeError("Too many jobs in queue")
if item.id is None:
item.id = uuid.uuid4().hex
self._todo_jobs.put(item)
def join(self):
for _ in self._threads:
self._todo_jobs.put(EOQTerminateThread)
for thread in self._threads:
thread.join()
class MyQTP(QueueingThreadPool):
class Settings:
max_jobs = 100
num_threads = 4
class Actions:
@staticmethod
def pprint(*args):
print(*args)
@staticmethod
def plen(*args):
return len(*args)
@staticmethod
def pstr(*args):
return str(*args)
@staticmethod
def psum(*args):
return sum(*args)
def main():
logging.basicConfig(level=logging.DEBUG)
qtp = MyQTP()
qtp.start()
running = True
print(f"{qtp.WorkRequest.__dataclass_fields__=}")
while running:
qtp.submit_work(qtp.WorkRequest("pprint", ("Hello",)))
qtp.submit_work(qtp.WorkRequest("plen", ("Hello",)))
qtp.submit_work(qtp.WorkRequest("pstr", (123,)))
qtp.submit_work(qtp.WorkRequest("psum", ([1, 2, 3],)))
try:
qtp.submit_work(qtp.WorkRequest("notavalidrequest", ([1, 2, 3, 4],)))
except Exception as e:
print(e)
time.sleep(1)
if __name__ == "__main__":
try:
main()
except KeyboardInterrupt:
pass
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment