Skip to content

Instantly share code, notes, and snippets.

@dmwyatt
Last active August 7, 2021 16:03
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dmwyatt/5cf7e5102ed0a01b7d38aabf322e03b2 to your computer and use it in GitHub Desktop.
Save dmwyatt/5cf7e5102ed0a01b7d38aabf322e03b2 to your computer and use it in GitHub Desktop.
[django-channels subprotocol auth middleware] Retrieves JWT from the websocket's subprotocol and retreives the appropriate user. #jwt #auth #django #websocket
import logging
from typing import Awaitable, Final, List, TYPE_CHECKING, TypedDict
from channels.auth import AuthMiddlewareStack
from channels.db import database_sync_to_async
from django.contrib.auth.models import AnonymousUser
from rest_framework.exceptions import AuthenticationFailed
from rest_framework_jwt.authentication import JSONWebTokenAuthentication
from rest_framework_jwt.blacklist.exceptions import MissingToken
if TYPE_CHECKING:
# If you're using a type checker, change this line to whatever your user model is.
from authentication.models import CustomUser
logger = logging.getLogger(__name__)
TOKEN_STR_PREFIX: Final = "Bearer"
class Scope(TypedDict, total=False):
subprotocols: List[str]
class QueryAuthMiddleware:
"""
Middleware for django-channels that gets the user from a websocket subprotocol
containing the JWT.
"""
def __init__(self, inner):
# Store the ASGI application we were passed
self.inner = inner
def __call__(self, scope: Scope):
return QueryAuthMiddlewareInstance(scope, self)
class QueryAuthMiddlewareInstance:
"""
Inner class that is instantiated once per scope.
"""
def __init__(self, scope: Scope, middleware):
self.middleware = middleware
self.scope = dict(scope)
self.inner = self.middleware.inner
async def __call__(self, receive, send):
if not self.scope.get("user") or self.scope["user"].is_anonymous:
logger.debug("Attempting to authenticate user.")
try:
self.scope["user"] = await get_user_from_scope(self.scope)
if "auth_error" in self.scope:
del self.scope["auth_error"]
except (AuthenticationFailed, MissingTokenError) as e:
self.scope["user"] = AnonymousUser()
# Saves the error received during authentication into the scope so
# that we can do something with it later if we want.
self.scope["auth_error"] = str(e)
logger.info("Could not auth user: %s", str(e))
inner = self.inner(self.scope)
return await inner(receive, send)
JWTBearerProtocolAuthStack = lambda inner: QueryAuthMiddleware(
AuthMiddlewareStack(inner)
)
def get_bearer_subprotocol(scope: Scope):
for subproto in scope.get("subprotocols", []):
if subproto.startswith(TOKEN_STR_PREFIX):
return subproto
class JWTAuth(JSONWebTokenAuthentication):
@classmethod
def get_token_from_request(cls, scope: Scope) -> str:
"""
Abuse this method to get token from django-channels scope instead of an http
request.
:param scope: Scope from django-channels middleware.
"""
token_string = get_bearer_subprotocol(scope)
if not token_string:
raise MissingToken("No token provided.")
token = token_string.split(TOKEN_STR_PREFIX)[1]
return token
class MissingTokenError(Exception):
pass
@database_sync_to_async
def get_user_from_scope(scope) -> Awaitable[CustomUser]:
auth = JWTAuth()
authenticated = auth.authenticate(scope)
if authenticated is None:
raise MissingTokenError("Cannot find token in scope.")
user, token = authenticated
logger.debug("Authenticated %s", user)
return user
@alex-pobeditel-2004
Copy link

Thanks for the gist! I modified it slightly to support Channels 3.* and rest_framework_simplejwt - hope that it didn't break anything:
https://gist.github.com/alex-pobeditel-2004/5098bac720c4eeb79052b7234346f52d

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment