Skip to content

Instantly share code, notes, and snippets.

@lucassimon
Created September 29, 2019 21:47
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save lucassimon/8f357d1812c61fab7a97119876ce82f2 to your computer and use it in GitHub Desktop.
Save lucassimon/8f357d1812c61fab7a97119876ce82f2 to your computer and use it in GitHub Desktop.
import hashlib
import memcache
import traceback
from flask import request
from functools import wraps
from wakatime_website import app
from werkzeug.contrib.cache import MemcachedCache
mc = memcache.Client()
cache = MemcachedCache(mc)
def cached(fn=None, unique_per_user=True, minutes=30):
"""Caches a Flask route/view in memcached.
The request url, args, and current user are used to build the cache key.
Only GET requests are cached.
By default, cached requests expire after 30 minutes.
"""
if not isinstance(minutes, int):
raise Exception('Minutes must be an integer number.')
def wrapper(func):
@wraps(func)
def inner(*args, **kwargs):
if request.method != 'GET':
return func(*args, **kwargs)
prefix = 'flask-request'
path = request.full_path
user_id = app.current_user.id if app.current_user.is_authenticated else None
key = u('{user}-{method}-{path}').format(
user=user_id,
method=request.method,
path=path,
)
hashed = hashlib.md5(key.encode('utf8')).hexdigest()
hashed = '{prefix}-{hashed}'.format(prefix=prefix, hashed=hashed)
try:
resp = cache.get(hashed)
if resp:
return resp
except:
app.logger.error(traceback.format_exc())
resp = None
resp = func(*args, **kwargs)
try:
cache.set(hashed, resp, timeout=minutes * 60)
except:
app.logger.error(traceback.format_exc())
return resp
return inner
return wrapper(fn) if fn else wrapper
Rate limit API requests by IP or Current User
import redis
import traceback
from flask import abort, request
from functools import wraps
from wakatime_website import app
r = redis.Redis(decode_responses=True)
def rate_limited(fn=None, limit=20, methods=[], ip=True, user=True, minutes=1):
"""Limits requests to this endpoint to `limit` per `minutes`."""
if not isinstance(limit, int):
raise Exception('Limit must be an integer number.')
if limit < 1:
raise Exception('Limit must be greater than zero.')
def wrapper(func):
@wraps(func)
def inner(*args, **kwargs):
if not methods or request.method in methods:
if ip:
increment_counter(type='ip', for_methods=methods,
minutes=minutes)
count = get_count(type='ip', for_methods=methods)
if count > limit:
abort(429)
if user and app.current_user.is_authenticated:
increment_counter(type='user', for_methods=methods,
minutes=minutes)
count = get_count(type='user', for_methods=methods)
if count > limit:
abort(429)
return func(*args, **kwargs)
return inner
return wrapper(fn) if fn else wrapper
def get_counter_key(type=None, for_only_this_route=True, for_methods=None): if not isinstance(for_methods, list):
for_methods = []
if type == 'ip':
key = request.remote_addr
elif type == 'user':
key = app.current_user.id if app.current_user.is_authenticated else None
else:
raise Exception('Unknown rate limit type: {0}'.format(type))
route = ''
if for_only_this_route:
route = '{endpoint}'.format(
endpoint=request.endpoint,
)
return u('{type}-{methods}-{key}{route}').format(
type=type,
key=key,
methods=','.join(for_methods),
route=route,
)
def increment_counter(type=None, for_only_this_route=True, for_methods=None,
minutes=1):
if type not in ['ip', 'user']:
raise Exception('Type must be ip or user.')
key = get_counter_key(type=type, for_only_this_route=for_only_this_route,
for_methods=for_methods)
try:
r.incr(key)
r.expire(key, time=60 * minutes)
except:
app.logger.error(traceback.format_exc())
pass
def get_count(type=None, for_only_this_route=True, for_methods=None):
key = get_counter_key(type=type, for_only_this_route=for_only_this_route,
for_methods=for_methods)
try:
return int(r.get(key) or 0)
except:
app.logger.error(traceback.format_exc())
return 0
# Prevent brute forcing secrets or tokens
import redis
import traceback
from flask import request
from functools import wraps
from wakatime_website import app
from werkzeug.exceptions import NotFound
r = redis.Redis(decode_responses=True)
def protected(fn=None, limit=10, minutes=60):
"""Bans IP after requesting a protected resource too many times.
Prevents IP from making more than `limit` requests per `minutes` to
the decorated route. Prevents enumerating secrets or tokens from urls or
query arguments by blocking requests after too many 404 not found errors.
"""
if not isinstance(limit, int):
raise Exception('Limit must be an integer number.')
if not isinstance(minutes, int):
raise Exception('Minutes must be an integer number.')
def wrapper(func):
@wraps(func)
def inner(*args, **kwargs):
key = u('bruteforce-{}-{}').format(request.endpoint, request.remote_addr)
try:
count = int(r.get(key) or 0)
if count > limit:
r.incr(key)
seconds = 60 * minutes
r.expire(key, time=seconds)
app.logger.info('Request blocked by protected decorator.')
return '404', 404
except:
app.logger.error(traceback.format_exc())
try:
result = func(*args, **kwargs)
except NotFound:
try:
r.incr(key)
seconds = 60 * minutes
r.expire(key, time=seconds)
except:
pass
raise
if isinstance(result, tuple) and len(result) > 1 and result[1] == 404:
try:
r.incr(key)
seconds = 60 * minutes
r.expire(key, time=seconds)
except:
pass
return result
return inner
return wrapper(fn) if fn else wrapper
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment