|
import logging |
|
import time |
|
from abc import ABC, abstractmethod |
|
from collections import OrderedDict |
|
from enum import Enum, auto |
|
from threading import RLock |
|
from threading import Thread as _BuiltinThread |
|
from typing import Any |
|
|
|
|
|
class ThreadTypes(Enum): |
|
MAIN = auto() |
|
INTERACTION = auto() |
|
TRAINING = auto() |
|
|
|
|
|
class SharedObjectPool: |
|
_objects: OrderedDict[ThreadTypes, OrderedDict] |
|
|
|
def __init__(self) -> None: |
|
self._objects = OrderedDict() |
|
for e in ThreadTypes: |
|
self._objects[e] = OrderedDict() |
|
|
|
self._lock = RLock() |
|
|
|
def register(self, thread_type: ThreadTypes, name: str, obj: Any) -> None: |
|
with self._lock: |
|
self._objects[thread_type][name] = obj |
|
|
|
def get(self, thread_type: ThreadTypes, name: str) -> None: |
|
with self._lock: |
|
return self._objects[thread_type][name] |
|
|
|
|
|
class BaseThread(ABC): |
|
def __init__(self, shared_object_pool: SharedObjectPool) -> None: |
|
self._shared_object_pool = shared_object_pool |
|
self._worker_thread = _BuiltinThread(target=self._worker) |
|
self.logger = logging.getLogger(f"{__name__}.{self.__class__.__name__}") |
|
|
|
@property |
|
@abstractmethod |
|
def thread_type(self) -> ThreadTypes: |
|
"""Must define thread type""" |
|
raise NotImplementedError |
|
|
|
@abstractmethod |
|
def worker(self) -> None: |
|
"""Worker thread""" |
|
raise NotImplementedError |
|
|
|
def _worker(self): |
|
try: |
|
self.worker() |
|
except Exception as e: |
|
self.logger.exception(e) |
|
|
|
def start(self): |
|
"""Start worker thread in background""" |
|
self._worker_thread.start() |
|
|
|
def share_object(self, name: str, obj: Any) -> None: |
|
"""Share to other threads.""" |
|
self._shared_object_pool.register(self.thread_type, name, obj) |
|
|
|
def get_shared_object(self, shared_from: ThreadTypes, name: str) -> Any: |
|
"""Get shared object.""" |
|
return self._shared_object_pool.get(shared_from, name) |
|
|
|
|
|
class Context: |
|
value: bool = False |
|
|
|
|
|
class MainThread(BaseThread): |
|
"""Main thread.""" |
|
|
|
def __init__(self, shared_object_pool: SharedObjectPool) -> None: |
|
super().__init__(shared_object_pool) |
|
self.shutdown = Context() |
|
self.share_object("shutdown", self.shutdown) |
|
|
|
@property |
|
def thread_type(self) -> ThreadTypes: |
|
return ThreadTypes.MAIN |
|
|
|
def start(self): |
|
"""Run on main thread.""" |
|
self._worker() |
|
|
|
def worker(self) -> None: |
|
|
|
try: |
|
while True: |
|
print("Main. Press Ctrl+C to quit.") |
|
time.sleep(1) |
|
except KeyboardInterrupt: |
|
self.shutdown.value = True |
|
|
|
|
|
class InteractionThread(BaseThread): |
|
@property |
|
def thread_type(self) -> ThreadTypes: |
|
return ThreadTypes.INTERACTION |
|
|
|
def worker(self) -> None: |
|
|
|
shutdown: Context = self.get_shared_object(ThreadTypes.MAIN, "shutdown") |
|
|
|
while not shutdown.value: |
|
print("Interaction.") |
|
time.sleep(0.5) |
|
|
|
|
|
class TrainingThread(BaseThread): |
|
@property |
|
def thread_type(self) -> ThreadTypes: |
|
return ThreadTypes.TRAINING |
|
|
|
def worker(self) -> None: |
|
|
|
shutdown: Context = self.get_shared_object(ThreadTypes.MAIN, "shutdown") |
|
|
|
while not shutdown.value: |
|
print("Training") |
|
time.sleep(1.25) |
|
|
|
|
|
if __name__ == "__main__": |
|
sop = SharedObjectPool() |
|
main = MainThread(sop) |
|
interaction = InteractionThread(sop) |
|
training = TrainingThread(sop) |
|
|
|
interaction.start() |
|
training.start() |
|
|
|
main.start() |