Skip to content

Instantly share code, notes, and snippets.

@horodchukanton
Last active September 24, 2022 11:13
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save horodchukanton/969f554ff82ce84e48c891599ea787f1 to your computer and use it in GitHub Desktop.
Save horodchukanton/969f554ff82ce84e48c891599ea787f1 to your computer and use it in GitHub Desktop.
FastAPI Background tasks queue
import logging
from concurrent.futures import ThreadPoolExecutor
from datetime import timedelta, datetime
from typing import Any, Dict
class SimpleStorageEntry:
_counter: int = 0
id: int
timestamp: datetime
ttl: int
data: Any
def __init__(self, data: Any, ttl: int = 86400):
SimpleStorageEntry._counter += 1
self.id = SimpleStorageEntry._counter
self.timestamp = datetime.now()
self.ttl = ttl
self.data = data
class SimpleStorage:
_map: Dict[int, SimpleStorageEntry]
def __init__(self):
self._map = {}
def set(self, data: Any, ttl: int = 60 * 60 * 24):
entry = SimpleStorageEntry(data=data, ttl=ttl)
self._map[entry.id] = entry
return entry.id
def get(self, task_token: int) -> Any:
self._check_ttls()
try:
entry = self._map[task_token]
return entry.data
except KeyError as ke:
raise MissingEntryRequestedException from ke
def update(self, task_token, new_data):
self._check_ttls()
try:
entry = self._map[task_token]
entry.data = new_data
except KeyError as ke:
raise MissingEntryRequestedException from ke
def _check_ttls(self):
keys = self._map.copy().keys()
for k in keys:
entry = self._map[k]
time_to_die = entry.timestamp + timedelta(seconds=entry.ttl)
if time_to_die < datetime.now():
del self._map[k]
class BackgroundExecutor:
executor = None
def __init__(self):
# pylint: disable=consider-using-with
self.executor = ThreadPoolExecutor(max_workers=1)
def add_task(self, fn):
self.executor.submit(fn)
class TaskFailedException(Exception):
caused_by: Exception
def __init__(self, caused_by, *args, **kwargs):
super().__init__(*args, **kwargs)
self.caused_by = caused_by
class NoOperationForTokenException(Exception):
pass
class RequestResultsForInProgressOperation(Exception):
pass
class MissingEntryRequestedException(Exception):
pass
class SimpleTasksQueue:
kv_storage: SimpleStorage = None
bg_tasks: BackgroundExecutor = None
def __init__(self):
self._pending_value = str(id(self)) + '_pending'
self._running_value = str(id(self)) + '_running'
self.kv_staorage = SimpleStorage()
self.bg_tasks = BackgroundExecutor()
def add(self, func, *args, **kwargs):
task_token = self.kv_storage.set(self._pending_value)
wait_task = self._wrap_task(task_token, func, *args, **kwargs)
self.bg_tasks.add_task(wait_task)
return task_token
def get(self, task_token) -> Any:
if not self.is_ready(task_token):
raise RequestResultsForInProgressOperation()
result = self.kv_storage.get(task_token)
if isinstance(result, TaskFailedException):
raise result.caused_by from result
return result
def is_ready(self, task_token):
return self.kv_storage.get(task_token) not in [
self._running_value, self._pending_value]
def _wrap_task(self, task_token, fn, *args, **kwargs):
def update_result_when_finished():
try:
logging.debug("Task %i started", task_token)
self.kv_storage.update(task_token, self._running_value)
result = fn(*args, **kwargs)
logging.debug("Task %i finished", task_token)
except Exception as e:
logging.exception("Task %i failed", task_token, exc_info=e)
result = TaskFailedException(caused_by=e)
logging.debug("Saving result for %i", task_token)
self.kv_storage.update(task_token, result)
return update_result_when_finished
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment