Skip to content

Instantly share code, notes, and snippets.

@gabbhack
Last active October 9, 2023 17:55
Show Gist options
  • Save gabbhack/0609c9813d3287fad3d3f07c4514a9eb to your computer and use it in GitHub Desktop.
Save gabbhack/0609c9813d3287fad3d3f07c4514a9eb to your computer and use it in GitHub Desktop.
import asyncio
from aiogram import Bot, Dispatcher, executor, types
from aiogram.contrib.fsm_storage.redis import RedisStorage2
from aiogram.dispatcher import DEFAULT_RATE_LIMIT
from aiogram.dispatcher.handler import CancelHandler, current_handler
from aiogram.dispatcher.middlewares import BaseMiddleware
from aiogram.utils.exceptions import Throttled
TOKEN = 'BOT TOKEN HERE'
loop = asyncio.get_event_loop()
# In this example Redis storage is used
storage = RedisStorage2(db=5)
bot = Bot(token=TOKEN, loop=loop)
dp = Dispatcher(bot, storage=storage)
def rate_limit(limit: int, key=None):
"""
Decorator for configuring rate limit and key in different functions.
:param limit:
:param key:
:return:
"""
def decorator(func):
setattr(func, 'throttling_rate_limit', limit)
if key:
setattr(func, 'throttling_key', key)
return func
return decorator
class ThrottlingMiddleware(BaseMiddleware):
"""
Simple middleware
"""
def __init__(self, limit=DEFAULT_RATE_LIMIT, key_prefix='antiflood_'):
self.rate_limit = limit
self.prefix = key_prefix
super(ThrottlingMiddleware, self).__init__()
async def on_process_message(self, message: types.Message, data: dict):
"""
This handler is called when dispatcher receives a message
:param message:
"""
# Get current handler
handler = current_handler.get()
# Get dispatcher from context
dispatcher = Dispatcher.get_current()
# If handler was configured, get rate limit and key from handler
if handler:
limit = getattr(handler, 'throttling_rate_limit', self.rate_limit)
key = getattr(handler, 'throttling_key', f"{self.prefix}_{handler.__name__}")
else:
limit = self.rate_limit
key = f"{self.prefix}_message"
# Use Dispatcher.throttle method.
try:
await dispatcher.throttle(key, rate=limit)
except Throttled as t:
# Execute action
await self.message_throttled(message, t)
# Cancel current handler
raise CancelHandler()
async def on_process_callback_query(self, query: types.CallbackQuery, data: dict):
# Get current handler
handler = current_handler.get()
# Get dispatcher from context
dispatcher = Dispatcher.get_current()
# If handler was configured, get rate limit and key from handler
if handler:
limit = getattr(handler, 'throttling_rate_limit', self.rate_limit)
key = getattr(handler, 'throttling_key', f"{self.prefix}_{handler.__name__}")
else:
limit = self.rate_limit
key = f"{self.prefix}_message"
# Use Dispatcher.throttle method.
try:
await dispatcher.throttle(key, rate=limit)
except Throttled as t:
# Execute action
await self.callback_query_throttled(query, t)
# Cancel current handler
raise CancelHandler()
async def message_throttled(self, message: types.Message, throttled: Throttled):
"""
Notify user only on first exceed and notify about unlocking only on last exceed
:param message:
:param throttled:
"""
handler = current_handler.get()
dispatcher = Dispatcher.get_current()
if handler:
key = getattr(handler, 'throttling_key', f"{self.prefix}_{handler.__name__}")
else:
key = f"{self.prefix}_message"
# Calculate how many time is left till the block ends
delta = throttled.rate - throttled.delta
# Prevent flooding
if throttled.exceeded_count <= 2:
await message.reply('Too many requests! ')
# Sleep.
await asyncio.sleep(delta)
# Check lock status
thr = await dispatcher.check_key(key)
# If current message is not last with current key - do not send message
if thr.exceeded_count == throttled.exceeded_count:
await message.reply('Unlocked.')
async def callback_query_throttled(self, query: types.CallbackQuery, throttled: Throttled):
handler = current_handler.get()
dispatcher = Dispatcher.get_current()
if handler:
key = getattr(handler, 'throttling_key', f"{self.prefix}_{handler.__name__}")
else:
key = f"{self.prefix}_message"
# Calculate how many time is left till the block ends
delta = throttled.rate - throttled.delta
# Prevent flooding
if throttled.exceeded_count <= 2:
await query.answer('Too many requests! ', show_alert=True)
@dp.message_handler(commands=['start'])
@rate_limit(5, 'start') # this is not required but you can configure throttling manager for current handler using it
async def cmd_test(message: types.Message):
# You can use this command every 5 seconds
await message.reply('Test passed! You can use this command every 5 seconds.')
if __name__ == '__main__':
# Setup middleware
dp.middleware.setup(ThrottlingMiddleware())
# Start long-polling
executor.start_polling(dp, loop=loop)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment