Skip to content

Instantly share code, notes, and snippets.

Last active October 13, 2023 20:56
Show Gist options
  • Save rluts/22e05ed8f53f97bdd02eafdf38f3d60a to your computer and use it in GitHub Desktop.
Save rluts/22e05ed8f53f97bdd02eafdf38f3d60a to your computer and use it in GitHub Desktop.
Token authorization middleware for Django Channels 2
from channels.auth import AuthMiddlewareStack
from rest_framework.authtoken.models import Token
from django.contrib.auth.models import AnonymousUser
from django.db import close_old_connections
class TokenAuthMiddleware:
Token authorization middleware for Django Channels 2
def __init__(self, inner):
self.inner = inner
def __call__(self, scope):
headers = dict(scope['headers'])
if b'authorization' in headers:
token_name, token_key = headers[b'authorization'].decode().split()
if token_name == 'Token':
token = Token.objects.get(key=token_key)
scope['user'] = token.user
except Token.DoesNotExist:
scope['user'] = AnonymousUser()
return self.inner(scope)
TokenAuthMiddlewareStack = lambda inner: TokenAuthMiddleware(AuthMiddlewareStack(inner))
Copy link

@astrikov-d I was thinking that would be much easier too. But as mentioned in a few comments above, it's not secure to send auth/session token via query params since they are visible on server logs(i guess). Well, there is the third option of using cookies. take a look at this

Copy link

noobmaster19 commented Apr 24, 2020

Hello , sorry for the interruption , i have successfully implemented the cookie method raised here.
Im pretty new to web development , wanted to know if this method is secure.

@astrikov-d I was thinking that would be much easier too. But as mentioned in a few comments above, it's not secure to send auth/session token via query params since they are visible on server logs(i guess). Well, there is the third option of using cookies. take a look at this

Copy link

@neowenshun if the webpage you are using it on is SSL secured and the cookie has HTTPOnly flag set, it should be secure implementation.

Copy link

dphans commented Sep 2, 2020

I still use this way to use token as query param, authentication with django-rest-framework-simplejwt:

from urllib import parse

from channels.auth import AuthMiddlewareStack
from channels.db import database_sync_to_async
from django.contrib.auth.models import AnonymousUser
# noinspection PyProtectedMember
from django.db import close_old_connections
from rest_framework_simplejwt.settings import api_settings
from rest_framework_simplejwt.state import User
from rest_framework_simplejwt.tokens import AccessToken

from Account.models import User

def get_user(**kwargs):
        return User.objects.get(**kwargs)
    except User.DoesNotExist:
        return AnonymousUser()

class JwtAuthMiddleware:

    def __init__(self, inner):
        self.inner = inner

    def __call__(self, scope):
        return JwtAuthMiddlewareInstance(scope, self)

class JwtAuthMiddlewareInstance:

    def __init__(self, scope, middleware):
        self.middleware = middleware
        self.scope = dict(scope)
        self.inner = self.middleware.inner

    async def __call__(self, receive, send):

        if self.scope.get('user') and self.scope.get('user').is_active:
            inner = self.inner(dict(self.scope, user=self.scope.get('user')))
            return await inner(receive, send)

        query_string = self.scope["query_string"]
        if not query_string:
            inner = self.inner(dict(self.scope, user=AnonymousUser()))
            return await inner(receive, send)

            query_dict = parse.parse_qs(query_string.decode('utf-8'))
            inner = self.inner(dict(self.scope, user=AnonymousUser()))
            return await inner(receive, send)

        if type(query_dict.get('token')) is not list or not len(query_dict.get('token')):
            inner = self.inner(dict(self.scope, user=AnonymousUser()))
            return await inner(receive, send)

        raw_token = query_dict['token'][0]
        if not raw_token:
            inner = self.inner(dict(self.scope, user=AnonymousUser()))
            return await inner(receive, send)

            token_decoded = AccessToken(raw_token)
            token_decoded = None

        if not token_decoded:
            inner = self.inner(dict(self.scope, user=AnonymousUser()))
            return await inner(receive, send)

        user = await self.get_user(validated_token=token_decoded, )
        inner = self.inner(dict(self.scope, user=user))
        return await inner(receive, send)

    async def get_user(self, validated_token):
            user_id = validated_token[api_settings.USER_ID_CLAIM]
        except Exception:
            return AnonymousUser()

            user = await get_user(**{api_settings.USER_ID_FIELD: user_id})
            return AnonymousUser()

        if not user.is_active:
            return AnonymousUser()

        return user

JwtAuthMiddlewareStack = lambda inner: JwtAuthMiddleware(AuthMiddlewareStack(inner))

Copy link

Hello thank you all for your awesome answers. sadly, I am using django-all-auth to use social media authentication, and it only works with djangorestframework-jwt . I also want to use it with Django channels but I can't figure out how to create that custom authentication middleware. if anyone has faced the same issue, I would like to know how they handled it.

Copy link

dmwyatt commented Sep 28, 2020

I created a middleware that uses the websocket subprotocol. Rather than rolling my own authentication, I subclassed and customized JSONWebTokenAuthentication from the django-rest-framework-jwt package to get the JWT from the scope rather than from the HTTP request headers.

Copy link

I still use this way to use token as query param, authentication with django-rest-framework-simplejwt:

from urllib import parse

from channels.auth import AuthMiddlewareStack
from channels.db import database_sync_to_async
from django.contrib.auth.models import AnonymousUser
# noinspection PyProtectedMember
from django.db import close_old_connections
from rest_framework_simplejwt.settings import api_settings
from rest_framework_simplejwt.state import User
from rest_framework_simplejwt.tokens import AccessToken

from Account.models import User

def get_user(**kwargs):
        return User.objects.get(**kwargs)
    except User.DoesNotExist:
        return AnonymousUser()

class JwtAuthMiddleware:

    def __init__(self, inner):
        self.inner = inner

    def __call__(self, scope):
        return JwtAuthMiddlewareInstance(scope, self)

class JwtAuthMiddlewareInstance:

    def __init__(self, scope, middleware):
        self.middleware = middleware
        self.scope = dict(scope)
        self.inner = self.middleware.inner

    async def __call__(self, receive, send):

        if self.scope.get('user') and self.scope.get('user').is_active:
            inner = self.inner(dict(self.scope, user=self.scope.get('user')))
            return await inner(receive, send)

        query_string = self.scope["query_string"]
        if not query_string:
            inner = self.inner(dict(self.scope, user=AnonymousUser()))
            return await inner(receive, send)

            query_dict = parse.parse_qs(query_string.decode('utf-8'))
            inner = self.inner(dict(self.scope, user=AnonymousUser()))
            return await inner(receive, send)

        if type(query_dict.get('token')) is not list or not len(query_dict.get('token')):
            inner = self.inner(dict(self.scope, user=AnonymousUser()))
            return await inner(receive, send)

        raw_token = query_dict['token'][0]
        if not raw_token:
            inner = self.inner(dict(self.scope, user=AnonymousUser()))
            return await inner(receive, send)

            token_decoded = AccessToken(raw_token)
            token_decoded = None

        if not token_decoded:
            inner = self.inner(dict(self.scope, user=AnonymousUser()))
            return await inner(receive, send)

        user = await self.get_user(validated_token=token_decoded, )
        inner = self.inner(dict(self.scope, user=user))
        return await inner(receive, send)

    async def get_user(self, validated_token):
            user_id = validated_token[api_settings.USER_ID_CLAIM]
        except Exception:
            return AnonymousUser()

            user = await get_user(**{api_settings.USER_ID_FIELD: user_id})
            return AnonymousUser()

        if not user.is_active:
            return AnonymousUser()

        return user

JwtAuthMiddlewareStack = lambda inner: JwtAuthMiddleware(AuthMiddlewareStack(inner))

I am absolutely new to django-channels and I am using this code, it is returning an error.

    inner = self.inner(dict(self.scope, user=AnonymousUser()))
TypeError: __call__() missing 2 required positional arguments: 'receive' and 'send'

Can anyone point to the right direction?

Copy link

Ever since I upgraded to Channels 3.0 this code snippet isn't working anymore. Can anyone help?

If any of you ended up here and using Django 3.0. This won't work out of the box. The db access needs to be used in a separate method with database_sync_to_async decorator.

Here is the code snippet.

from channels.auth import AuthMiddlewareStack
from channels.db import database_sync_to_async
from django.contrib.auth.models import AnonymousUser

from rest_framework.authtoken.models import Token

def get_user(headers):
        token_name, token_key = headers[b'authorization'].decode().split()
        if token_name == 'Token':
            token = Token.objects.get(key=token_key)
            return token.user
    except Token.DoesNotExist:
        return AnonymousUser()

class TokenAuthMiddleware:

    def __init__(self, inner):
        self.inner = inner

    def __call__(self, scope):
        return TokenAuthMiddlewareInstance(scope, self)

class TokenAuthMiddlewareInstance:
    Yeah, this is black magic:
    def __init__(self, scope, middleware):
        self.middleware = middleware
        self.scope = dict(scope)
        self.inner = self.middleware.inner

    async def __call__(self, receive, send):
        headers = dict(self.scope['headers'])
        if b'authorization' in headers:
            self.scope['user'] = await get_user(headers)
        inner = self.inner(self.scope)
        return await inner(receive, send)

TokenAuthMiddlewareStack = lambda inner: TokenAuthMiddleware(AuthMiddlewareStack(inner))

credits: stackoverflow

Copy link

AliRn76 commented Dec 24, 2020

If you guys using Channels 3 you can use this snippet code:
(Token authorization middleware for Django Channels 3)

Copy link

Thanks for sharing. Below is simplejwt code that works for me.


Also I replace lambda with def to avoid flake8 check.

Copy link

nasir733 commented Mar 22, 2021

this is the code that worked for me in jwt auth in channels 3

"""General web socket middlewares

from channels.db import database_sync_to_async
from django.contrib.auth import get_user_model
from django.contrib.auth.models import AnonymousUser
from rest_framework_simplejwt.exceptions import InvalidToken, TokenError
from rest_framework_simplejwt.tokens import UntypedToken
from rest_framework_simplejwt.authentication import JWTTokenUserAuthentication
from rest_framework_simplejwt.state import User
from channels.middleware import BaseMiddleware
from channels.auth import AuthMiddlewareStack
from django.db import close_old_connections
from urllib.parse import parse_qs
from jwt import decode as jwt_decode
from django.conf import settings
def get_user(validated_token):
        user = get_user_model().objects.get(id=validated_token["user_id"])
        # return get_user_model().objects.get(id=toke_id)
        return user
    except User.DoesNotExist:
        return AnonymousUser()

class JwtAuthMiddleware(BaseMiddleware):
    def __init__(self, inner):
        self.inner = inner

    async def __call__(self, scope, receive, send):
       # Close old database connections to prevent usage of timed out connections

        # Get the token
        token = parse_qs(scope["query_string"].decode("utf8"))["token"][0]

        # Try to authenticate the user
            # This will automatically validate the token and raise an error if token is invalid
        except (InvalidToken, TokenError) as e:
            # Token is invalid
            return None
            #  Then token is valid, decode it
            decoded_data = jwt_decode(token, settings.SECRET_KEY, algorithms=["HS256"])
            # Will return a dictionary like -
            # {
            #     "token_type": "access",
            #     "exp": 1568770772,
            #     "jti": "5c15e80d65b04c20ad34d77b6703251b",
            #     "user_id": 6
            # }

            # Get the user using ID
            scope["user"] = await get_user(validated_token=decoded_data)
        return await super().__call__(scope, receive, send)

def JwtAuthMiddlewareStack(inner):
    return JwtAuthMiddleware(AuthMiddlewareStack(inner))

cheers I didn't remove the useless imports I just found the solution and sharing it with you

Copy link

If anyone else needs it, here is a version of @dmwyatt gist in which Channels 3.0 are supported with rest_framework_simplejwt:

Copy link

For those who are stuck with client side code to be written and how this middleware connects to the url configuration, following blog can help -->

Copy link

To support headers Authorization and token from query string, and also session as well.

from urllib import parse

from rest_framework.authtoken.models import Token
from channels.db import database_sync_to_async
from channels.auth import AuthMiddlewareStack

def get_user_from_headers_or_queries(scope):
    function to get the `User` object
    from his headers or queries as well.
    :return object of `User` or None
        headers = dict(scope["headers"])
    except KeyError as error:
        headers = {}

        params = dict(parse.parse_qsl(scope["query_string"].decode("utf8")))
    except KeyError as error:
        params = {}

    token_key = None
    token_is_found = False

    if b"authorization" in headers:
        # 1. get from authorization headers
        token_name, token_key = headers[b"authorization"].decode().split()
        if token_name == "Token":  # nosec: B105 (just checking the token name)
            token_is_found = True
        # 2. get from token params
        token_key = params.get("token")
        token_is_found = True if token_key else False

    if token_is_found:
            token = Token.objects.get(key=token_key)
            return token.user
        except Token.DoesNotExist:
            pass  # AnonymousUser
    return None

class TokenAuthMiddleware:

    def __init__(self, app):
        # Store the ASGI application we were passed = app

    async def __call__(self, scope, receive, send):
        user = await get_user_from_headers_or_queries(scope)
        if user is not None:
            scope["user"] = user
        return await, receive, send)

# Handy shortcut for applying all three layers at once
def TokenAuthMiddlewareStack(inner):
    middleware to support websocket ssh connection
    from both session or by queries
    return TokenAuthMiddleware(AuthMiddlewareStack(inner));

from channels.routing import ProtocolTypeRouter, URLRouter
from import AllowedHostsOriginValidator
from yourproject.utils.middleware import TokenAuthMiddlewareStack

application = ProtocolTypeRouter({
    'websocket': AllowedHostsOriginValidator(

Copy link

To support headers Authorization and token from query string, and also session as well.

from urllib import parse

from rest_framework.authtoken.models import Token
from channels.db import database_sync_to_async
from channels.auth import AuthMiddlewareStack

def get_user_from_headers_or_queries(scope):
    function to get the `User` object
    from his headers or queries as well.
    :return object of `User` or None
        headers = dict(scope["headers"])
    except KeyError as error:
        headers = {}

        params = dict(parse.parse_qsl(scope["query_string"].decode("utf8")))
    except KeyError as error:
        params = {}

    token_key = None
    token_is_found = False

    if b"authorization" in headers:
        # 1. get from authorization headers
        token_name, token_key = headers[b"authorization"].decode().split()
        if token_name == "Token":  # nosec: B105 (just checking the token name)
            token_is_found = True
        # 2. get from token params
        token_key = params.get("token")
        token_is_found = True if token_key else False

    if token_is_found:
            token = Token.objects.get(key=token_key)
            return token.user
        except Token.DoesNotExist:
            pass  # AnonymousUser
    return None

class TokenAuthMiddleware:

    def __init__(self, app):
        # Store the ASGI application we were passed = app

    async def __call__(self, scope, receive, send):
        user = await get_user_from_headers_or_queries(scope)
        if user is not None:
            scope["user"] = user
        return await, receive, send)

# Handy shortcut for applying all three layers at once
def TokenAuthMiddlewareStack(inner):
    middleware to support websocket ssh connection
    from both session or by queries
    return TokenAuthMiddleware(AuthMiddlewareStack(inner));

from channels.routing import ProtocolTypeRouter, URLRouter
from import AllowedHostsOriginValidator
from yourproject.utils.middleware import TokenAuthMiddlewareStack

application = ProtocolTypeRouter({
    'websocket': AllowedHostsOriginValidator(

first you need to change the order.

# Handy shortcut for applying all three layers at once
def TokenAuthMiddlewareStack(inner):
    middleware to support websocket ssh connection
    from both session or by queries
    return AuthMiddlewareStack(TokenAuthMiddleware(inner)) #<---------- need to change the order. 

and second getting token from a query string is not a good idea.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment