Skip to content

Instantly share code, notes, and snippets.

@rluts
Last active October 13, 2023 20:56
Show Gist options
  • Save rluts/22e05ed8f53f97bdd02eafdf38f3d60a to your computer and use it in GitHub Desktop.
Save rluts/22e05ed8f53f97bdd02eafdf38f3d60a to your computer and use it in GitHub Desktop.
Token authorization middleware for Django Channels 2
from channels.auth import AuthMiddlewareStack
from rest_framework.authtoken.models import Token
from django.contrib.auth.models import AnonymousUser
from django.db import close_old_connections
class TokenAuthMiddleware:
"""
Token authorization middleware for Django Channels 2
"""
def __init__(self, inner):
self.inner = inner
def __call__(self, scope):
headers = dict(scope['headers'])
if b'authorization' in headers:
try:
token_name, token_key = headers[b'authorization'].decode().split()
if token_name == 'Token':
token = Token.objects.get(key=token_key)
scope['user'] = token.user
close_old_connections()
except Token.DoesNotExist:
scope['user'] = AnonymousUser()
return self.inner(scope)
TokenAuthMiddlewareStack = lambda inner: TokenAuthMiddleware(AuthMiddlewareStack(inner))
@goatwu1993
Copy link

Thanks for sharing. Below is simplejwt code that works for me.

https://gist.github.com/goatwu1993/1105108e71b6a138168a2e9d160b357d

django==3.1.4
djangorestframework_simplejwt==4.6.0
channels==3.0.3

Also I replace lambda with def to avoid flake8 check.

@nasir733
Copy link

nasir733 commented Mar 22, 2021

this is the code that worked for me in jwt auth in channels 3

"""General web socket middlewares
"""

from channels.db import database_sync_to_async
from django.contrib.auth import get_user_model
from django.contrib.auth.models import AnonymousUser
from rest_framework_simplejwt.exceptions import InvalidToken, TokenError
from rest_framework_simplejwt.tokens import UntypedToken
from rest_framework_simplejwt.authentication import JWTTokenUserAuthentication
from rest_framework_simplejwt.state import User
from channels.middleware import BaseMiddleware
from channels.auth import AuthMiddlewareStack
from django.db import close_old_connections
from urllib.parse import parse_qs
from jwt import decode as jwt_decode
from django.conf import settings
@database_sync_to_async
def get_user(validated_token):
    try:
        user = get_user_model().objects.get(id=validated_token["user_id"])
        # return get_user_model().objects.get(id=toke_id)
        print(f"{user}")
        return user
   
    except User.DoesNotExist:
        return AnonymousUser()



class JwtAuthMiddleware(BaseMiddleware):
    def __init__(self, inner):
        self.inner = inner

    async def __call__(self, scope, receive, send):
       # Close old database connections to prevent usage of timed out connections
        close_old_connections()

        # Get the token
        token = parse_qs(scope["query_string"].decode("utf8"))["token"][0]

        # Try to authenticate the user
        try:
            # This will automatically validate the token and raise an error if token is invalid
            UntypedToken(token)
        except (InvalidToken, TokenError) as e:
            # Token is invalid
            print(e)
            return None
        else:
            #  Then token is valid, decode it
            decoded_data = jwt_decode(token, settings.SECRET_KEY, algorithms=["HS256"])
            print(decoded_data)
            # Will return a dictionary like -
            # {
            #     "token_type": "access",
            #     "exp": 1568770772,
            #     "jti": "5c15e80d65b04c20ad34d77b6703251b",
            #     "user_id": 6
            # }

            # Get the user using ID
            scope["user"] = await get_user(validated_token=decoded_data)
        return await super().__call__(scope, receive, send)


def JwtAuthMiddlewareStack(inner):
    return JwtAuthMiddleware(AuthMiddlewareStack(inner))


cheers I didn't remove the useless imports I just found the solution and sharing it with you

@alex-pobeditel-2004
Copy link

If anyone else needs it, here is a version of @dmwyatt gist in which Channels 3.0 are supported with rest_framework_simplejwt:
https://gist.github.com/alex-pobeditel-2004/5098bac720c4eeb79052b7234346f52d

@manupatel007
Copy link

For those who are stuck with client side code to be written and how this middleware connects to the url configuration, following blog can help --> https://hashnode.com/post/using-django-drf-jwt-authentication-with-django-channels-cjzy5ffqs0013rus1yb9huxvl

@agusmakmun
Copy link

To support headers Authorization and token from query string, and also session as well.

from urllib import parse

from rest_framework.authtoken.models import Token
from channels.db import database_sync_to_async
from channels.auth import AuthMiddlewareStack


@database_sync_to_async
def get_user_from_headers_or_queries(scope):
    """
    function to get the `User` object
    from his headers or queries as well.
    :return object of `User` or None
    """
    try:
        headers = dict(scope["headers"])
    except KeyError as error:
        headers = {}
        logger.error(error)

    try:
        params = dict(parse.parse_qsl(scope["query_string"].decode("utf8")))
    except KeyError as error:
        params = {}
        logger.warning(error)

    token_key = None
    token_is_found = False

    if b"authorization" in headers:
        # 1. get from authorization headers
        token_name, token_key = headers[b"authorization"].decode().split()
        if token_name == "Token":  # nosec: B105 (just checking the token name)
            token_is_found = True
    else:
        # 2. get from token params
        token_key = params.get("token")
        token_is_found = True if token_key else False

    if token_is_found:
        try:
            token = Token.objects.get(key=token_key)
            return token.user
        except Token.DoesNotExist:
            pass  # AnonymousUser
    return None


class TokenAuthMiddleware:

    def __init__(self, app):
        # Store the ASGI application we were passed
        self.app = app

    async def __call__(self, scope, receive, send):
        user = await get_user_from_headers_or_queries(scope)
        if user is not None:
            scope["user"] = user
        return await self.app(scope, receive, send)


# Handy shortcut for applying all three layers at once
def TokenAuthMiddlewareStack(inner):
    """
    middleware to support websocket ssh connection
    from both session or by queries
    """
    return TokenAuthMiddleware(AuthMiddlewareStack(inner))

urls.py;

from channels.routing import ProtocolTypeRouter, URLRouter
from channels.security.websocket import AllowedHostsOriginValidator
from yourproject.utils.middleware import TokenAuthMiddlewareStack

application = ProtocolTypeRouter({
    'websocket': AllowedHostsOriginValidator(
        TokenAuthMiddlewareStack(
            URLRouter(...)
        )
    )
})

@ritiksoni00
Copy link

To support headers Authorization and token from query string, and also session as well.

from urllib import parse

from rest_framework.authtoken.models import Token
from channels.db import database_sync_to_async
from channels.auth import AuthMiddlewareStack


@database_sync_to_async
def get_user_from_headers_or_queries(scope):
    """
    function to get the `User` object
    from his headers or queries as well.
    :return object of `User` or None
    """
    try:
        headers = dict(scope["headers"])
    except KeyError as error:
        headers = {}
        logger.error(error)

    try:
        params = dict(parse.parse_qsl(scope["query_string"].decode("utf8")))
    except KeyError as error:
        params = {}
        logger.warning(error)

    token_key = None
    token_is_found = False

    if b"authorization" in headers:
        # 1. get from authorization headers
        token_name, token_key = headers[b"authorization"].decode().split()
        if token_name == "Token":  # nosec: B105 (just checking the token name)
            token_is_found = True
    else:
        # 2. get from token params
        token_key = params.get("token")
        token_is_found = True if token_key else False

    if token_is_found:
        try:
            token = Token.objects.get(key=token_key)
            return token.user
        except Token.DoesNotExist:
            pass  # AnonymousUser
    return None


class TokenAuthMiddleware:

    def __init__(self, app):
        # Store the ASGI application we were passed
        self.app = app

    async def __call__(self, scope, receive, send):
        user = await get_user_from_headers_or_queries(scope)
        if user is not None:
            scope["user"] = user
        return await self.app(scope, receive, send)


# Handy shortcut for applying all three layers at once
def TokenAuthMiddlewareStack(inner):
    """
    middleware to support websocket ssh connection
    from both session or by queries
    """
    return TokenAuthMiddleware(AuthMiddlewareStack(inner))

urls.py;

from channels.routing import ProtocolTypeRouter, URLRouter
from channels.security.websocket import AllowedHostsOriginValidator
from yourproject.utils.middleware import TokenAuthMiddlewareStack

application = ProtocolTypeRouter({
    'websocket': AllowedHostsOriginValidator(
        TokenAuthMiddlewareStack(
            URLRouter(...)
        )
    )
})

first you need to change the order.

# Handy shortcut for applying all three layers at once
def TokenAuthMiddlewareStack(inner):
    """
    middleware to support websocket ssh connection
    from both session or by queries
    """
    return AuthMiddlewareStack(TokenAuthMiddleware(inner)) #<---------- need to change the order. 

and second getting token from a query string is not a good idea.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment