Skip to content

Instantly share code, notes, and snippets.

@dhepper
Created April 6, 2017 12:40
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dhepper/ae9cd31db35cdb9408d28434942f9b57 to your computer and use it in GitHub Desktop.
Save dhepper/ae9cd31db35cdb9408d28434942f9b57 to your computer and use it in GitHub Desktop.
from urllib.parse import urlparse
from django.conf import settings
from django.http.request import validate_host
from channels.exceptions import DenyConnection
class BaseOriginValidator(object):
def __init__(self, func):
self.func = func
def __call__(self, message, *args, **kwargs):
origin = self.get_origin(message)
if not self.validate_origin(message, origin):
raise DenyConnection
return self.func(message, *args, **kwargs)
def get_header(self, message, name):
headers = message.content['headers']
for header in headers:
try:
if header[0] == name:
return header[1:]
except IndexError:
continue
raise KeyError('No header named "{}"'.format(name))
def get_origin(self, message):
try:
header = self.get_header(message, b'origin')[0]
except (IndexError, KeyError) as e:
raise DenyConnection
origin = header.decode('ascii')
return origin
def validate_origin(self, message, origin):
raise NotImplemented('You must overwrite this method.')
class AllowedHostsOnlyOriginValidator(BaseOriginValidator):
def validate_origin(self, message, origin):
allowed_hosts = settings.ALLOWED_HOSTS
if settings.DEBUG and not allowed_hosts:
allowed_hosts = ['localhost', '127.0.0.1', '[::1]']
origin_hostname = urlparse(origin).hostname
valid = origin_hostname and validate_host(origin_hostname, allowed_hosts)
return valid
allowed_hosts_only = AllowedHostsOnlyOriginValidator
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment