Last active October 16, 2023 08:25
Flask server-side session with redis
license: gpl-3.0
This gist includes a snippet of the implementation of Flask server-side session with Redis.



The key for session in Cookie, constructed with the sessionID and an expiry timestamp, is changed by every request.

  • The sessionID is the key in Redis, which will be alive unless the user logout or it expires. The expiry of the session in cookie is updated by every request as the calculated ttl of the session in Redis. The states from both server and client are synchronized.
  • The expiry timestamp of the session key from the cookie in a request will be validated with the expiry from Redis. So a session key from the client will not be easily kept as valid.
  • When the session key in a request is judged invalid, the session related to it will be abandoned immediately. The key in Redis by the sessionID expires and the response deletes the cookie.

Easy to aggregate or filter by users, integrating with Flask-Login

A user id is set as the prefix of a sessionID for convenience, even though the user id is usually stored inside the session.

You can easily

  • Filter by the user id to collect and expire all sessions of a user to force logout, if a user reports his account is hijacked.
  • Aggregate by distinct user id to count the users.

Insert a organization identifier as the prefix of the current sessionID is an idea to aggregate user count by organization, if you need to extend this snippet.

Not implemented

Encrypt the key for session in Cookie

Current implementation show the plain text for the sessionID inside a key for session in Cookie. Hacker can pick it and try to send a request to disturb and force logout a user session.

  • This could happen normally when a user spill the beans of the cookie in his current login.
  • The sessionID is in a predictable format (uuid), but it's expensive to predict or play collision attack.
import json
import re
import time
from datetime import timedelta, datetime
from uuid import uuid4
from flask.sessions import SessionMixin, SessionInterface
from redis import ReadOnlyError
from redis import Redis
from werkzeug.datastructures import CallbackDict
def utctimestamp_by_second(utc_date_time):
return int((utc_date_time.replace(tzinfo=time.timezone.utc)).timestamp())
class RedisSession(CallbackDict, SessionMixin):
def __init__(self, initial=None, sid=None, new=False):
def on_update(s):
s.modified = True
CallbackDict.__init__(self, initial, on_update)
self.sid = sid = new
self.modified = False
class RedisSessionInterface(SessionInterface):
def __init__(self, redis=None):
self.redis = redis or Redis()
def open_session(self, app, request):
session_key = request.cookies.get(SESSION_COOKIE_NAME)
if not session_key:
return self._new_session()
sid, expiry_timestamp = self._extract_sid_and_expiry_timestamp_from(session_key)
if not expiry_timestamp:
return self._new_session()
redis_value, redis_key_ttl = self._get_redis_value_and_ttl_of(sid)
if not redis_value:
return self._new_session()
if self._expiry_timestamp_not_match(expiry_timestamp, redis_key_ttl):
return self._new_session()
data = json.loads(redis_value.decode())
return RedisSession(data, sid=sid)
def save_session(self, app, session, response):
user_id = session.get('user_id')
def session_is_modified_empty():
return not session and session.modified
def session_is_invalid():
return not user_id
if session_is_modified_empty() or session_is_invalid():
self._clean_redis_and_cookie(app, response, session)
redis_value = json.dumps(dict(session))
expiry_duration = self._get_expiry_duration(app, session)
expiry_date = datetime.utcnow() + expiry_duration
expires_in_seconds = int(expiry_duration.total_seconds())
session.sid = self._inject_user_id_in_sid(session.sid, user_id)
session_key = self._create_session_key(session.sid, expiry_date)
self._write_wrapper(self.redis.setex, self._redis_key(session.sid), redis_value, expires_in_seconds)
response.set_cookie(SESSION_COOKIE_NAME, session_key, expires=expiry_date,
httponly=True, domain=self.get_cookie_domain(app))
def _new_session():
return RedisSession(sid=uuid4().hex, new=True)
def _get_expiry_duration(app, session):
if session.permanent:
return app.permanent_session_lifetime
return timedelta(minutes=SESSION_EXPIRY_MINUTES)
def _redis_key(sid):
return 's:{}'.format(sid)
def _write_wrapper(self, write_method, *args):
for i in range(3):
except ReadOnlyError:
def _get_redis_value_and_ttl_of(self, sid):
redis_key = self._redis_key(sid)
pipeline = self.redis.pipeline()
results = pipeline.execute()
return tuple(results)
def _expiry_timestamp_not_match(expiry_timestamp, redis_key_ttl):
datetime_from_ttl = datetime.utcnow() + timedelta(seconds=redis_key_ttl)
timestamp_from_ttl = utctimestamp_by_second(datetime_from_ttl)
return abs(int(expiry_timestamp) - timestamp_from_ttl) > 10
except (ValueError, TypeError):
return True
def _extract_sid_and_expiry_timestamp_from(session_key):
matched = re.match(r"^(.+)\.(\d+)$", session_key)
if not matched:
return session_key, None
def _create_session_key(sid, expiry_date):
return "{}.{}".format(sid, utctimestamp_by_second(expiry_date))
def _inject_user_id_in_sid(sid, user_id):
prefix = "{}.".format(user_id)
if not sid.startswith(prefix):
sid = prefix + sid
return sid
def _clean_redis_and_cookie(self, app, response, session):
self._write_wrapper(self.redis.delete, self._redis_key(session.sid))
response.delete_cookie(SESSION_COOKIE_NAME, domain=self.get_cookie_domain(app))
def init_app(app):
redis = Redis(host=app.config['REDIS_HOST'], port=app.config['REDIS_PORT'],
db=app.config['REDIS_DB'], password=app.config['REDIS_PASSWORD'])
app.session_interface = RedisSessionInterface(redis)
