Skip to content

Instantly share code, notes, and snippets.

@gustabot42
Last active June 9, 2024 05:30
Show Gist options
  • Save gustabot42/9301b81b2eafd4de065bc5b3c3b72f94 to your computer and use it in GitHub Desktop.
Save gustabot42/9301b81b2eafd4de065bc5b3c3b72f94 to your computer and use it in GitHub Desktop.
from __future__ import annotations
import asyncio
import time
from collections import deque
from dataclasses import dataclass
from dataclasses import field
from typing import ClassVar
from typing import Generic
from typing import TypeVar
from pydantic import BaseModel
from result import Err
from result import Ok
from hermes.utils import coalesce
from hermes.utils.asyncio import EventResult
from hermes.utils.asyncio import delay_awaitable
from hermes.utils.logging import logging
ActionType = TypeVar("ActionType", bound=BaseModel)
@dataclass
class Buffer:
_values: deque | None = field(default_factory=deque)
event_result: EventResult = field(default_factory=EventResult, init=False)
def __len__(self) -> int:
return len(self._values) if self._values is not None else 0
def append(self, value) -> None:
if self._values is not None:
self._values.append(value)
def expend(self) -> list | None:
if self._values is None:
return None
values = self._values
self._values = None
return values
def is_exhausted(self) -> bool:
return self._values is None
@dataclass
class ActionBulk(Generic[ActionType]):
BUFFER_SIZE: ClassVar[int]
ACTION_DELAY_SEC: ClassVar[float]
AWAIT_TIMEOUT_SEC: ClassVar[float]
_buffer: Buffer = None
_background_tasks: set = field(default_factory=set)
def _create_task(self, buffer: Buffer, delay: float | None = None) -> None:
delay = coalesce(delay, self.ACTION_DELAY_SEC)
task = asyncio.create_task(delay_awaitable(delay, self.action, buffer))
self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.discard)
async def register(self, obj: ActionType) -> EventResult:
if self._buffer and len(self._buffer) >= self.BUFFER_SIZE:
self._create_task(self._buffer, delay=0)
self._buffer = None
if self._buffer is None or self._buffer.is_exhausted():
self._buffer = Buffer()
self._create_task(self._buffer)
self._buffer.append(obj)
return self._buffer.event_result
async def action(self, buffer: Buffer) -> None:
values = buffer.expend()
if values is None:
return
try:
tic = time.time()
await self.perform_action(values)
except Exception as e:
error_msg = str(e)
msg = f"{self.__class__} action error: {error_msg}"
logging.error(msg)
buffer.event_result.set(Err(error_msg))
finally:
elapsed = time.time() - tic
if elapsed > self.ACTION_DELAY_SEC:
msg = f"{self.__class__} action delay: {elapsed}s"
logging.error(msg)
buffer.event_result.set(Ok(None))
async def wait(self) -> None:
await asyncio.gather(*self._background_tasks)
@staticmethod
async def perform_action(objs: deque[ActionType]) -> None:
"""Ensure the action is idempotent to handle multiple invocations for the same objects,
as it can't be cancelled and may be repeated due to latency issues."""
raise NotImplementedError
import asyncio
import pytest
from pydantic import BaseModel
from result import Err
from result import Ok
from hermes.db.action_bulk import ActionBulk
class ActionModel(BaseModel):
value: int
class PassBulk(ActionBulk[ActionModel]):
BUFFER_SIZE = 2
ACTION_DELAY_SEC = 0.1
AWAIT_TIMEOUT_SEC = 0.2
@staticmethod
async def perform_action(objs: list[ActionModel]) -> None:
pass
class ErrorBulk(ActionBulk[ActionModel]):
BUFFER_SIZE = 2
ACTION_DELAY_SEC = 0.1
AWAIT_TIMEOUT_SEC = 0.2
# ruff: noqa: ARG004
@staticmethod
async def perform_action(objs: list[ActionModel]) -> None:
msg = "error message"
raise ValueError(msg)
class TimeoutBulk(ActionBulk[ActionModel]):
BUFFER_SIZE = 2
ACTION_DELAY_SEC = 0.1
AWAIT_TIMEOUT_SEC = 0.2
@staticmethod
async def perform_action(objs: list[ActionModel]) -> None:
await asyncio.sleep(0.5)
@pytest.mark.asyncio()
async def test_pass_register_and_wait_for():
bulk = PassBulk()
event_result = await bulk.register(ActionModel(value=1))
assert event_result.is_set() is False
assert await event_result.wait_for() == Ok(None)
assert event_result.is_set() is True
@pytest.mark.asyncio()
async def test_error_register_and_wait_for():
bulk = ErrorBulk()
event_result = await bulk.register(ActionModel(value=1))
assert event_result.is_set() is False
assert await event_result.wait_for() == Err("error message")
assert event_result.is_set() is True
@pytest.mark.asyncio()
async def test_timeout_register_and_wait_for():
bulk = TimeoutBulk()
event_result = await bulk.register(ActionModel(value=1))
assert event_result.is_set() is False
assert await event_result.wait_for(0.3) == Err("TimeoutError")
assert event_result.is_set() is False
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment