Skip to content

Instantly share code, notes, and snippets.

@simonw
Created December 22, 2009 07:14
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save simonw/261572 to your computer and use it in GitHub Desktop.
Save simonw/261572 to your computer and use it in GitHub Desktop.
# -*- coding: utf-8 -*-
# Copy of http://dpaste.com/136418/
# Based on http://github.com/simonw/django-signed
# All cookie related code is untested.
import base64
import hmac
import struct
import time
from django.conf import settings
from django.core.exceptions import SuspiciousOperation
from django.utils.hashcompat import sha_constructor
from django.utils import simplejson
from django.utils.functional import wraps
class BadSignature(ValueError, SuspiciousOperation):
# Extends ValueError, which makes it more convenient to catch and has
# basically the correct semantics.
pass
class SignedCookies(object):
def __init__(self, signer, cookies, extra_key=''):
self.signer = signer
self.cookies = request
self.extra_key = extra_key
def __getitem__(self, name):
cookie = self.cookies[name]
try:
return self.signer.unsign(cookie, extra_key=self.extra_key+name)
except BadSignature:
raise KeyError()
class Signer(object):
def __init__(self, key=None, compress=False, serializer=simplejson, hash_constructor=sha_constructor, separator='.'):
self.key = key or settings.SECRET_KEY
self.compress = compress
self.serializer = serializer
self.hash_constructor = hash_constructor
self.separator = separator
def split(self, signed_value):
return signed_value.rsplit(self.separator, 1)
def join(self, value, signature):
return value + self.separator + signature
def encode(self, s):
return base64.urlsafe_b64encode(s).rstrip('=')
def decode(self, s):
pad = len(s) % 4
if pad:
s += '=' * (4 - pad)
return base64.urlsafe_b64decode(s)
def signature(self, value, extra_key=''):
return self.encode(hmac.new(self.key + extra_key, value, self.hash_constructor).digest())
def sign(self, value, extra_key=''):
if not isinstance(value, str):
raise TypeError('sign() needs bytestring, got %s: %s' % (type(value), repr(value)))
return self.join(value, self.signature(value, extra_key=extra_key))
def unsign(self, signed_value, extra_key=''):
if not isinstance(signed_value, str):
raise TypeError('unsign() needs bytestring, got %s: %s' % (type(signed_value), repr(signed_value)))
try:
value, signature = self.split(signed_value)
except (ValueError, TypeError):
raise BadSignature()
if self.signature(value, extra_key=extra_key) == signature:
return value
raise BadSignature()
def dumps(self, obj, extra_key=''):
s = self.serializer.dumps(obj)
is_compressed = False # Flag for if it's been compressed or not
if self.compress:
import zlib # Avoid zlib dependency unless compress is being used
compressed = zlib.compress(s)
if len(compressed) < (len(s) - 1):
s = compressed
is_compressed = True
s = self.encode(s)
if is_compressed:
s = self.separator + s
return self.sign(s, extra_key=extra_key)
def loads(self, s, extra_key=''):
value = self.unsign(s, extra_key=extra_key)
if value and value[0] == self.separator:
value = value[1:]
is_compressed = True
else:
is_compressed = False
value = self.decode(value)
if is_compressed:
import zlib
value = zlib.decompress(value)
return self.serializer.loads(value)
def get_cookie(self, request, name, extra_key=''):
cookie = request.COOKIES[name]
return self.unsign(cookie, extra_key=extra_key + name)
def set_cookie(self, response, name, value, extra_key='', **kwargs):
cookie = self.sign(value, extra_key=extra_key + name)
response.set_cookie(name, cookie, **kwargs)
def sign_cookies(self, extra_key=''):
def decorator(view_func):
@wraps(view_func)
def decorated(request, *args, **kwargs):
request.COOKIES = SignedCookies(self, request.COOKIES, extra_key=extra_key)
response = view_func(request, *args, **kwargs)
for name in response.cookies.keys():
response.cookies[name] = self.sign(response.cookies[name], extra_key=extra_key + name)
return response
return decorated
return decorator
class SignatureExpired(BadSignature): pass
class TimestampedSigner(Signer):
def __init__(self, ttl=60, **kwargs):
self.ttl = ttl
super(TimestampedSigner, self).__init__(**kwargs)
def sign(self, value, extra_key=''):
timestamp = self.encode(struct.pack('I', int(time.time())))
value = self.join(value, timestamp)
return super(TimestampedSigner, self).sign(value, extra_key=extra_key)
def unsign(self, s, extra_key=''):
value = super(TimestampedSigner, self).unsign(s, extra_key=extra_key)
value, timestamp = self.split(value)
timestamp = struct.unpack('I', self.decode(timestamp))[0]
if self.ttl < int(time.time()) - timestamp:
raise SignatureExpired()
return value
def set_cookie(self, response, name, value, extra_key='', max_age=None, **kwargs):
max_age = min(max_age, self.ttl)
super(TimestampedSigner, self).set_cookie(response, name, value, extra_key=extra_key, max_age=max_age, **kwargs)
# shortcut functions
_default_signer = Signer()
def sign(value, extra_key=''):
return _default_signer.sign(value, extra_key=extra_key)
def unsign(signed_value, extra_key=''):
return _default_signer.unsign(signed_value, extra_key=extra_key)
def dumps(value, extra_key=''):
return _default_signer.dumps(value, extra_key=extra_key)
def loads(signed_value, extra_key=''):
return _default_signer.loads(signed_value, extra_key=extra_key)
# tests
if __name__ == "__main__":
import unittest, pickle, hashlib
class Test(unittest.TestCase):
sign_tests = ["a", "123", "\00\01\02", "foo\t.bar", "\n\n", "1.302.55", '!"$%&/()=?`"', "x"*1000]
dumps_tests = [1, 2.3, False, True, None, {'foo': 'bar'}, {'foo': 'x'*1000}, ['foo', 'bar', 'baz'], u"fööbär"]
def _test_signer(self, signer):
for method in (('sign', 'unsign'), ('dumps', 'loads')):
sign, unsign = method
objects = getattr(self, "%s_tests" % sign)
sign, unsign = getattr(signer, sign), getattr(signer, unsign)
for s in objects:
signed = sign(s)
self.assertEqual(s, unsign(signed))
extra_signed = sign(s, extra_key='foo')
self.assertEqual(s, unsign(extra_signed, extra_key='foo'))
self.assertRaises(BadSignature, unsign, extra_signed)
self.assertRaises(BadSignature, unsign, signed, extra_key='foo')
for bad_signed in (repr(s), signed[:-3], signed[1:], signed.replace('.', '').replace('#', '')):
self.assertRaises(BadSignature, unsign, bad_signed)
def test(self):
options = dict(key='SECRET', compress=True, serializer=pickle, hash_constructor=hashlib.md5, separator='#')
kwargs_list = [{}]
for key, value in options.items():
new_options = []
for kwargs in kwargs_list:
args = kwargs.copy()
args[key] = value
new_options.append(args)
kwargs_list += new_options
for cls in (Signer, TimestampedSigner):
for kwargs in kwargs_list:
signer = cls(**kwargs)
self._test_signer(signer)
t_signer = TimestampedSigner(ttl=1)
signed = t_signer.sign("foo")
time.sleep(2)
self.assertRaises(SignatureExpired, t_signer.unsign, signed)
unittest.main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment