Skip to content

Instantly share code, notes, and snippets.

@alex-pobeditel-2004
Last active May 24, 2022 13:47
Show Gist options
  • Save alex-pobeditel-2004/5098bac720c4eeb79052b7234346f52d to your computer and use it in GitHub Desktop.
Save alex-pobeditel-2004/5098bac720c4eeb79052b7234346f52d to your computer and use it in GitHub Desktop.
JWT Auth middleware for Django Channels 3.0 and rest_framework_simplejwt - update of @dmwyatt gist
"""
Original gist: https://gist.github.com/dmwyatt/5cf7e5102ed0a01b7d38aabf322e03b2
"""
import logging
from typing import Awaitable, Final, List, TYPE_CHECKING, TypedDict
from channels.auth import AuthMiddlewareStack
from channels.db import database_sync_to_async
from django.contrib.auth.models import AnonymousUser
from rest_framework.exceptions import AuthenticationFailed
from rest_framework_simplejwt.authentication import JWTAuthentication
if TYPE_CHECKING:
# If you're using a type checker, change this line to whatever your user model is.
from authentication.models import CustomUser
logger = logging.getLogger(__name__)
TOKEN_STR_PREFIX: Final = "Bearer"
class Scope(TypedDict, total=False):
subprotocols: List[str]
class QueryAuthMiddleware:
"""
Middleware for django-channels that gets the user from a websocket subprotocol
containing the JWT.
"""
def __init__(self, inner):
# Store the ASGI application we were passed
self.inner = inner
def __call__(self, scope: Scope):
return QueryAuthMiddlewareInstance(scope, self)
class QueryAuthMiddlewareInstance:
"""
Inner class that is instantiated once per scope.
"""
def __init__(self, scope: Scope, middleware):
self.middleware = middleware
self.scope = scope
self.inner = self.middleware.inner
async def __call__(self, receive, send):
if not self.scope.get("user") or self.scope["user"].is_anonymous:
logger.debug("Attempting to authenticate user.")
try:
self.scope["user"] = await get_user_from_scope(self.scope)
if "auth_error" in self.scope:
del self.scope["auth_error"]
except (AuthenticationFailed, MissingTokenError) as e:
self.scope["user"] = AnonymousUser()
# Saves the error received during authentication into the scope so
# that we can do something with it later if we want.
self.scope["auth_error"] = str(e)
logger.info("Could not auth user: %s", str(e))
inner = self.inner(self.scope, receive, send)
return await inner
JWTBearerProtocolAuthStack = lambda inner: QueryAuthMiddleware(
AuthMiddlewareStack(inner)
)
def get_bearer_subprotocol(scope: Scope):
for subproto in scope.get("subprotocols", []):
if subproto.startswith(TOKEN_STR_PREFIX):
return subproto
class JWTAuth(JWTAuthentication):
@classmethod
def get_token_from_request(cls, scope: Scope) -> str:
"""
Abuse this method to get token from django-channels scope instead of an http
request.
:param scope: Scope from django-channels middleware.
"""
token_string = get_bearer_subprotocol(scope)
if not token_string:
raise ValueError("No token provided.")
token = token_string.split(TOKEN_STR_PREFIX)[1]
return token
class MissingTokenError(Exception):
pass
class MetaRequest:
"""
This class puts headers from simple scope (request) to an object with META property for rest_framework_simplejwt
"""
def __init__(self, scope: dict):
"""
This code copied from django.core.handlers.asgi
:param scope:
"""
self.META = dict()
# Headers go into META.
for name, value in scope.get('headers', []):
name = name.decode('latin1')
if name == 'content-length':
corrected_name = 'CONTENT_LENGTH'
elif name == 'content-type':
corrected_name = 'CONTENT_TYPE'
else:
corrected_name = 'HTTP_%s' % name.upper().replace('-', '_')
# HTTP/2 say only ASCII chars are allowed in headers, but decode
# latin1 just in case.
value = value.decode('latin1')
if corrected_name in self.META:
value = self.META[corrected_name] + ',' + value
self.META[corrected_name] = value
@database_sync_to_async
def get_user_from_scope(scope) -> Awaitable[User]:
auth = JWTAuth()
# Fiddling META for rest_framework_simplejwt:
meta_request = MetaRequest(scope)
authenticated = auth.authenticate(meta_request)
if authenticated is None:
raise MissingTokenError("Cannot find token in scope.")
user, token = authenticated
logger.debug("Authenticated %s", user)
return user
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment