Skip to content

Instantly share code, notes, and snippets.

@loRes228
Last active January 3, 2024 19:26
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save loRes228/c644092a170e0068d50ec4334247aa95 to your computer and use it in GitHub Desktop.
Save loRes228/c644092a170e0068d50ec4334247aa95 to your computer and use it in GitHub Desktop.
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING
from aiogram import BaseMiddleware
from cachetools import TTLCache
if TYPE_CHECKING:
from collections.abc import Awaitable, Callable
from typing import Any
from aiogram import Dispatcher
from aiogram.types import TelegramObject, User
from src.config import ThrottlingConfig
class ThrottledError(Exception):
pass
@dataclass(kw_only=True, slots=True)
class ThrottlingData:
rate: int = 0
warned: bool = False
class Throttler:
cache: TTLCache[int, ThrottlingData]
max_rate: int
__slots__ = ("cache", "max_rate")
def __init__(self, *, time_period: float, max_rate: int) -> None:
self.cache = TTLCache(maxsize=10_000, ttl=time_period)
self.max_rate = max_rate
def setup(self, *, user_id: int) -> None:
self.cache.setdefault(key=user_id, default=ThrottlingData())
def add_rate(self, *, user_id: int) -> None:
self.cache[user_id].rate += 1
def is_throttled(self, *, user_id: int) -> bool:
return self.cache[user_id].rate >= self.max_rate
def warn(self, *, user_id: int) -> None:
# Chained assignment for resetting throttled period
data = self.cache[user_id] = self.cache[user_id]
if data.warned:
return
data.warned = True
raise ThrottledError
class ThrottlingMiddleware(BaseMiddleware):
throttler: Throttler
__slots__ = ("throttler",)
def __init__(self, *, config: ThrottlingConfig) -> None:
self.throttler = Throttler(
time_period=config.time_period.total_seconds(),
max_rate=config.max_rate,
)
def setup(self, *, dispatcher: Dispatcher) -> None:
dispatcher.update.outer_middleware.register(self)
async def __call__(
self,
handler: Callable[[TelegramObject, dict[str, Any]], Awaitable[Any]],
event: TelegramObject,
data: dict[str, Any],
) -> Any:
event_user: User | None = data.get("event_from_user")
if not event_user:
return None
self.throttler.setup(user_id=event_user.id)
if self.throttler.is_throttled(user_id=event_user.id):
return self.throttler.warn(user_id=event_user.id)
self.throttler.add_rate(user_id=event_user.id)
return await handler(event, data)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment