Skip to content

Instantly share code, notes, and snippets.

@Tobi-De
Created November 28, 2023 17:43
Show Gist options
  • Save Tobi-De/c059c878d6b0a51ce7ad207e7b6e4658 to your computer and use it in GitHub Desktop.
Save Tobi-De/c059c878d6b0a51ce7ad207e7b6e4658 to your computer and use it in GitHub Desktop.
from typing import TYPE_CHECKING
import structlog
from .config import get_redis_url, get_client_tracking_enabled
from functools import wraps
import threading
import redis
import asyncio
if TYPE_CHECKING:
from .brokers import Broker
logger = structlog.stdlib.get_logger("client_tracking")
class RedisCounterMap:
hash_key = "sse_relay_server:channels"
def __init__(self, redis_url: str) -> None:
self._redis = redis.Redis.from_url(redis_url)
def increment(self, channel: str) -> None:
self._redis.hincrby(self.hash_key, channel, 1)
def decrement(self, channel: str) -> None:
count = self._redis.hincrby(self.hash_key, channel, -1)
if count <= 0:
self._redis.hdel(self.hash_key, channel)
def value(self) -> dict[str, int]:
return {
k.decode(): int(v) for k, v in self._redis.hgetall(self.hash_key).items()
}
def reset(self) -> None:
self._redis.delete(self.hash_key)
class DictCounterMap:
def __init__(self) -> None:
self._map = {}
self.lock = threading.Lock()
def increment(self, channel: str):
with self.lock:
if channel in self._map:
self._map[channel] += 1
else:
self._map[channel] = 1
def decrement(self, channel: str):
with self.lock:
if channel not in self._map:
return
self._map[channel] -= 1
if self._map[channel] == 0:
del self._map[channel]
def value(self):
return self._map.copy()
def reset(self):
with self.lock:
self._map = {}
def _get_counter_map():
if redis_url := get_redis_url():
return RedisCounterMap(redis_url)
return DictCounterMap()
_counter = _get_counter_map()
client_tracking_enabled = get_client_tracking_enabled()
def count_clients(func: "Broker.listen"):
if not client_tracking_enabled:
return func
@wraps(func)
async def wrapper(instance: "Broker", channel: str):
try:
logger.debug(f"Incrementing counter using {_counter.__class__.__name__}")
_counter.increment(channel)
async for event in func(instance, channel):
yield event
except asyncio.CancelledError:
_counter.decrement(channel)
logger.debug(f"Decrementing counter using {_counter.__class__.__name__}")
raise
return wrapper
def get_count_value():
if not client_tracking_enabled:
return {}
return _counter.value()
def reset_count_value():
if not client_tracking_enabled:
return {}
return _counter.reset()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment