Skip to content

Instantly share code, notes, and snippets.

@thewisenerd
Created May 5, 2023 16:21
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save thewisenerd/5719ee49a62d14eeb2fc0be6407c6169 to your computer and use it in GitHub Desktop.
Save thewisenerd/5719ee49a62d14eeb2fc0be6407c6169 to your computer and use it in GitHub Desktop.
import threading
import typing
from dataclasses import dataclass, field
from queue import Queue
from typing import TypeVar, Generic
T = TypeVar('T')
R = TypeVar('R')
@dataclass
class Job(Generic[T, R]):
input: T
done: threading.Event = field(init=False, default_factory=threading.Event)
success: bool = field(init=False, default=True)
result: typing.Optional[R] = field(init=False, default=None)
error: typing.Optional[Exception] = field(init=False, default=None)
def complete_ok(self, result: R):
self.result = result
self.done.set()
def complete_err(self, err: Exception):
self.success = False
self.error = err
self.done.set()
def wait(self):
self.done.wait()
def get(self):
self.wait()
if self.success:
return self.result
raise self.error
class WorkerPool(Generic[T, R]):
def __init__(self,
name: str,
threads: int,
runner: typing.Callable[[T, 'WorkerPool'], R],
q: typing.Optional[Queue[T]] = None):
if q is None:
q = Queue[Job[T, R]]()
self.name = name
self.threads = threads
self.q = q
self.runner = runner
self.kill_switch = threading.Event()
self.worker_threads = []
def worker():
while not self.kill_switch.is_set():
job = self.q.get()
try:
r = self.runner(job.input, self)
job.complete_ok(r)
except Exception as e:
job.complete_err(e)
q.task_done()
for idx in range(0, self.threads):
thread = threading.Thread(
target=worker,
name=f'{self.name}-{idx}',
args=[]
)
thread.daemon = True
thread.start()
self.worker_threads.append(thread)
def submit(self, task: T) -> Job[T, R]:
if self.kill_switch.is_set():
raise RuntimeError("cannot submit jobs for a shut down worker pool")
job: Job[T, R] = Job(task)
self.q.put(job)
return job
def shutdown(self):
self.q.join()
self.kill_switch.set()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment